# Post-Processing Experimental Two-Probe Data

This notebook post-processes raw experimental measurement data from two-probe quantum experiments into a format suitable for machine learning training. It performs post-selection based on ancilla-physical qubit pairs for error mitigation, extracts probe qubit measurements, constructs quantum shadow states from different measurement bases, and aggregates data across multiple experimental runs into consolidated training datasets.

### Qubit Layout: 6x6 Physical Lattice with Ancilla Perimeter

```
        ⊕0  ⊕1   ⊕2    ⊕3   ⊕4   ⊕5           ← ANCILLA (top)
        │    │    │    │    │    │ 
   ⊕6 ──●7──●8 ──●9 ──●10──●11──●12──⊕13      
        │    │    │    │    │    │    
  ⊕14──●15──●16──●17──●18──●19──●20──⊕21
        │    │    │    │    │    │    
  ⊕22──●23──●24──●25──●26──●27──●28──⊕29      ← 6×6 PHYSICAL
        │    │    │    │    │    │            
  ⊕30──●31──●32──●33──●34──●35──●36──⊕37       
        │    │    │    │    │    │    
  ⊕38──●39──●40──●41──●42──●43──●44──⊕45
        │    │    │    │    │    │    
  ⊕46──●47──●48──●49──●50──●51──●52──⊕53
        │    │    │    │    │    │
       ⊕54  ⊕55  ⊕56  ⊕57  ⊕58  ⊕59           ← ANCILLA (bottom)
```

- **Physical qubits** (prep_idx): Interior 6×6 grid used for computation
- **Ancilla qubits**: Perimeter qubits used for error correction post-selection
- **Post-selection**: Requires ancilla-physical matching (e.g., ancilla 0 ↔ physical 7)

In [1]:
import torch # type: ignore
import os

dtype = torch.complex128
device = torch.device("cpu")# torch.device("cuda" if torch.cuda.is_available() else "cpu")

pauli = torch.tensor([[[1,0],[0,1]],[[0,1],[1,0]],[[0,-1j],[1j,0]],[[1,0],[0,-1]]], device=device, dtype=dtype)
basis = torch.linalg.eig(pauli)[1][1:].mT # (3, 2, 2)

In [2]:
# process two probe qubits data
def torch_data(filename, d):
    data = {}
    for i in range(3):
        for j in range(3):
            m = torch.load(filename+f'_({i},{j}).pt')
            # msk starts as all True, then gets progressively filtered by requiring ancilla-physical qubit pairs to be the same.
            # Finally used to keep only measurements that pass all post-selection criteria: m = m[msk]
            msk = torch.ones(m.shape[0], device=device, dtype=torch.bool)
            # post select on 2-qubit mitigation
            for anc, phy in [(0,7),
                             (1,8),
                             (2,9),
                             (3,10),
                             (4,11),
                             (5,12),
                             (54,47),
                             (55,48),
                             (56,49),
                             (57,50),
                             (58,51),
                             (59,52),
                             (6,7),
                             (14,15),
                             (22,23),
                             (30,31),
                             (38,39),
                             (46,47),
                             (13,12),
                             (21,20),
                             (29,28),
                             (37,36),
                             (45,44),
                             (53,52),
                             ]:
                msk = msk & (m[:,anc]==m[:,phy])
                #print(f'anc={anc}, phy={phy}, {((m[:,anc]==m[:,phy]).float().mean().item()):.4f}')
            prep_idx = [7,8,9,10,11,12,
                        15,16,17,18,19,20,
                        23,24,25,26,27,28,
                        31,32,33,34,35,36,
                        39,40,41,42,43,44,
                        47,48,49,50,51,52]
            m = m[msk] # (batch, num_qubits)
            if d == 6:
                probe_idx = [7,12]
            if d == 5:
                probe_idx = [7,11]
                #probe_idx = [12,44]
            if d == 4:
                probe_idx = [8,11]
            if d == 3:
                probe_idx = [8,10]
            prep_idx = [p for p in prep_idx if p not in probe_idx]
            probe = torch.cat([m[:,probe_idx[0]].view(-1,1), m[:,probe_idx[1]].view(-1,1)], 1)
            prep = m[:,prep_idx]
            data[(i,j)] = (prep, probe)
    prepseq, shadow_state, rhoS = [], [], []
    for k in data.keys():
        # construct post-measure state
        probseq = data[k][1].to(dtype=torch.int64).to(device=device) # (repetition, 2) last 2 outcomes
        obs_basis0 = basis[k[0]].unsqueeze(0).expand(probseq.shape[0], -1, -1) # (repetition, 2, 2)
        shadow_state0 = obs_basis0.gather(1, probseq[:,0].view(-1, 1, 1).expand(-1, -1, 2)).squeeze(1) # (repetition, 2)
        obs_basis1 = basis[k[1]].unsqueeze(0).expand(probseq.shape[0], -1, -1) # (repetition, 2, 2)
        shadow_state1 = obs_basis1.gather(1, probseq[:,1].view(-1, 1, 1).expand(-1, -1, 2)).squeeze(1) # (repetition, 2)
        shadow_state01 = torch.vmap(torch.kron)(shadow_state0, shadow_state1) # (batch, 4)
        # construct rhoS
        I = torch.eye(2, 2, device=device)[None,...].expand(shadow_state01.shape[0], -1, -1)
        rhoS0 = 3*torch.vmap(torch.outer)(shadow_state0, shadow_state0.conj()) - I
        rhoS1 = 3*torch.vmap(torch.outer)(shadow_state1, shadow_state1.conj()) - I
        rhoS01 = torch.vmap(torch.kron)(rhoS0, rhoS1)
        # collect result
        prepseq.append(data[k][0].to(dtype=torch.int64).to(device=device))
        shadow_state.append(shadow_state01)
        rhoS.append(rhoS01)
    prepseq = torch.cat(prepseq, 0).to(torch.int64)
    shadow_state = torch.cat(shadow_state, 0)
    rhoS = torch.cat(rhoS, 0)
    return prepseq, shadow_state, rhoS

def shuffle(prepseq, shadow_state, rhoS):
    indices = torch.randperm(prepseq.shape[0])
    prepseq = prepseq[indices]
    shadow_state = shadow_state[indices]
    rhoS = rhoS[indices]
    return prepseq, shadow_state, rhoS

### Data Object Descriptions

**Key saved objects and their meanings:**

- `all_prepseq` - Shape: (N, 34)  
  Ancilla measurement outcomes from physical qubits after post-selection filtering. Used as input features for ML models to predict probe states.

- `all_shadow_state` - Shape: (N, 4)  
  Shadow states of the two probe qubits: |ψ₁⟩⊗|ψ₂⟩ in the 4D Hilbert space. Constructed from probe measurements in X/Y/Z bases.

- `all_rhoS` - Shape: (N, 4, 4)  
  Tensor product of single-qubit shadow density matrices: ρS = ρS₁⊗ρS₂ where ρSᵢ = 3|ψᵢ⟩⟨ψᵢ| - I. These are the target density matrices for ML training.

---

### Data Storage Summary

**File Organization:**
```
data/theta{θ_idx}/
├── all_prepseq_theta={θ_idx}.pt     # Input: ancilla measurement patterns  
├── all_shadow_state_theta={θ_idx}.pt # Target: probe shadow states
└── all_rhoS_theta={θ_idx}.pt        # Target: probe shadow density matrices
```

**Usage:** ML models learn to predict probe quantum states (shadow_state/rhoS) from ancilla measurement outcomes (prepseq) for quantum error correction decoding.


In [3]:
for theta_idx in [4]:
    for d in [5]:
        all_prepseq = []
        all_shadow_state = []
        all_rhoS = []
        for loop in range(17):
            filename = f'data/theta{theta_idx}/loop{loop}/theta={theta_idx}'
            torch.manual_seed(loop)
            prepseq, shadow_state, rhoS = torch_data(filename, d)
            prepseq, shadow_state, rhoS = shuffle(prepseq, shadow_state, rhoS)
            all_prepseq.append(prepseq)
            all_shadow_state.append(shadow_state)
            all_rhoS.append(rhoS)
            print(f'distance={d}, loop={loop}, theta_idx={theta_idx}, portion to keep={((prepseq.shape[0]/9000000)):.4f}')
        all_prepseq = torch.cat(all_prepseq, 0)
        all_shadow_state = torch.cat(all_shadow_state, 0)
        all_rhoS = torch.cat(all_rhoS, 0)
        torch.save(all_prepseq, f'data/theta{theta_idx}/all_prepseq_theta={theta_idx}.pt')
        torch.save(all_shadow_state, f'data/theta{theta_idx}/all_shadow_state_theta={theta_idx}.pt')
        torch.save(all_rhoS, f'data/theta{theta_idx}/all_rhoS_theta={theta_idx}.pt')
        print(all_prepseq.shape, theta_idx)
        print(all_shadow_state.shape, theta_idx)
        print(all_rhoS.shape, theta_idx)

distance=5, loop=0, theta_idx=4, portion to keep=0.5838
distance=5, loop=1, theta_idx=4, portion to keep=0.5862
distance=5, loop=2, theta_idx=4, portion to keep=0.5859
distance=5, loop=3, theta_idx=4, portion to keep=0.5833
distance=5, loop=4, theta_idx=4, portion to keep=0.5830
distance=5, loop=5, theta_idx=4, portion to keep=0.5831
distance=5, loop=6, theta_idx=4, portion to keep=0.5810
distance=5, loop=7, theta_idx=4, portion to keep=0.5833
distance=5, loop=8, theta_idx=4, portion to keep=0.5801
distance=5, loop=9, theta_idx=4, portion to keep=0.5811
distance=5, loop=10, theta_idx=4, portion to keep=0.5862
distance=5, loop=11, theta_idx=4, portion to keep=0.5358
distance=5, loop=12, theta_idx=4, portion to keep=0.5414
distance=5, loop=13, theta_idx=4, portion to keep=0.5338
distance=5, loop=14, theta_idx=4, portion to keep=0.5470
distance=5, loop=15, theta_idx=4, portion to keep=0.5313
distance=5, loop=16, theta_idx=4, portion to keep=0.5250
torch.Size([86679738, 34]) 4
torch.Size([

In [None]:
# Scan for duplicates and save results (run once)
import torch
import numpy as np
from utils import create_train_test_split

train_size = 78*10**6
test_size = 10**6
batch = 1000
d = 5
seed = 81
ignore_sites = [28,29,30,31,32]

# Efficient duplicate analysis using tensor operations
def find_duplicates(data):
    # Convert bit strings to unique integers for fast comparison
    # Use base-2 (binary) encoding since values are 0,1 after preprocessing
    powers = 2 ** torch.arange(data.shape[1] - 1, -1, -1, device=data.device)
    hashes = (data * powers).sum(dim=1)
    
    # Get unique hashes and their counts
    unique_hashes, counts = torch.unique(hashes, return_counts=True)
    
    # Find duplicates (count > 1)
    duplicate_mask = counts > 1
    duplicate_counts = counts[duplicate_mask]
    
    return duplicate_counts.cpu()

# Store results for all theta values
duplicate_results = {}

for theta_idx in range(11):
    torch.manual_seed(seed)
    print(f'Processing theta {theta_idx}...')
    
    # load train/test data
    prepseq_all = torch.load(f'data/theta{theta_idx}/all_prepseq_theta={theta_idx}.pt',weights_only=True)
    shadow_all = torch.load(f'data/theta{theta_idx}/all_shadow_state_theta={theta_idx}.pt',weights_only=True)
    rhoS_all = torch.load(f'data/theta{theta_idx}/all_rhoS_theta={theta_idx}.pt',weights_only=True)
    
   
    # Create non-overlapping train/test split with batching
    train_data, test_data = create_train_test_split(
        prepseq_all, shadow_all, rhoS_all, 
        train_size, test_size, batch
    )

    prepseq_train = train_data['prepseq'].view(-1, 34) # (train_size, 34-len(ignore_sites))
    prepseq_test = test_data['prepseq'].view(-1, 34) # (test_size, 34-len(ignore_sites))

    # ignore sites
    if len(ignore_sites) > 0:
        prepseq_train = prepseq_train[:, ~torch.isin(torch.arange(34), torch.tensor(ignore_sites))]
        prepseq_test = prepseq_test[:, ~torch.isin(torch.arange(34), torch.tensor(ignore_sites))]
    
    # Find duplicates
    train_duplicate_counts = find_duplicates(prepseq_train)
    test_duplicate_counts = find_duplicates(prepseq_test)
    
    # Store results
    duplicate_results[theta_idx] = {
        'train_duplicate_counts': train_duplicate_counts,
        'test_duplicate_counts': test_duplicate_counts,
        'train_shape': prepseq_train.shape,
        'test_shape': prepseq_test.shape
    }
    
    print(f'  Train: {prepseq_train.shape}, duplicates: {len(train_duplicate_counts)}')
    print(f'  Test: {prepseq_test.shape}, duplicates: {len(test_duplicate_counts)}')

# Save lightweight results file
torch.save(duplicate_results, f'duplicate_analysis_results_ignore_sites={ignore_sites}.pt')
print(f'Saved duplicate analysis results to duplicate_analysis_results_ignore_sites={ignore_sites}.pt')


In [None]:
# Load and plot duplicate analysis results (run anytime for plotting)
import torch
import matplotlib.pyplot as plt
import numpy as np

# Specify sites to ignore (must match what was used in scanning cell)
ignore_sites = []
#ignore_sites = [28,29,30,31,32]

# Plot settings
max_frequency_display = 5  # Upper limit for x-axis display

# Create title suffix based on ignore sites
if ignore_sites:
    title_suffix = f' (ignore sites: {ignore_sites})'
else:
    title_suffix = ' (all sites)'

# Load saved results
duplicate_results = torch.load(f'duplicate_analysis_results_ignore_sites={ignore_sites}.pt', weights_only=True)

thetas = torch.linspace(0, np.pi/2, 11)

# Plot histogram of frequency distribution (including singles)
# Create one large figure with all thetas
valid_thetas = [i for i in range(11) if i in duplicate_results]
n_thetas = len(valid_thetas)

# Create subplot grid: n_thetas rows, 2 columns (train, test)
fig, axes = plt.subplots(n_thetas, 2, figsize=(16, 4*n_thetas))
if n_thetas == 1:
    axes = axes.reshape(1, -1)  # Ensure 2D array for single row

for plot_idx, theta_idx in enumerate(valid_thetas):
    data = duplicate_results[theta_idx]
    train_duplicate_counts = data['train_duplicate_counts']
    test_duplicate_counts = data['test_duplicate_counts']
    train_total_shots = data['train_shape'][0]
    test_total_shots = data['test_shape'][0]
    
    # print(f'theta {theta_idx}: train duplicates: {len(train_duplicate_counts)}, test duplicates: {len(test_duplicate_counts)}')
    
    # Use the current row of axes
    current_axes = axes[plot_idx]
    
    # Train histogram
    if len(train_duplicate_counts) > 0:
        # Simple calculation: count unique strings at each frequency
        total_duplicate_samples = train_duplicate_counts.sum().item()
        num_single_strings = train_total_shots - total_duplicate_samples
        total_unique_strings = num_single_strings + len(train_duplicate_counts)
        
        # Count frequency occurrences (always calculate all frequencies)
        freq_counts = {}
        freq_counts[1] = num_single_strings  # Always include singles in calculation
        for freq in train_duplicate_counts:
            f = freq.item()
            freq_counts[f] = freq_counts.get(f, 0) + 1
        
        # Convert to ratios (always relative to total unique strings)
        all_frequencies = sorted(freq_counts.keys())
        all_values = [freq_counts[f] / total_unique_strings for f in all_frequencies]
        
        # Filter what to display (always show singles)
        display_frequencies = all_frequencies
        display_values = all_values
        
        bars = axes[plot_idx, 0].bar(display_frequencies, display_values, alpha=0.7, width=0.8)
        
        # Add value labels on top of bars
        for bar, value in zip(bars, display_values):
            height = bar.get_height()
            axes[plot_idx, 0].text(bar.get_x() + bar.get_width()/2., height,
                        f'{value:.8f}', ha='center', va='bottom', fontsize=9)
        
        axes[plot_idx, 0].set_xlabel('Frequency')
        axes[plot_idx, 0].set_ylabel('Fraction of Unique Strings')
        axes[plot_idx, 0].set_title(f'Train Frequency Distribution (θ={thetas[theta_idx]:.2f}){title_suffix}')
        axes[plot_idx, 0].grid(True, alpha=0.3, axis='y')
        
        # Adaptive y-axis scaling
        axes[plot_idx, 0].set_ylim(0, max(display_values) * 1.08)
        
        
        from matplotlib.ticker import MaxNLocator
        axes[plot_idx, 0].xaxis.set_major_locator(MaxNLocator(integer=True))
        axes[plot_idx, 0].set_xlim(0.5, max_frequency_display + 0.5)
        axes[plot_idx, 0].set_xticks(range(1, max_frequency_display + 1))
    else:
        axes[plot_idx, 0].text(0.5, 0.5, 'All singles', ha='center', va='center', transform=axes[plot_idx, 0].transAxes)
        axes[plot_idx, 0].set_title(f'Train Frequency Distribution (θ={thetas[theta_idx]:.2f}){title_suffix}')
    
    # Test histogram  
    if len(test_duplicate_counts) > 0:
        # Simple calculation: count unique strings at each frequency
        total_duplicate_samples = test_duplicate_counts.sum().item()
        num_single_strings = test_total_shots - total_duplicate_samples
        total_unique_strings = num_single_strings + len(test_duplicate_counts)
        
        # Count frequency occurrences (always calculate all frequencies)
        freq_counts = {}
        freq_counts[1] = num_single_strings  # Always include singles in calculation
        for freq in test_duplicate_counts:
            f = freq.item()
            freq_counts[f] = freq_counts.get(f, 0) + 1
        
        # Convert to ratios (always relative to total unique strings)
        all_frequencies = sorted(freq_counts.keys())
        all_values = [freq_counts[f] / total_unique_strings for f in all_frequencies]
        
        # Filter what to display (always show singles)
        display_frequencies = [f for f in all_frequencies if f <= max_frequency_display]
        display_values = [freq_counts[f] / total_unique_strings for f in display_frequencies]
        
        bars = axes[plot_idx, 1].bar(display_frequencies, display_values, alpha=0.7, width=0.8)
        
        # Add value labels on top of bars
        for bar, value in zip(bars, display_values):
            height = bar.get_height()
            axes[plot_idx, 1].text(bar.get_x() + bar.get_width()/2., height,
                        f'{value:.8f}', ha='center', va='bottom', fontsize=9)
        
        axes[plot_idx, 1].set_xlabel('Frequency')
        axes[plot_idx, 1].set_ylabel('Fraction of Unique Strings')
        axes[plot_idx, 1].set_title(f'Test Frequency Distribution (θ={thetas[theta_idx]:.2f}){title_suffix}')
        axes[plot_idx, 1].grid(True, alpha=0.3, axis='y')
        
        # Adaptive y-axis scaling
        axes[plot_idx, 1].set_ylim(0, max(display_values) * 1.08)
        
        axes[plot_idx, 1].xaxis.set_major_locator(MaxNLocator(integer=True))
        axes[plot_idx, 1].set_xlim(0.5, max_frequency_display + 0.5)
        axes[plot_idx, 1].set_xticks(range(1, max_frequency_display + 1))
    else:
        axes[plot_idx, 1].text(0.5, 0.5, 'All singles', ha='center', va='center', transform=axes[plot_idx, 1].transAxes)
        axes[plot_idx, 1].set_title(f'Test Frequency Distribution (θ={thetas[theta_idx]:.2f}){title_suffix}')

# Show the combined figure with all thetas
plt.tight_layout()
plt.show()
