In [1]:
import torch # type: ignore
import os

dtype = torch.complex128
device = torch.device("cpu")# torch.device("cuda" if torch.cuda.is_available() else "cpu")

pauli = torch.tensor([[[1,0],[0,1]],[[0,1],[1,0]],[[0,-1j],[1j,0]],[[1,0],[0,-1]]], device=device, dtype=dtype)
basis = torch.linalg.eig(pauli)[1][1:].mT # (3, 2, 2)

In [2]:
# process two probe qubits data
def torch_data(filename, d):
    data = {}
    for i in range(3):
        for j in range(3):
            m = torch.load(filename+f'_({i},{j}).pt')
            msk = torch.ones(m.shape[0], device=device, dtype=torch.bool)
            # post select on 2-qubit mitigation
            for anc, phy in [(0,7),
                             (1,8),
                             (2,9),
                             (3,10),
                             (4,11),
                             (5,12),
                             (54,47),
                             (55,48),
                             (56,49),
                             (57,50),
                             (58,51),
                             (59,52),
                             (6,7),
                             (14,15),
                             (22,23),
                             (30,31),
                             (38,39),
                             (46,47),
                             (13,12),
                             (21,20),
                             (29,28),
                             (37,36),
                             (45,44),
                             (53,52),
                             ]:
                msk = msk & (m[:,anc]==m[:,phy])
                #print(f'anc={anc}, phy={phy}, {((m[:,anc]==m[:,phy]).float().mean().item()):.4f}')
            prep_idx = [7,8,9,10,11,12,
                        15,16,17,18,19,20,
                        23,24,25,26,27,28,
                        31,32,33,34,35,36,
                        39,40,41,42,43,44,
                        47,48,49,50,51,52]
            # prep_idx = [12,20,28,36,44,52,
            #             11,19,27,35,43,51,
            #             10,18,26,34,42,50,
            #             9,17,25,33,41,49,
            #             8,16,24,32,40,48,
            #             7,15,23,31,39,47]
            m = m[msk] # (batch, num_qubits)
            if d == 6:
                probe_idx = [7,12]
            if d == 5:
                probe_idx = [7,11]
                #probe_idx = [12,44]
            if d == 4:
                probe_idx = [8,11]
            if d == 3:
                probe_idx = [8,10]
            prep_idx = [p for p in prep_idx if p not in probe_idx]
            probe = torch.cat([m[:,probe_idx[0]].view(-1,1), m[:,probe_idx[1]].view(-1,1)], 1)
            prep = m[:,prep_idx]
            data[(i,j)] = (prep, probe)
    prepseq, shadow_state, rhoS = [], [], []
    for k in data.keys():
        # construct post-measure state
        probseq = data[k][1].to(dtype=torch.int64).to(device=device) # (repetition, 2) last 2 outcomes
        obs_basis0 = basis[k[0]].unsqueeze(0).expand(probseq.shape[0], -1, -1) # (repetition, 2, 2)
        shadow_state0 = obs_basis0.gather(1, probseq[:,0].view(-1, 1, 1).expand(-1, -1, 2)).squeeze(1) # (repetition, 2)
        obs_basis1 = basis[k[1]].unsqueeze(0).expand(probseq.shape[0], -1, -1) # (repetition, 2, 2)
        shadow_state1 = obs_basis1.gather(1, probseq[:,1].view(-1, 1, 1).expand(-1, -1, 2)).squeeze(1) # (repetition, 2)
        shadow_state01 = torch.vmap(torch.kron)(shadow_state0, shadow_state1) # (batch, 4)
        # construct rhoS
        I = torch.eye(2, 2, device=device)[None,...].expand(shadow_state01.shape[0], -1, -1)
        rhoS0 = 3*torch.vmap(torch.outer)(shadow_state0, shadow_state0.conj()) - I
        rhoS1 = 3*torch.vmap(torch.outer)(shadow_state1, shadow_state1.conj()) - I
        rhoS01 = torch.vmap(torch.kron)(rhoS0, rhoS1)
        # collect result
        prepseq.append(data[k][0].to(dtype=torch.int64).to(device=device))
        shadow_state.append(shadow_state01)
        rhoS.append(rhoS01)
    prepseq = torch.cat(prepseq, 0).to(torch.int64)
    shadow_state = torch.cat(shadow_state, 0)
    rhoS = torch.cat(rhoS, 0)
    return prepseq, shadow_state, rhoS

def shuffle(prepseq, shadow_state, rhoS):
    indices = torch.randperm(prepseq.shape[0])
    prepseq = prepseq[indices]
    shadow_state = shadow_state[indices]
    rhoS = rhoS[indices]
    return prepseq, shadow_state, rhoS

In [3]:
for theta_idx in [4]:
    for d in [5]:
        all_prepseq = []
        all_shadow_state = []
        all_rhoS = []
        for loop in range(17):
            filename = f'data/theta{theta_idx}/loop{loop}/theta={theta_idx}'
            torch.manual_seed(loop)
            prepseq, shadow_state, rhoS = torch_data(filename, d)
            prepseq, shadow_state, rhoS = shuffle(prepseq, shadow_state, rhoS)
            all_prepseq.append(prepseq)
            all_shadow_state.append(shadow_state)
            all_rhoS.append(rhoS)
            print(f'distance={d}, loop={loop}, theta_idx={theta_idx}, portion to keep={((prepseq.shape[0]/9000000)):.4f}')
        all_prepseq = torch.cat(all_prepseq, 0)
        all_shadow_state = torch.cat(all_shadow_state, 0)
        all_rhoS = torch.cat(all_rhoS, 0)
        torch.save(all_prepseq, f'data/theta{theta_idx}/all_prepseq_theta={theta_idx}.pt')
        torch.save(all_shadow_state, f'data/theta{theta_idx}/all_shadow_state_theta={theta_idx}.pt')
        torch.save(all_rhoS, f'data/theta{theta_idx}/all_rhoS_theta={theta_idx}.pt')
        print(all_prepseq.shape, theta_idx)
        print(all_shadow_state.shape, theta_idx)
        print(all_rhoS.shape, theta_idx)

distance=5, loop=0, theta_idx=4, portion to keep=0.5838
distance=5, loop=1, theta_idx=4, portion to keep=0.5862
distance=5, loop=2, theta_idx=4, portion to keep=0.5859
distance=5, loop=3, theta_idx=4, portion to keep=0.5833
distance=5, loop=4, theta_idx=4, portion to keep=0.5830
distance=5, loop=5, theta_idx=4, portion to keep=0.5831
distance=5, loop=6, theta_idx=4, portion to keep=0.5810
distance=5, loop=7, theta_idx=4, portion to keep=0.5833
distance=5, loop=8, theta_idx=4, portion to keep=0.5801
distance=5, loop=9, theta_idx=4, portion to keep=0.5811
distance=5, loop=10, theta_idx=4, portion to keep=0.5862
distance=5, loop=11, theta_idx=4, portion to keep=0.5358
distance=5, loop=12, theta_idx=4, portion to keep=0.5414
distance=5, loop=13, theta_idx=4, portion to keep=0.5338
distance=5, loop=14, theta_idx=4, portion to keep=0.5470
distance=5, loop=15, theta_idx=4, portion to keep=0.5313
distance=5, loop=16, theta_idx=4, portion to keep=0.5250
torch.Size([86679738, 34]) 4
torch.Size([