In [113]:
import sys
sys.path.append("src/")

from simulate import Algorithm, BivariateBicycle, NoiseModel, CodeFactory, StimSimulator, LogicalCircuit
import pandas as pd
import numpy as np
import torch
from torch.utils.data import IterableDataset
from torch_geometric.data import Data
import math
from functools import lru_cache


class BivariateBicycleDataset(IterableDataset):
    """
    PyTorch IterableDataset for bivariate bicycle code error correction data.
    
    Generates batches of bivariate bicycle code detector measurements and corresponding 
    logical error labels for training graph neural network decoders.
    
    Supports 3-stage curriculum learning with automatic p adjustment.
    """
    
    def __init__(self, l=6, m=6, rounds=10, p=2.0, batch_size=32, 
                 stage_manager=None, num_workers=8, global_step_offset=0, **kwargs):
        """
        Args:
            l: BB code parameter l
            m: BB code parameter m
            rounds: Number of syndrome rounds
            p: Default error probability parameter (used when stage_manager is None)
            batch_size: Batch size for generated data
            stage_manager: StageManager instance for curriculum learning (optional)
            num_workers: Total number of DataLoader workers (for step estimation)
            global_step_offset: Starting global step offset (for resuming)
        """
        super().__init__()
        self.default_p = p
        self.batch_size = batch_size
        self.l = l
        self.m = m
        self.rounds = rounds
        self.stage_manager = stage_manager
        self.num_workers = num_workers
        self.global_step_offset = global_step_offset
        
        # Track local worker sample count (will be set in each worker)
        self.local_sample_count = 0
        
        # Build static code graph structure (independent of rounds)
        self._static_graph_structure = self._build_static_graph_structure()
    
    def get_neighborhood_size(self):
        """Get the total number of relation types (static + temporal)."""
        detectors, logical_errors, (rounds, p), mt = self.generate_batch(
                self.l, self.m, 10, 0.01, 1
        )
        graph = self._build_spatiotemporal_graph_structure(rounds, mt)
        return graph.shape[1]
    
    def get_num_logical_qubits(self):
        """Get the number of logical qubits encoded by this BB code."""
        # logical_operators is a numpy structured array with fields: logical_id, logical_pauli, data_id, physical_pauli
        # logical_id goes from 0 to k-1 where k is the number of logical qubits
        logical_ops = self._static_graph_structure['metadata'].logical_operators
        return np.max(logical_ops['logical_id']) + 1
    
    def _estimate_global_step(self) -> int:
        """Estimate current global step from worker-local sample count."""
        return self.global_step_offset + (self.local_sample_count * self.num_workers)
    
    def get_current_p(self) -> float:
        """Get current p value based on stage manager or default."""
        if self.stage_manager is not None:
            estimated_step = self._estimate_global_step()
            return self.stage_manager.get_current_p(estimated_step)
        else:
            return self.default_p
    
    def _build_static_graph_structure(self):
        """Build the static BB code graph structure (independent of rounds)."""
        # Create BB code metadata
        cf = CodeFactory(BivariateBicycle, {'l': self.l, 'm': self.m})
        metadata = cf().metadata
        
        # Build tanner graph - this creates edges between checks and data
        tanner = pd.DataFrame(metadata.tanner)
        
        # Create check-to-check graph as in bb-test.ipynb
        c2c = tanner.merge(tanner, on='data_id')[['check_id_x', 'check_id_y', 'edge_type_x', 'edge_type_y']]
        all_edge_types = []
        graph_edges = []
        
        for (c1, c2), df in c2c.groupby(['check_id_x', 'check_id_y']):
            edge_types = df[['edge_type_x', 'edge_type_y']].to_numpy()
            edge_types = frozenset([tuple(map(int, x)) for x in edge_types])
            if edge_types not in all_edge_types:
                all_edge_types.append(edge_types)
            c1, c2 = map(int, (c1, c2))
            graph_edges.append((c1, c2, all_edge_types.index(edge_types)))
        
        graph_df = pd.DataFrame(graph_edges, columns=['syndrome_id', 'neighbor_syndrome_id', 'edge_type'])
        
        return {
            'graph_df': graph_df,
            'all_edge_types': all_edge_types,
            'metadata': metadata
        }
    
    def _build_spatiotemporal_graph_structure(self, rounds, mt):
        """Build the spatiotemporal graph structure using the static graph and temporal connections.
        
        Args:
            rounds: Number of error correction rounds
            mt: MeasurementTracker from the LogicalCircuit (contains detector metadata)
        """
        # Use pre-computed static structure
        graph_df = self._static_graph_structure['graph_df']
        all_edge_types = self._static_graph_structure['all_edge_types']
        
        # Get actual detector metadata from measurement tracker
        det = pd.DataFrame(mt.detectors)
        
        t0 = det[det['time'] == 1].copy()
        t0['time'] = 0
        t0['detector_id'] = -1
        tf = det[det['time'] == det['time'].max()]
        opposite_basis = t0[~t0['syndrome_id'].isin(tf['syndrome_id'].values)].copy()
        opposite_basis['time'] = det['time'].max()
        tf2 = t0.copy()
        tf2['time'] = det['time'].max()+1
        det = pd.concat([t0, det, opposite_basis, tf2]).sort_values(by=['time', 'syndrome_id']).reset_index(drop=True)
        
        graph_df = self._static_graph_structure['graph_df']
        all_edge_types = self._static_graph_structure['all_edge_types']
        det_graph = det.merge(graph_df, on='syndrome_id').merge(
            det, left_on='neighbor_syndrome_id', right_on='syndrome_id', suffixes=('', '_nb')
        )

        # Filter temporal connections (allow connections within 1 time step)
        det_graph = det_graph[np.abs(det_graph['time'] - det_graph['time_nb']) <= 1]

        # Encode temporal information in edge types
        dt = det_graph['time_nb'] - det_graph['time'] + 1  # Maps [-1, 0, 1] to [0, 1, 2]
        det_graph['final_edge_type'] = dt * len(all_edge_types) + det_graph['edge_type']

        actual_graph = det_graph[det_graph['detector_id'] != -1]
        actual_graph = actual_graph.sort_values(['detector_id', 'final_edge_type'])
        graph = actual_graph[['detector_id', 'detector_id_nb', 'final_edge_type']].to_numpy()
        _, counts = np.unique(graph[...,0], return_counts=True)
        assert np.all(counts == counts[0])
        nbrhood_size = counts[0]
        graph = graph.reshape((-1, nbrhood_size, 3))        
        return graph

    
    @staticmethod
    def generate_batch(l, m, rounds, p, batch_size):
        """Generate a single batch of bivariate bicycle code data."""
        
        # Create BB code circuit
        alg = Algorithm.build_memory(cycles=rounds)
        cf = CodeFactory(BivariateBicycle, {'l': l, 'm': m})
        lc, _, mt = LogicalCircuit.from_algorithm(alg, cf)
        noise_model = NoiseModel.get_scaled_noise_model(p).without_loss()
        sim = StimSimulator(alg, noise_model, cf, seed=np.random.randint(0, 2**48))
        
        # Sample syndrome data
        results = sim.sample(shots=batch_size)
        detectors = results.detectors
        logical_errors = results.logical_errors
        
        return detectors, logical_errors, (rounds, p), mt
    
    def __iter__(self):
        """Generate infinite stream of bivariate bicycle code data batches."""
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is not None:
            worker_id = worker_info.id if hasattr(worker_info, 'id') else 'unknown'
        else: 
            worker_id = 'main'
            np.random.seed(0)
        
        # Reset local sample count for this worker
        self.local_sample_count = 0
        
        while True:
            # Calculate current p based on estimated global step
            current_p = self.get_current_p()
            
            # Generate batch with current curriculum p value
            detectors, logical_errors, (rounds, p), mt = self.generate_batch(
                self.l, self.m, self.rounds, current_p, self.batch_size
            )
            synd_id = mt.detectors['syndrome_id']
            detectors = detectors * (synd_id.max()+1) + synd_id[None,:]
            # No longer yielding graph since it's static and registered in the model
            yield detectors.astype(np.int32), np.squeeze(logical_errors.astype(np.float32), axis=1), (rounds, p)

            # Increment local sample count after generating batch
            self.local_sample_count += 1
    
    def get_graph(self):
        # Generate batch with current curriculum p value
        detectors, logical_errors, (rounds, p), mt = self.generate_batch(
            self.l, self.m, self.rounds, 1.0, self.batch_size
        )
        synd_id = mt.detectors['detector_id']
        detectors = detectors * (synd_id.max()+1) + synd_id[None,:]
        graph = self._build_spatiotemporal_graph_structure(rounds, mt)
        return graph
    
    def get_num_embeddings(self):
        detectors, logical_errors, (rounds, p), mt = self.generate_batch(
                self.l, self.m, self.rounds, 1.0, self.batch_size
            )
        synd_id = mt.detectors['syndrome_id']
        return 2*(synd_id.max()+1)

In [114]:
class BB2(IterableDataset):
    """
    PyTorch IterableDataset for bivariate bicycle code error correction data.
    
    Generates batches of bivariate bicycle code detector measurements and corresponding 
    logical error labels for training graph neural network decoders.
    
    Supports 3-stage curriculum learning with automatic p adjustment.
    """
    
    def __init__(self, l=6, m=6, rounds=10, p=2.0, batch_size=32, 
                 stage_manager=None, num_workers=8, global_step_offset=0, **kwargs):
        """
        Args:
            l: BB code parameter l
            m: BB code parameter m
            rounds: Number of syndrome rounds
            p: Default error probability parameter (used when stage_manager is None)
            batch_size: Batch size for generated data
            stage_manager: StageManager instance for curriculum learning (optional)
            num_workers: Total number of DataLoader workers (for step estimation)
            global_step_offset: Starting global step offset (for resuming)
        """
        super().__init__()
        self.default_p = p
        self.batch_size = batch_size
        self.l = l
        self.m = m
        self.rounds = rounds
        self.stage_manager = stage_manager
        self.num_workers = num_workers
        self.global_step_offset = global_step_offset
        
        # Track local worker sample count (will be set in each worker)
        self.local_sample_count = 0
        
        # Build static code graph structure (independent of rounds)
        self._static_graph_structure = self._build_static_graph_structure()
    
    def get_neighborhood_size(self):
        """Get the total number of relation types (static + temporal)."""
        detectors, logical_errors, (rounds, p), mt = self.generate_batch(
                self.l, self.m, 10, 0.01, 1
        )
        graph = self._build_spatiotemporal_graph_structure(rounds, mt)
        return graph.shape[1]
    
    def get_num_logical_qubits(self):
        """Get the number of logical qubits encoded by this BB code."""
        # logical_operators is a numpy structured array with fields: logical_id, logical_pauli, data_id, physical_pauli
        # logical_id goes from 0 to k-1 where k is the number of logical qubits
        logical_ops = self._static_graph_structure['metadata'].logical_operators
        return np.max(logical_ops['logical_id']) + 1
    
    def _estimate_global_step(self) -> int:
        """Estimate current global step from worker-local sample count."""
        return self.global_step_offset + (self.local_sample_count * self.num_workers)
    
    def get_current_p(self) -> float:
        """Get current p value based on stage manager or default."""
        if self.stage_manager is not None:
            estimated_step = self._estimate_global_step()
            return self.stage_manager.get_current_p(estimated_step)
        else:
            return self.default_p
    
    def _build_static_graph_structure(self):
        """Build the static BB code graph structure (independent of rounds)."""
        # Create BB code metadata
        cf = CodeFactory(BivariateBicycle, {'l': self.l, 'm': self.m})
        metadata = cf().metadata
        
        # Build tanner graph - this creates edges between checks and data
        tanner = pd.DataFrame(metadata.tanner)
        
        # Create check-to-check graph as in bb-test.ipynb
        c2c = tanner.merge(tanner, on='data_id')[['check_id_x', 'check_id_y', 'edge_type_x', 'edge_type_y']]
        all_edge_types = []
        graph_edges = []
        
        for (c1, c2), df in c2c.groupby(['check_id_x', 'check_id_y']):
            edge_types = df[['edge_type_x', 'edge_type_y']].to_numpy()
            edge_types = frozenset([tuple(map(int, x)) for x in edge_types])
            if edge_types not in all_edge_types:
                all_edge_types.append(edge_types)
            c1, c2 = map(int, (c1, c2))
            graph_edges.append((c1, c2, all_edge_types.index(edge_types)))
        
        graph_df = pd.DataFrame(graph_edges, columns=['syndrome_id', 'neighbor_syndrome_id', 'edge_type'])
        return {
            'graph_df': graph_df.sort_values(by=['syndrome_id', 'edge_type']),
            'all_edge_types': all_edge_types,
            'metadata': metadata
        }
        
    @staticmethod
    def generate_batch(l, m, rounds, p, batch_size):
        """Generate a single batch of bivariate bicycle code data."""
        
        # Create BB code circuit
        alg = Algorithm.build_memory(cycles=rounds)
        cf = CodeFactory(BivariateBicycle, {'l': l, 'm': m})
        lc, _, mt = LogicalCircuit.from_algorithm(alg, cf)
        noise_model = NoiseModel.get_scaled_noise_model(p).without_loss()
        sim = StimSimulator(alg, noise_model, cf, seed=np.random.randint(0, 2**48))
        
        # Sample syndrome data
        results = sim.sample(shots=batch_size)
        detectors = results.detectors
        logical_errors = results.logical_errors
        
        return detectors, logical_errors, (rounds, p), mt
    
    def __iter__(self):
        """Generate infinite stream of bivariate bicycle code data batches."""
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is not None:
            worker_id = worker_info.id if hasattr(worker_info, 'id') else 'unknown'
        else: 
            worker_id = 'main'
            np.random.seed(0)
        
        # Reset local sample count for this worker
        self.local_sample_count = 0
        
        while True:
            # Calculate current p based on estimated global step
            current_p = self.get_current_p()
            
            # Generate batch with current curriculum p value
            detectors, logical_errors, (rounds, p), mt = self.generate_batch(
                self.l, self.m, self.rounds, current_p, self.batch_size
            )
            synd_id = mt.detectors['syndrome_id']
            detectors = detectors * (synd_id.max()+1) + synd_id[None,:]
            
            ret = np.zeros((detectors.shape[0], rounds+1, synd_id.max()+1), dtype=np.int32)
            
            ret[:, mt.detectors['time']-1, mt.detectors['syndrome_id']] = detectors+1
            # No longer yielding graph since it's static and registered in the model
            yield ret, np.squeeze(logical_errors.astype(np.float32), axis=1), (rounds, p)

            # Increment local sample count after generating batch
            self.local_sample_count += 1
    
    def get_graph(self):
        # Generate batch with current curriculum p value
        detectors, logical_errors, (rounds, p), mt = self.generate_batch(
            self.l, self.m, self.rounds, 1.0, self.batch_size
        )
        synd_id = mt.detectors['detector_id']
        detectors = detectors * (synd_id.max()+1) + synd_id[None,:]
        graph = self._build_spatiotemporal_graph_structure(rounds, mt)
        return graph
    
    def get_num_embeddings(self):
        detectors, logical_errors, (rounds, p), mt = self.generate_batch(
                self.l, self.m, self.rounds, 1.0, self.batch_size
            )
        synd_id = mt.detectors['syndrome_id']
        return 2*(synd_id.max()+1)+1

In [160]:
def profile_torch_function(func, *args, warmup_runs=5, num_runs=20, **kwargs):
    device = torch.device("cuda")

    # --- Warm-up Phase ---
    for _ in range(warmup_runs):
        _ = func(*args, **kwargs)
    torch.cuda.synchronize()

    # --- Time Measurement ---
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    timings = []
    for _ in range(num_runs):
        start_event.record()
        _ = func(*args, **kwargs)
        end_event.record()
        torch.cuda.synchronize()
        timings.append(start_event.elapsed_time(end_event))
    
    average_time_ms = sum(timings) / len(timings)

    # --- Memory Measurement ---
    torch.cuda.reset_peak_memory_stats(device)
    _ = func(*args, **kwargs)
    torch.cuda.synchronize()
    peak_memory_bytes = torch.cuda.max_memory_allocated(device)
    peak_memory_mb = peak_memory_bytes / (1024 * 1024)

    return {
        "average_time_ms": average_time_ms,
        "peak_memory_mb": peak_memory_mb
    }


In [167]:
ds = BB2(l=6, m=6, p=1.0, rounds=6, batch_size=64)
graph_df = ds._build_static_graph_structure()['graph_df'].copy()
edge_types = list(sorted(set(map(tuple, graph_df.groupby('syndrome_id').agg(list)['edge_type'].tolist()))))
edge_type_bdr = min(edge_types[1])
graph_df.loc[graph_df['edge_type'] >= edge_type_bdr, 'edge_type'] -= edge_type_bdr

graph = np.zeros((graph_df['syndrome_id'].max()+1, 22), dtype=int)
graph[graph_df['syndrome_id'], graph_df['edge_type']] = graph_df['neighbor_syndrome_id']


embedding_dim = 128
emb = nn.Embedding(ds.get_num_embeddings(), embedding_dim=embedding_dim).cuda()
lin = nn.Linear(graph.shape[1]*embedding_dim, 3*embedding_dim).cuda()

x, y, _ = next(iter(ds))
x, y = torch.from_numpy(x).int().cuda(), torch.from_numpy(y)
x = emb(x)





In [172]:
x0, x1, x2 = torch.chunk(lin(x[:,:,graph,:].flatten(start_dim=-2)), 3, dim=-1)

In [162]:
def method1():
    x0, x1, x2 = torch.chunk(lin(x[:,:,graph,:].flatten(start_dim=-2)), 3, dim=-1)
    xout = F.pad(x0[:,:-1], (0,0,0,0,1,0,0,0)) + x1 + F.pad(x2[:,1:], (0,0,0,0,0,1,0,0))
profiling_results = profile_torch_function(
    method1, warmup_runs=10, num_runs=100
)
print(profiling_results)

{'average_time_ms': 11.615856666564941, 'peak_memory_mb': 7635.63818359375}


In [164]:
import torch.nn as nn
import torch.nn.functional as F
ds = BivariateBicycleDataset(l=6, m=6, p=1.0, rounds=6, batch_size=64)
embedding_dim = 128
emb = nn.Embedding(ds.get_num_embeddings(), embedding_dim=embedding_dim).cuda()
lin = nn.Linear(ds.get_neighborhood_size()*embedding_dim, embedding_dim).cuda()
print(lin)
x, y, _ = next(iter(ds))
x, y = torch.from_numpy(x).int().cuda(), torch.from_numpy(y)
x = emb(x)
x2 = F.pad(x, (0,0,0,1,0,0))
out = lin(x2[:, ds.get_graph()[...,1]].flatten(start_dim=2))
def method2():
    x2 = F.pad(x, (0,0,0,1,0,0))
    out = lin(x2[:, ds.get_graph()[...,1]].flatten(start_dim=2))

profiling_results = profile_torch_function(
    method2, warmup_runs=10, num_runs=100
)

Linear(in_features=8448, out_features=128, bias=True)


In [165]:
profiling_results

{'average_time_ms': 108.21879806518555, 'peak_memory_mb': 8185.26025390625}