In [1]:
import torch

def Neg(rhoS, rhoC):
    rhoC_pt = rhoC.view(-1,2,2,2,2).permute(0,1,4,3,2).reshape(-1,4,4)
    rhoS_pt = rhoS.view(-1,2,2,2,2).permute(0,1,4,3,2).reshape(-1,4,4)
    e, v = torch.linalg.eig(rhoC_pt)
    #e += 1e-5
    mask = e.real < 0
    negative_v = v * mask.unsqueeze(1)
    P = torch.bmm(negative_v, negative_v.mT.conj()) # projection matrix
    return -torch.vmap(torch.trace)(torch.bmm(P, rhoS_pt)).real

def blogm(A):
    E, U = torch.linalg.eig(A)
    #E += 1e-5
    logE = torch.log(E.abs()).to(U.dtype)
    logA = torch.bmm(torch.bmm(U, torch.diag_embed(logE, offset=0, dim1=-2, dim2=-1)), U.conj().mT)
    return logA

def bSqc(rhoQ, rhoC):
    return -torch.vmap(torch.trace)(rhoQ@blogm(rhoC)).real

def eps(rho, e=0.1):
    I = torch.eye(rho.shape[-1], rho.shape[-1], dtype=rho.dtype, device=rho.device)[None,...].expand(rho.shape[0], -1, -1)/rho.shape[-1]
    return (1-e)*rho + e*I

In [2]:
H = (torch.tensor([[1,1],[1,-1]])/2**0.5).to(torch.cdouble)
H.conj().T@H

tensor([[1.0000+0.j, 0.0000+0.j],
        [0.0000+0.j, 1.0000+0.j]], dtype=torch.complex128)

In [3]:
u = torch.tensor([[torch.tensor(-1j*torch.pi/4).exp()/torch.sqrt(torch.tensor(2)),torch.tensor(1j*torch.pi/4).exp()/torch.sqrt(torch.tensor(2))],
                  [0,0],
                  [0,0],
                  [torch.tensor(1j*torch.pi/4).exp()/torch.sqrt(torch.tensor(2)),torch.tensor(-1j*torch.pi/4).exp()/torch.sqrt(torch.tensor(2))]]).view(2,2,2).to(torch.cdouble)
v = torch.tensor([[0,0,0,2**0.5],[2**0.5,0,0,0]]).view(2,2,2).to(torch.cdouble)
torch.einsum(u, [1,2,3], v, [3,4,5], [1,4,2,5]).reshape(4,4)
v = v.permute(1,2,0)

In [4]:
theta = torch.rand(1)*torch.pi/2
phi = torch.rand(1)*torch.pi
U = torch.tensor([[torch.cos(theta/2)*torch.exp(1.0j*phi/2),torch.sin(theta/2)*torch.exp(-1.0j*phi/2)],
                          [-torch.sin(theta/2)*torch.exp(1.0j*phi/2),torch.cos(theta/2)*torch.exp(-1.0j*phi/2)]])

U.T.conj()@U

tensor([[ 1.0000e+00+4.1024e-10j, -2.2664e-09+3.3258e-10j],
        [-2.2664e-09+2.9080e-10j,  1.0000e+00-3.9737e-10j]])

In [5]:
def cross(m, theta, phi, kernel):
    U = torch.tensor([[torch.cos(theta/2)*torch.exp(1.0j*phi/2),torch.sin(theta/2)*torch.exp(-1.0j*phi/2)],
                          [-torch.sin(theta/2)*torch.exp(1.0j*phi/2),torch.cos(theta/2)*torch.exp(-1.0j*phi/2)]], dtype=kernel.dtype)
    out = H@torch.tensor([1,0], dtype=kernel.dtype)
    out = torch.einsum(kernel, [0,1,2],
                       out, [0],
                       [1,2])
    out = torch.einsum(kernel, [1,3,4],
                       out, [1,2],
                       [2,3,4])
    out = torch.einsum(kernel, [3,5,6],
                       out, [2,3,4],
                       [2,4,5,6])
    out = torch.einsum(kernel, [5,7,8],
                       out, [2,4,5,6],
                       [7,8,6,4,2])
    # single qubit rotation
    out = torch.einsum(U, [9,7],
                       out, [7,8,6,4,2],
                       [9,8,6,4,2])
    if m == 0:
        m = torch.tensor([1,0], dtype=u.dtype)
    elif m == 1:
        m = torch.tensor([0,1], dtype=u.dtype)
    if m is not None:
        out = torch.einsum(m, [9],
                        out, [9,8,6,4,2],
                        [8,6,4,2])
    return out

def side(m, theta, phi, kernel):
    U = torch.tensor([[torch.cos(theta/2)*torch.exp(1.0j*phi/2),torch.sin(theta/2)*torch.exp(-1.0j*phi/2)],
                          [-torch.sin(theta/2)*torch.exp(1.0j*phi/2),torch.cos(theta/2)*torch.exp(-1.0j*phi/2)]], dtype=kernel.dtype)
    out = H@torch.tensor([1,0], dtype=kernel.dtype)
    out = torch.einsum(kernel, [0,1,2],
                       out, [0],
                       [1,2])
    out = torch.einsum(kernel, [1,3,4],
                       out, [1,2],
                       [2,3,4])
    out = torch.einsum(kernel, [3,5,6],
                       out, [2,3,4],
                       [2,4,5,6])
    # single qubit rotation
    out = torch.einsum(U, [7,5],
                       out, [2,4,5,6],
                       [7,6,4,2])
    if m == 0:
        m = torch.tensor([1,0], dtype=u.dtype)
    elif m == 1:
        m = torch.tensor([0,1], dtype=u.dtype)
    if m is not None: # (phy, bond, bond, bond)
        out = torch.einsum(m, [7],
                        out, [7,6,4,2],
                        [6,4,2])
    return out

def corner(m, theta, phi, kernel):
    U = torch.tensor([[torch.cos(theta/2)*torch.exp(1.0j*phi/2),torch.sin(theta/2)*torch.exp(-1.0j*phi/2)],
                          [-torch.sin(theta/2)*torch.exp(1.0j*phi/2),torch.cos(theta/2)*torch.exp(-1.0j*phi/2)]], dtype=kernel.dtype)
    out = H@torch.tensor([1,0], dtype=kernel.dtype)
    out = torch.einsum(kernel, [0,1,2],
                       out, [0],
                       [1,2])
    out = torch.einsum(kernel, [1,3,4],
                       out, [1,2],
                       [2,3,4])
    # single qubit rotation
    out = torch.einsum(U, [5,3],
                       out, [2,3,4],
                       [5,2,4])
    if m == 0:
        m = torch.tensor([1,0], dtype=u.dtype)
    elif m == 1:
        m = torch.tensor([0,1], dtype=u.dtype)
    if m is not None:
        out = torch.einsum(m, [5],
                        out, [5,2,4],
                        [2,4])
    return out

def contract(src, src_idx, dst, dst_idx):
    out_idx = src_idx.copy()
    for i in dst_idx:
        if i != 0: # batch dim
            if i in out_idx:
                out_idx.remove(i)
            else:
                out_idx.append(i)
    #print(out_idx)
    out_tensor = torch.einsum(src, src_idx,
                              dst, dst_idx,
                              out_idx)
    return out_tensor, out_idx

In [6]:
torch.arange(0,0.11,0.01).tolist()

[0.0,
 0.009999999776482582,
 0.019999999552965164,
 0.029999999329447746,
 0.03999999910593033,
 0.05000000074505806,
 0.05999999865889549,
 0.07000000029802322,
 0.07999999821186066,
 0.09000000357627869,
 0.10000000149011612]

In [7]:
rates = torch.arange(0,0.31,0.01).tolist()
out = torch.zeros(11, 36, len(rates))
for theta_idx in range(11):
    prep0, rhoS = torch.load(f'data/theta{theta_idx}/all_prepseq_theta={theta_idx}.pt'), torch.load(f'data/theta{theta_idx}/all_rhoS_theta={theta_idx}.pt')
    prep0 = prep0[:10**6]
    rhoS = rhoS[:10**6]
    #prep0 = torch.randint(0,2,(100,34))
    print(f'theta_idx={theta_idx} start')

    theta = torch.linspace(0, torch.pi/2, 11)[theta_idx]
    phi = torch.tensor((5/4)*torch.pi)

    u_cross = torch.cat([cross(0, theta, phi, u)[None,...], cross(1, theta, phi, u)[None,...]], dim=0)
    u_side = torch.cat([side(0, theta, phi, u)[None,...], side(1, theta, phi, u)[None,...]], dim=0)
    u_corner = torch.cat([corner(0, theta, phi, u)[None,...], corner(1, theta, phi, u)[None,...]], dim=0)
    v_cross = torch.cat([cross(0, theta, phi, v)[None,...], cross(1, theta, phi, v)[None,...]], dim=0)
    v_side = torch.cat([side(0, theta, phi, v)[None,...], side(1, theta, phi, v)[None,...]], dim=0)
    v_corner = torch.cat([corner(0, theta, phi, v)[None,...], corner(1, theta, phi, v)[None,...]], dim=0)
    # SENSITIVITY TESTING: Test robustness by flipping measurement outcomes
    # This loop is designed to evaluate how sensitive the quantum state reconstruction
    # is to errors in specific measurement outcomes
    for site in [35]:  # List of sites to test (can be expanded to [35, 34, 33, ...] for multiple sites)
        prep = prep0.clone()  # Create a copy of original preparation sequence
        
        # MEASUREMENT OUTCOME FLIPPING LOGIC:
        # This section flips binary measurement outcomes (0 ↔ 1) at the specified site
        # to simulate measurement errors or test sensitivity
        if site != 35:  # note: Current logic only flips if site ≠ 35
            # Flip binary outcomes: 0 → 1, 1 → 0
            # Mathematical operation: new_value = -(old_value - 1)
            # For 0: 0-1=-1, then -1*-1=1 → 0 becomes 1
            # For 1: 1-1=0,  then 0*-1=0  → 1 becomes 0
            prep[:,site] -= 1  # Subtract 1 from all measurements at this site
            prep[:,site] *= -1  # Multiply by -1 to complete the flip
            
        # note: To test multiple sites, modify the list above to:
        # for site in [35, 34, 33, 32, ...]:  # Test sensitivity at multiple locations
        # This will create separate runs testing the effect of flipping each site individually
        # first row
        src, src_idx = contract(v_corner[prep[:,0]], [0,1,2],
                u_side[prep[:,1]], [0,2,3,4])
        src, src_idx = contract(src, src_idx,
                                v_side[prep[:,2]], [0,4,5,6])
        src, src_idx = contract(src, src_idx,
                                u_side[prep[:,3]], [0,6,7,8])
        src, src_idx = contract(src, src_idx,
                                v_side[prep[:,4]], [0,8,9,10])
        src, src_idx = contract(src, src_idx,
                                u_corner[prep[:,5]], [0,10,11])
        #print(src.shape, src_idx)
        # second row
        src, src_idx = contract(src, src_idx,
                                u_side[prep[:,6]], [0,1,13,14])
        src, src_idx = contract(src, src_idx,
                                v_cross[prep[:,7]], [0,3,14,15,16])
        src, src_idx = contract(src, src_idx,
                                u_cross[prep[:,8]], [0,5,16,17,18])
        src, src_idx = contract(src, src_idx,
                                v_cross[prep[:,9]], [0,7,18,19,20])
        src, src_idx = contract(src, src_idx,
                                u_cross[prep[:,10]], [0,9,20,21,22])
        src, src_idx = contract(src, src_idx,
                                v_side[prep[:,11]], [0,11,22,23])
        #print(src.shape, src_idx)
        # third row
        src, src_idx = contract(src, src_idx,
                                v_side[prep[:,12]], [0,13,24,25])
        src, src_idx = contract(src, src_idx,
                                u_cross[prep[:,13]], [0,15,25,26,27])
        src, src_idx = contract(src, src_idx,
                                v_cross[prep[:,14]], [0,17,27,28,29])
        src, src_idx = contract(src, src_idx,
                                u_cross[prep[:,15]], [0,19,29,30,31])
        src, src_idx = contract(src, src_idx,
                                v_cross[prep[:,16]], [0,21,31,32,33])
        src, src_idx = contract(src, src_idx,
                                u_side[prep[:,17]], [0,23,33,34])
        #print(src.shape, src_idx)
        # fourth row # labeling [0,24,26,28,30,32,34]->[0,1,2,3,4,5,6]
        src, src_idx = contract(src, [0,1,2,3,4,5,6],
                                u_side[prep[:,18]], [0,1,7,8])
        src, src_idx = contract(src, src_idx,
                                v_cross[prep[:,19]], [0,2,7,9,10])
        src, src_idx = contract(src, src_idx,
                                u_cross[prep[:,20]], [0,3,10,11,12])
        src, src_idx = contract(src, src_idx,
                                v_cross[prep[:,21]], [0,4,12,13,14])
        src, src_idx = contract(src, src_idx,
                                u_cross[prep[:,22]], [0,5,14,15,16])
        src, src_idx = contract(src, src_idx,
                                v_side[prep[:,23]], [0,6,16,17])
        #print(src.shape, src_idx)
        # fifth row
        src, src_idx = contract(src, src_idx,
                                v_side[prep[:,24]], [0,8,18,19])
        src, src_idx = contract(src, src_idx,
                                u_cross[prep[:,25]], [0,9,18,20,21])
        src, src_idx = contract(src, src_idx,
                                v_cross[prep[:,26]], [0,11,21,22,23])
        src, src_idx = contract(src, src_idx,
                                u_cross[prep[:,27]], [0,13,23,24,25])
        src, src_idx = contract(src, src_idx,
                                v_cross[prep[:,28]], [0,15,25,26,27])
        src, src_idx = contract(src, src_idx,
                                u_side[prep[:,29]], [0,17,27,28])
        #print(src.shape, src_idx)
        # sixth row
        src, src_idx = contract(src, src_idx,
                                u_corner[prep[:,30]], [0,19,29])
        src, src_idx = contract(src, src_idx,
                                v_side[prep[:,31]], [0,20,29,30])
        src, src_idx = contract(src, src_idx,
                                u_side[prep[:,32]], [0,22,30,31])
        src, src_idx = contract(src, src_idx,
                                v_side[prep[:,33]], [0,24,31,32])
        src, src_idx = contract(src, src_idx,
                                u_side[prep[:,34]], [0,26,32,33])
        src, src_idx = contract(src, src_idx,
                                corner(None, theta, phi, v), [41,28,33])
        #print(src.shape, src_idx)
        rho = torch.vmap(torch.outer)(src.view(-1, 2), src.view(-1, 2).conj())
        coef = torch.vmap(torch.trace)(rho).view(-1,1,1)
        idx = (coef.real != 0).view(-1)
        rho = rho[idx]
        coef = coef[idx]
        rhoS = rhoS[idx]
        rho /= coef

        for r_idx, rate in enumerate(rates):
            depol_rho = eps(rho, rate)
            score = bSqc(rhoS, depol_rho).mean().item()
            out[theta_idx, site, r_idx] = score
            print(f"theta_idx={theta_idx} site={site} rate={rate} score={score:.4f}")

theta_idx=0 start
theta_idx=0 site=35 rate=0.0 score=1.1812
theta_idx=0 site=35 rate=0.009999999776482582 score=0.1708
theta_idx=0 site=35 rate=0.019999999552965164 score=0.1540
theta_idx=0 site=35 rate=0.029999999329447746 score=0.1462
theta_idx=0 site=35 rate=0.03999999910593033 score=0.1421
theta_idx=0 site=35 rate=0.04999999701976776 score=0.1401
theta_idx=0 site=35 rate=0.05999999865889549 score=0.1394
theta_idx=0 site=35 rate=0.07000000029802322 score=0.1395
theta_idx=0 site=35 rate=0.07999999821186066 score=0.1404
theta_idx=0 site=35 rate=0.08999999612569809 score=0.1418
theta_idx=0 site=35 rate=0.09999999403953552 score=0.1435
theta_idx=0 site=35 rate=0.10999999940395355 score=0.1457
theta_idx=0 site=35 rate=0.11999999731779099 score=0.1481
theta_idx=0 site=35 rate=0.12999999523162842 score=0.1507
theta_idx=0 site=35 rate=0.14000000059604645 score=0.1536
theta_idx=0 site=35 rate=0.14999999105930328 score=0.1567
theta_idx=0 site=35 rate=0.1599999964237213 score=0.1599
theta_idx=

In [22]:
torch.save(out[:,-1].T, 'data/Feb_1_cz_calibration_center/decode_6x6_single.pt')

In [21]:
out[:,-1].T.shape

torch.Size([31, 11])