In [8]:
import torch
import itertools
import utils

In [2]:
radius = []
radius.append([23,28,29,33,34]) # 5-qubit radius
radius.append([17,22,23,27,28,29,32,33,34]) # 9-qubit radius
radius.append([14,15,16,17,20,21,22,23,26,27,28,29,32,33,34]) # 15-qubit radius       

In [6]:
test_size = 10**6
results = []
for sites in radius:
    for theta_idx in range(11):
        prepseq = torch.load(f'data/theta{theta_idx}/all_prepseq_theta={theta_idx}.pt')
        rhoS = torch.load(f'data/theta{theta_idx}/all_rhoS_theta={theta_idx}.pt')
        print(f"prepseq.shape={prepseq.shape}, theta_idx={theta_idx}")
        print(f"rhoS.shape={rhoS.shape}, theta_idx={theta_idx}")

        prepseq = prepseq[:,sites]

        prepseq_test = prepseq[:test_size]
        rhoS_test = rhoS[:test_size]

        prepseq_post_select = prepseq[test_size:]
        rhoS_post_select = rhoS[test_size:]

        # Generate all possible bitstrings for the selected sites
        num_sites = len(sites)
        all_bitstrings = list(itertools.product([0, 1], repeat=num_sites))

        cross_entropies = {}

        print(f"Processing {len(all_bitstrings)} bitstrings for {num_sites} sites...")

        rhoC = torch.zeros_like(rhoS_test) # (test_size, 2, 2)

        for bitstring in all_bitstrings:
            bitstring_tensor = torch.tensor(bitstring, dtype=torch.int64)
            
            # Find indices where prepseq matches this bitstring for post-selection data
            post_idx = (prepseq_post_select == bitstring_tensor).all(dim=1)
            test_idx = (prepseq_test == bitstring_tensor).all(dim=1)

            if post_idx.sum() > 0 and test_idx.sum() > 0:
                rhoS_post_group = rhoS_post_select[post_idx].mean(dim=0, keepdim=True) # (1, 2, 2)
                min_eig = torch.linalg.eigvals(rhoS_post_group)[0].real.min()
                if min_eig < 0:
                    e = -min_eig / (1/rhoS_post_group.shape[-1] - min_eig)
                    rhoS_post_group = utils.depolarize(rhoS_post_group, e=e)
                rhoC[test_idx] = rhoS_post_group.expand(test_idx.sum(), -1, -1)            
            elif post_idx.sum() == 0 and test_idx.sum() > 0:
                # No post-selection data for this bitstring, use identity matrix
                I = torch.eye(rhoC.shape[-1], dtype=rhoC.dtype, device=rhoC.device)/rhoC.shape[-1]
                rhoC[test_idx] = I.unsqueeze(0).expand(test_idx.sum(), -1, -1)

        results.append(utils.bSqc(rhoS_test, rhoC))
results = torch.cat(results, 0).view(2,11,-1)

prepseq.shape=torch.Size([21574529, 35]), theta_idx=0
rhoS.shape=torch.Size([21574529, 2, 2]), theta_idx=0
Processing 32 bitstrings for 5 sites...
prepseq.shape=torch.Size([21528096, 35]), theta_idx=1
rhoS.shape=torch.Size([21528096, 2, 2]), theta_idx=1
Processing 32 bitstrings for 5 sites...
prepseq.shape=torch.Size([21464994, 35]), theta_idx=2
rhoS.shape=torch.Size([21464994, 2, 2]), theta_idx=2
Processing 32 bitstrings for 5 sites...
prepseq.shape=torch.Size([21449696, 35]), theta_idx=3
rhoS.shape=torch.Size([21449696, 2, 2]), theta_idx=3
Processing 32 bitstrings for 5 sites...
prepseq.shape=torch.Size([21490391, 35]), theta_idx=4
rhoS.shape=torch.Size([21490391, 2, 2]), theta_idx=4
Processing 32 bitstrings for 5 sites...
prepseq.shape=torch.Size([21462555, 35]), theta_idx=5
rhoS.shape=torch.Size([21462555, 2, 2]), theta_idx=5
Processing 32 bitstrings for 5 sites...
prepseq.shape=torch.Size([21479187, 35]), theta_idx=6
rhoS.shape=torch.Size([21479187, 2, 2]), theta_idx=6
Processing 

In [3]:
test_size = 10**6
results = []
for sites in radius:
    for theta_idx in range(11):
        prepseq = torch.load(f'data/theta{theta_idx}/all_prepseq_theta={theta_idx}.pt')
        rhoS = torch.load(f'data/theta{theta_idx}/all_rhoS_theta={theta_idx}.pt')
        print(f"prepseq.shape={prepseq.shape}, theta_idx={theta_idx}")
        print(f"rhoS.shape={rhoS.shape}, theta_idx={theta_idx}")
        # Select relevant sites
        prepseq = prepseq[:, sites]

        prepseq_test = prepseq[:test_size]
        rhoS_test = rhoS[:test_size]

        prepseq_post_select = prepseq[test_size:]
        rhoS_post_select = rhoS[test_size:]

        num_sites = len(sites)

        # === Step 1: Convert binary rows to integer keys ===
        def binary_tensor_to_int(tensor):
            powers = 2 ** torch.arange(tensor.shape[1] - 1, -1, -1, device=tensor.device)
            return (tensor * powers).sum(dim=1)

        test_keys = binary_tensor_to_int(prepseq_test)
        post_keys = binary_tensor_to_int(prepseq_post_select)

        # === Step 2: Efficient group-based reconstruction ===
        rhoC = torch.zeros_like(rhoS_test)
        I = torch.eye(rhoC.shape[-1], dtype=rhoC.dtype, device=rhoC.device) / rhoC.shape[-1]

        unique_keys = test_keys.unique()

        for key in unique_keys:
            post_mask = post_keys == key
            test_mask = test_keys == key

            if post_mask.sum() > 0:
                rhoS_post_group = rhoS_post_select[post_mask].mean(dim=0, keepdim=True)
                min_eig = torch.linalg.eigvals(rhoS_post_group)[0].real.min()
                if min_eig < 0:
                    e = -min_eig / (1 / rhoS_post_group.shape[-1] - min_eig)
                    rhoS_post_group = utils.depolarize(rhoS_post_group, e=e)
                rhoC[test_mask] = rhoS_post_group.expand(test_mask.sum(), -1, -1)
            else:
                rhoC[test_mask] = I.unsqueeze(0).expand(test_mask.sum(), -1, -1)

        # === Step 3: Save result ===
        results.append(utils.bSqc(rhoS_test, rhoC))
results = torch.cat(results, 0).view(3,11,-1)

prepseq.shape=torch.Size([21574529, 35]), theta_idx=0
rhoS.shape=torch.Size([21574529, 2, 2]), theta_idx=0
prepseq.shape=torch.Size([21528096, 35]), theta_idx=1
rhoS.shape=torch.Size([21528096, 2, 2]), theta_idx=1
prepseq.shape=torch.Size([21464994, 35]), theta_idx=2
rhoS.shape=torch.Size([21464994, 2, 2]), theta_idx=2
prepseq.shape=torch.Size([21449696, 35]), theta_idx=3
rhoS.shape=torch.Size([21449696, 2, 2]), theta_idx=3
prepseq.shape=torch.Size([21490391, 35]), theta_idx=4
rhoS.shape=torch.Size([21490391, 2, 2]), theta_idx=4
prepseq.shape=torch.Size([21462555, 35]), theta_idx=5
rhoS.shape=torch.Size([21462555, 2, 2]), theta_idx=5
prepseq.shape=torch.Size([21479187, 35]), theta_idx=6
rhoS.shape=torch.Size([21479187, 2, 2]), theta_idx=6
prepseq.shape=torch.Size([21427448, 35]), theta_idx=7
rhoS.shape=torch.Size([21427448, 2, 2]), theta_idx=7
prepseq.shape=torch.Size([21492355, 35]), theta_idx=8
rhoS.shape=torch.Size([21492355, 2, 2]), theta_idx=8
prepseq.shape=torch.Size([21550490, 3

RuntimeError: shape '[3, 11, -1]' is invalid for input of size 3000000

In [3]:
test_size = 10**6
results = []

for radius_idx, sites in enumerate(radius):
    print(f"\n=== Processing radius {radius_idx + 1} with {len(sites)} sites ===")

    for theta_idx in range(11):
        print(f"\n-- θ index = {theta_idx} --")

        # Load data
        prepseq = torch.load(f'data/theta{theta_idx}/all_prepseq_theta={theta_idx}.pt')
        rhoS = torch.load(f'data/theta{theta_idx}/all_rhoS_theta={theta_idx}.pt')
        print(f"Loaded prepseq: shape = {prepseq.shape}, rhoS: shape = {rhoS.shape}")

        # Select relevant sites
        prepseq = prepseq[:, sites]
        prepseq_test = prepseq[:test_size]
        rhoS_test = rhoS[:test_size]

        prepseq_post_select = prepseq[test_size:]
        rhoS_post_select = rhoS[test_size:]

        num_sites = len(sites)
        print(f"Using {num_sites} selected qubit sites")

        # Convert binary sequences to integer keys
        def binary_tensor_to_int(tensor):
            powers = 2 ** torch.arange(tensor.shape[1] - 1, -1, -1, device=tensor.device)
            return (tensor * powers).sum(dim=1)

        test_keys = binary_tensor_to_int(prepseq_test)
        post_keys = binary_tensor_to_int(prepseq_post_select)

        rhoC = torch.zeros_like(rhoS_test)
        I = torch.eye(rhoC.shape[-1], dtype=rhoC.dtype, device=rhoC.device) / rhoC.shape[-1]

        unique_keys = test_keys.unique()
        print(f"Found {len(unique_keys)} unique bitstring keys in test set")

        total_assigned = 0
        for key in unique_keys:
            post_mask = post_keys == key
            test_mask = test_keys == key
            num_test = test_mask.sum().item()
            num_post = post_mask.sum().item()

            if num_post > 0:
                rhoS_post_group = rhoS_post_select[post_mask].mean(dim=0, keepdim=True)
                min_eig = torch.linalg.eigvals(rhoS_post_group)[0].real.min()
                if min_eig < 0:
                    e = -min_eig / (1 / rhoS_post_group.shape[-1] - min_eig)
                    rhoS_post_group = utils.depolarize(rhoS_post_group, e=e)
                rhoC[test_mask] = rhoS_post_group.expand(num_test, -1, -1)
            else:
                rhoC[test_mask] = I.unsqueeze(0).expand(num_test, -1, -1)

            total_assigned += num_test

        print(f"Total reconstructed states: {total_assigned} / {test_size}")
        acc = utils.bSqc(rhoS_test, rhoC)
        results.append(acc)

# Final reshape
results = torch.cat(results, 0).view(3, 11, -1)
print("\n=== All processing complete ===")
print(f"Final results shape: {results.shape}")


=== Processing radius 1 with 5 sites ===

-- θ index = 0 --
Loaded prepseq: shape = torch.Size([21574529, 35]), rhoS: shape = torch.Size([21574529, 2, 2])
Using 5 selected qubit sites
Found 32 unique bitstring keys in test set
Total reconstructed states: 1000000 / 1000000

-- θ index = 1 --
Loaded prepseq: shape = torch.Size([21528096, 35]), rhoS: shape = torch.Size([21528096, 2, 2])
Using 5 selected qubit sites
Found 32 unique bitstring keys in test set
Total reconstructed states: 1000000 / 1000000

-- θ index = 2 --
Loaded prepseq: shape = torch.Size([21464994, 35]), rhoS: shape = torch.Size([21464994, 2, 2])
Using 5 selected qubit sites
Found 32 unique bitstring keys in test set
Total reconstructed states: 1000000 / 1000000

-- θ index = 3 --
Loaded prepseq: shape = torch.Size([21449696, 35]), rhoS: shape = torch.Size([21449696, 2, 2])
Using 5 selected qubit sites
Found 32 unique bitstring keys in test set
Total reconstructed states: 1000000 / 1000000

-- θ index = 4 --
Loaded prep

In [6]:
results.mean(-1)

tensor([[0.1285, 0.0791, 0.1338, 0.2508, 0.3678, 0.4915, 0.5931, 0.6581, 0.6855,
         0.6926, 0.6930],
        [0.1190, 0.0656, 0.1028, 0.1888, 0.2838, 0.4196, 0.5505, 0.6405, 0.6819,
         0.6923, 0.6931],
        [0.1885, 0.1094, 0.1494, 0.2306, 0.2759, 0.3954, 0.5365, 0.6391, 0.6875,
         0.6989, 0.7004]], dtype=torch.float64)

In [None]:
torch