## Relay-BP Decoder Demo

This notebook demonstrates the Relay-BP decoder integrated into the Qiskit-CSS-T project structure.

In [None]:
import numpy as np
import scipy.sparse
import matplotlib.pyplot as plt
from datetime import datetime
import sys
import os

# Add path to decoder source
sys.path.append(os.path.abspath('../src'))
from relay_bp_decoder import RelayDecoder

### Section 1: Mathematical & Matrix Functions

In [None]:
def repetition_code_matrices(d):
    """
    Generates Hx and Hz matrices for a repetition code setup.
    """
    n = d
    m = d - 1
    data = []
    indices = []
    indptr = [0]
    
    for i in range(m):
        data.extend([1, 1])
        indices.extend([i, i+1])
        indptr.append(len(data))
        
    H = scipy.sparse.csr_matrix((data, indices, indptr), shape=(m, n))
    return H

def get_parity_matrices(n):
    # We simulate Z checks correcting X errors
    Hz = repetition_code_matrices(n)
    Hx = scipy.sparse.csr_matrix((1, n), dtype=int) # Dummy
    return Hx, Hz

In [None]:
n = 13
Hx, Hz = get_parity_matrices(n)
print(f"Generated matrices for n={n}")
print(f"Hz shape: {Hz.shape}")

### Section 2: Decoder Initialization

In [None]:
def initialize_decoder(H, error_rate):
    """
    Initialize the RelayBP Decoder.
    """
    priors = np.full(H.shape[1], error_rate)
    decoder = RelayDecoder(
        check_matrix=H,
        error_priors=priors,
        pre_iter=20,
        num_legs=10,
        iter_per_leg=20,
        gamma_interval=(-0.5, 0.5)
    )
    return decoder

### Section 3: Simulation Functions

In [None]:
def generate_errors(num_shots, n, error_probability):
    # Simple i.i.d error generation
    errors = np.random.choice(
        [0, 1], size=(num_shots, n), p=[1 - error_probability, error_probability]
    )
    return errors

def simulate_single_shot(H, error, decoder):
    """
    Simulates decoding for a single shot.
    """
    syndrome = (H @ error) % 2
    
    success, correction = decoder.decode(syndrome)
    
    if not success:
        return 1 # Decoder failed to converge
    
    # Check for logical error
    residual = (error + correction) % 2
    
    # For Repetition Code, logical operators are all 1s (transversal)
    if np.all(residual == 0):
        return 0
    else:
        return 1

In [None]:
def run_simulation(ns, ps, num_shots):
    log_errors_counts = []
    
    for n in ns:
        print(f"Simulating for n={n}...")
        Hx, Hz = get_parity_matrices(n)
        # We simulate Z checks correcting X errors
        H = Hz
        
        log_errors = []
        
        for p in ps:
            decoder = initialize_decoder(H, p)
            errors = generate_errors(num_shots, n, p)
            
            num_logical_errors = 0
            for i in range(num_shots):
                num_logical_errors += simulate_single_shot(H, errors[i], decoder)
            
            logical_error_rate = num_logical_errors / num_shots
            log_errors.append(logical_error_rate)
            # print(f"  p={p:.4f}, LER={logical_error_rate:.4f}")
            
        log_errors_counts.append(log_errors)
        
    return log_errors_counts

### Section 4: Plotting

In [None]:
def plot_logical_error_rates(ns, ps, log_errors_counts):
    """
    Plot the logical error rates for various distances.
    """
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    plt.figure(figsize=(12, 8))
    
    for i, log_errors in enumerate(log_errors_counts):
        plt.plot(ps, log_errors, marker='o', linestyle='-', label=f'n={ns[i]}')

    # Add break-even line
    plt.plot(ps, ps, linestyle='--', color='black', label='Break-even (y=x)')

    plt.xscale('log')
    plt.yscale('log')
    plt.xlabel('Physical Error Rate (p)')
    plt.ylabel('Logical Error Rate')
    plt.title(f'Logical Error Rate vs. Physical Error Rate (Relay-BP)')
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.grid(True, which="both", ls="--")
    
    plt.tight_layout()
    plt.show()

### Section 5: Execution

In [None]:
ns = [5, 7, 9, 11]
ps = np.linspace(0.01, 0.15, 10).tolist()
num_shots = 1000

log_errors_counts = run_simulation(ns, ps, num_shots)

In [None]:
plot_logical_error_rates(ns, ps, log_errors_counts)