**Note: This notebook is not used in the current project.**

This notebook was developed for reproducing training dynamics by loading saved model checkpoints and evaluating test metrics, but it is not part of the active workflow for the two-probe quantum error correction experiments.


In [1]:
from TFM import LlamaPredictor
import torch
from utils import torch_data, shuffle, blogm, bSqc, Neg, Sa, eps, create_train_test_split, save_checkpoint, load_checkpoint, save_checkpoint_and_test
from math import prod

dtype = torch.complex128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm
2025-07-22 21:20:39.486657: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1753219239.496837    6906 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1753219239.500959    6906 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1753219239.506137    6906 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1753219239.506147    6906 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1753219239.506149    6906

In [None]:
# Code to reproduce training dynamics by loading checkpoints and running test loops
import os

def run_test_evaluation(model, prepseq_test, shadow_state_test, rhoS_test, device):
    """
    Run test evaluation using the same logic as save_checkpoint_and_test from utils.
    This extracts just the test evaluation part without saving checkpoints.
    """
    model.eval()
    temp_test_metrics = {'loss':[], 'msk off Sqc':[], 'msk off Neg':[], 'msk off Sa':[]}
    
    with torch.no_grad():
        test_batches = prepseq_test.shape[0]
        for j in range(test_batches):
            prepseq_batch = prepseq_test[j].to(device)
            shadow_state_batch = shadow_state_test[j].to(device) 
            rhoS_batch = rhoS_test[j].to(device)
            
            # Forward pass (mask off for test metrics)
            rhoC = model(prepseq_batch, False)
            
            # Compute product of partial traces: Tr_A[rho^C] ⊗ Tr_B[rho^C]
            # Using same einsum pattern as in utils.py Sa function
            rho_reshaped = rhoC.view(-1, 2, 2, 2, 2)  # (batch, i1, i2, j1, j2)
            rho_A = torch.einsum('bijkj->bik', rho_reshaped)  # Tr_A: trace over 1st qubit (i1=j1)
            rho_B = torch.einsum('bijil->bjl', rho_reshaped)  # Tr_B: trace over 2nd qubit (i2=j2)
            rhoC_product = torch.vmap(torch.kron)(rho_A, rho_B)  # Product of marginals Tr_A ⊗ Tr_B
            
            # Calculate test metrics using utils functions with product of marginals
            temp_test_metrics['msk off Sqc'].extend(bSqc(rhoS_batch, rhoC_product).tolist())
            temp_test_metrics['msk off Neg'].extend(Neg(rhoS_batch, rhoC_product).tolist())
            temp_test_metrics['msk off Sa'].extend(Sa(rhoS_batch, rhoC_product).tolist())
            
            # Calculate loss (with mask on, same as training loss)
            rhoC_masked = model(prepseq_batch, True)
            rhoC_masked = rhoC_masked.view(-1, 2, 2, 2, 2)
            rhoC_masked_A = torch.einsum('bijkj->bik', rhoC_masked)
            rhoC_masked_B = torch.einsum('bijil->bjl', rhoC_masked)
            rhoC_masked_product = torch.vmap(torch.kron)(rhoC_masked_A, rhoC_masked_B)

            probs_masked = torch.bmm(torch.bmm(shadow_state_batch.conj().unsqueeze(1), rhoC_masked_product), shadow_state_batch.unsqueeze(-1)).view(-1).real
            loss_masked = -probs_masked.log().mean()
            temp_test_metrics['loss'].append(loss_masked.item())
    
    # Return mean values using torch tensors (same as utils pattern)
    return {
        'loss': torch.tensor(temp_test_metrics['loss']).mean().item(),
        'msk off Sqc': torch.tensor(temp_test_metrics['msk off Sqc']).mean().item(), 
        'msk off Neg': torch.tensor(temp_test_metrics['msk off Neg']).mean().item(),
        'msk off Sa': torch.tensor(temp_test_metrics['msk off Sa']).mean().item()
    }

# Configuration matching your training setup
seed = 81
test_size = 1*10**6
N = 36
batch = 1000
num_check = 20  # Number of checkpoints per epoch (matching your training)
max_epochs = 20  # Maximum epochs to process (matching your training)

# ===== SPECIFY WHICH THETA VALUES TO PROCESS =====
theta_values_to_process = [2,4,6,8,10]  # Change this to process different theta values, e.g. [0, 1, 2, 3, 4]

# Expected directory structure for each theta:
# training_theta=0/data/theta0/ and training_theta=0/save/models/
# training_theta=1/data/theta1/ and training_theta=1/save/models/  
# training_theta=2/data/theta2/ and training_theta=2/save/models/
# etc.

# Create single reproduce directory at notebook level
reproduce_file = 'reproduce'  # Single reproduce directory for all results
os.makedirs(f'{reproduce_file}/record', exist_ok=True)

for d in [5]:
    for theta_idx in theta_values_to_process:
        for train_size in [78*10**6]:
            # Set up directory paths for this theta value
            base_dir = f'training_theta={theta_idx}'  # Dynamic base directory for data/models
            file = f'{base_dir}/save'  # Original save directory
            
            print(f"Processing d={d}, theta_idx={theta_idx}, train_size={train_size}")
            print(f"Data/models directory: {base_dir}/")
            print(f"Results will be saved to: {reproduce_file}/record/")
            print(f"Expected checkpoints: {max_epochs} epochs × {num_check} checkpoints = {max_epochs * num_check} total")
            
            # Initialize model (same architecture as training)
            torch.manual_seed(seed)
            mdl = LlamaPredictor(L_max=N,
                                d=d,
                                n_embd=96, 
                                n_layer=36, 
                                n_head=48,
                                vocab_size=3, 
                                dropout_prob=0.0).to(device)
            
            # Dummy optimizer for checkpoint loading compatibility
            optimizer = torch.optim.Adam(mdl.parameters(), lr=1e-4)
            
            # Initialize record storage (same structure as training records)
            l_test_reproduced = {'loss':[], 'msk off Sqc':[], 'msk off Neg':[],'msk off Sa':[]}
            
            # Load test data using correct directory structure
            data_dir = f'{base_dir}/data/theta{theta_idx}'
            try:
                prepseq_all = torch.load(f'{data_dir}/all_prepseq_theta={theta_idx}.pt',weights_only=True)
                shadow_all = torch.load(f'{data_dir}/all_shadow_state_theta={theta_idx}.pt',weights_only=True)
                rhoS_all = torch.load(f'{data_dir}/all_rhoS_theta={theta_idx}.pt',weights_only=True)
                print(f"Data loaded from: {data_dir}/")
            except FileNotFoundError as e:
                print(f"ERROR: Could not load data from {data_dir}/ - {e}")
                print(f"Please check if directory {base_dir}/ exists and contains the data files")
                continue
            
            # Preprocess data exactly as in training
            prepseq_all = torch.cat([prepseq_all+1, torch.zeros(prepseq_all.shape[0], 1, dtype=prepseq_all.dtype)], -1)
            
            # Create test split using utils function (we only need test data)
            _, test_data = create_train_test_split(
                prepseq_all, shadow_all, rhoS_all, 
                train_size, test_size, batch
            )
            
            prepseq_test = test_data['prepseq']
            shadow_state_test = test_data['shadow_state']
            rhoS_test = test_data['rhoS']
            
            print(f"Test data loaded: {prepseq_test.shape[0]} batches")
            
            # Process all available checkpoints (20 per epoch × 20 epochs)
            checkpoint_dir = f'{file}/models'
            filename_prefix = f'model_d{d}_theta_idx{theta_idx}'
            
            print(f"Looking for checkpoints in: {checkpoint_dir}/")
            print(f"Checkpoint pattern: {filename_prefix}_epochXXX_stepXXXX.pt")
            
            checkpoint_count = 0
            
            # Process epochs from -1 to max_epochs-1 (to match your checkpoint naming)
            for epoch in range(-1, max_epochs):
                print(f"\nProcessing Epoch {epoch}...")
                epoch_checkpoints = 0
                
                # Process all 20 checkpoints for this epoch
                for checkpoint_num in range(num_check + 5):  # +5 to catch any extra final checkpoints
                    # Simple checkpoint naming format matching your example: epoch{epoch:03d}_step{checkpoint_num:04d}.pt
                    checkpoint_file = f'{checkpoint_dir}/{filename_prefix}_epoch{epoch:04d}_step{checkpoint_num:04d}.pt'
                    
                    if os.path.exists(checkpoint_file):
                        try:
                            # Load checkpoint manually since the naming format is different
                            checkpoint = torch.load(checkpoint_file, map_location=device)
                            mdl.load_state_dict(checkpoint['model_state_dict'])
                            if 'optimizer_state_dict' in checkpoint:
                                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                            
                            
                            # Run test evaluation using our helper function
                            metrics = run_test_evaluation(mdl, prepseq_test, shadow_state_test, rhoS_test, device)
                            
                            # Store results
                            l_test_reproduced['loss'].append(metrics['loss'])
                            l_test_reproduced['msk off Sqc'].append(metrics['msk off Sqc'])
                            l_test_reproduced['msk off Neg'].append(metrics['msk off Neg'])
                            l_test_reproduced['msk off Sa'].append(metrics['msk off Sa'])
                            
                            checkpoint_count += 1
                            epoch_checkpoints += 1
                            
                            print(f"  Checkpoint {checkpoint_num}: Loss={metrics['loss']:.4f}, Sqc={metrics['msk off Sqc']:.4f}, Neg={metrics['msk off Neg']:.4f}, Sa={metrics['msk off Sa']:.4f}")
                            
                        except Exception as e:
                            print(f"  Error loading {checkpoint_file}: {e}")
                            continue
                
                print(f"Epoch {epoch} completed: {epoch_checkpoints} checkpoints processed")
            
            # Save reproduced results
            print(f"\n" + "="*80)
            print(f"SUMMARY:")
            print(f"Total checkpoints processed: {checkpoint_count}")
            print(f"Expected: {max_epochs * num_check}, Found: {checkpoint_count}")
            
            # Save test records in same format as original training
            output_test_file = f'{reproduce_file}/record/reproduced_d{d}_theta_idx{theta_idx}_size{train_size}_test.pt'
            torch.save(l_test_reproduced, output_test_file)
            
            print(f"Reproduced test records saved to: {output_test_file}")
            print(f"Records contain {len(l_test_reproduced['loss'])} data points for plotting")
            print("="*80)
            
print(f"\n{'='*80}")
print(f"COMPLETED: Processed {len(theta_values_to_process)} theta value(s): {theta_values_to_process}")
print(f"All results saved in: {reproduce_file}/record/")
print(f"{'='*80}")

Processing d=5, theta_idx=2, train_size=78000000
Data/models directory: training_theta=2/
Results will be saved to: reproduce/record/
Expected checkpoints: 20 epochs × 20 checkpoints = 400 total
Data loaded from: training_theta=2/data/theta2/
test size=1000000, train size=78000000
test indices: [0-999999], train indices: [1000000-78999999]
Test data loaded: 1000 batches
Looking for checkpoints in: training_theta=2/save/models/
Checkpoint pattern: model_d5_theta_idx2_epochXXX_stepXXXX.pt

Processing Epoch -1...
