# QuartumSE Benchmark Suite

This notebook benchmarks **classical shadows** against direct measurement baselines for a single circuit.

In [None]:
# --- Setup ---
import sys
sys.path.insert(0, '../src')

import numpy as np
from collections import Counter, defaultdict
from qiskit import QuantumCircuit

from quartumse import (
    run_benchmark_suite,
    BenchmarkMode,
    BenchmarkSuiteConfig,
    generate_observable_set,
    Observable,
    ObservableSet,
)

print("Setup complete!")

---

## 1. Circuit Configuration

Define the circuit and observables to benchmark.

In [None]:
# =============================================================================
# CIRCUIT CONFIGURATION
# =============================================================================

CIRCUIT_ID = "ghz_4q"  # Identifier for this benchmark
N_QUBITS = 4

# --- Build Circuit ---
def build_ghz(n_qubits: int) -> QuantumCircuit:
    """GHZ state: (|00...0⟩ + |11...1⟩) / sqrt(2)"""
    qc = QuantumCircuit(n_qubits)
    qc.h(0)
    for i in range(1, n_qubits):
        qc.cx(i - 1, i)
    return qc

circuit = build_ghz(N_QUBITS)
print(f"Circuit: {CIRCUIT_ID}")
print(circuit.draw('text'))

In [None]:
# --- Observables ---
observables = []

# Random observables with mixed localities
for k in range(1, N_QUBITS + 1):
    obs_set = generate_observable_set(
        generator_id='random_pauli',
        n_qubits=N_QUBITS,
        n_observables=5,
        seed=42 + k,
        weight_distribution='fixed',
        fixed_weight=k,
    )
    observables.extend(list(obs_set.observables))

# Add GHZ stabilizers
observables.extend([
    Observable('Z' * N_QUBITS),
    Observable('X' * N_QUBITS),
])

obs_set = ObservableSet(
    observables=observables,
    observable_set_id=f'{CIRCUIT_ID}_obs',
    generator_id='mixed',
    generator_seed=42,
)

# Build locality map
locality_map = {}
for obs in observables:
    locality = sum(1 for c in obs.pauli_string if c != 'I')
    locality_map[obs.observable_id] = locality

print(f"Observables: {len(obs_set)}")
loc_counts = Counter(locality_map.values())
for k in sorted(loc_counts.keys()):
    print(f"  K={k}: {loc_counts[k]} observables")

---

## 2. Benchmark Configuration

In [None]:
# =============================================================================
# BENCHMARK CONFIGURATION
# =============================================================================

config = BenchmarkSuiteConfig(
    mode=BenchmarkMode.ANALYSIS,      # Full analysis with all tasks
    n_shots_grid=[100, 500, 1000, 5000],
    n_replicates=20,                  # 20 for publication quality
    seed=42,
    epsilon=0.01,                     # Target precision
    delta=0.05,                       # Failure probability
    shadows_protocol_id="classical_shadows_v0",
    baseline_protocol_id="direct_grouped",
    output_base_dir="benchmark_results",
)

print("Benchmark Configuration:")
print(f"  Mode: {config.mode.value}")
print(f"  Shot grid: {config.n_shots_grid}")
print(f"  Replicates: {config.n_replicates}")
print(f"  Target ε: {config.epsilon}")

---

## 3. Run Benchmark

In [None]:
%%time
# =============================================================================
# RUN BENCHMARK
# =============================================================================

result = run_benchmark_suite(
    circuit=circuit,
    observable_set=obs_set,
    circuit_id=CIRCUIT_ID,
    config=config,
    locality_map=locality_map,
)

---

## 4. Benchmark Summary Report

Single consolidated table answering all 8 Measurements Bible questions.

In [None]:
# =============================================================================
# CONSOLIDATED BENCHMARK SUMMARY REPORT
# =============================================================================

def generate_summary_report(result, config):
    """Generate a single consolidated summary table for all 8 tasks."""
    
    long_form = result.long_form_results
    truth_values = result.ground_truth.truth_values if result.ground_truth else {}
    max_n = max(result.summary.get('n_shots_grid', [5000]))
    epsilon = config.epsilon
    
    # Group data by protocol and N
    by_protocol_n = defaultdict(lambda: defaultdict(list))
    for row in long_form:
        by_protocol_n[row.protocol_id][row.N_total].append(row)
    
    protocols = list(by_protocol_n.keys())
    
    # =========================================================================
    # TASK 1: Worst-Case Guarantee
    # =========================================================================
    task1_answers = []
    for protocol in protocols:
        n_star = None
        for n in sorted(by_protocol_n[protocol].keys()):
            rows = by_protocol_n[protocol][n]
            max_se = max(r.se for r in rows if r.se is not None)
            if max_se <= epsilon:
                n_star = n
                break
        if n_star:
            task1_answers.append(f"{protocol}: N*={n_star}")
        else:
            rows = by_protocol_n[protocol][max_n]
            max_se = max(r.se for r in rows if r.se is not None)
            task1_answers.append(f"{protocol}: N*>{max_n} (SE={max_se:.4f})")
    
    # =========================================================================
    # TASK 2: Average Target
    # =========================================================================
    task2_answers = []
    for protocol in protocols:
        n_star = None
        for n in sorted(by_protocol_n[protocol].keys()):
            rows = by_protocol_n[protocol][n]
            mean_se = np.mean([r.se for r in rows if r.se is not None])
            if mean_se <= epsilon:
                n_star = n
                break
        if n_star:
            task2_answers.append(f"{protocol}: N*={n_star}")
        else:
            rows = by_protocol_n[protocol][max_n]
            mean_se = np.mean([r.se for r in rows if r.se is not None])
            task2_answers.append(f"{protocol}: N*>{max_n} (SE={mean_se:.4f})")
    
    # =========================================================================
    # TASK 3: Fixed Budget Distribution
    # =========================================================================
    task3_answers = []
    for protocol in protocols:
        rows = by_protocol_n[protocol][max_n]
        ses = [r.se for r in rows if r.se is not None]
        if ses:
            task3_answers.append(
                f"{protocol}: mean={np.mean(ses):.4f}, median={np.median(ses):.4f}, max={np.max(ses):.4f}"
            )
    
    # =========================================================================
    # TASK 4: Dominance
    # =========================================================================
    obs_best = defaultdict(lambda: {'se': float('inf'), 'protocol': None})
    for protocol in protocols:
        rows = by_protocol_n[protocol][max_n]
        for row in rows:
            if row.se < obs_best[row.observable_id]['se']:
                obs_best[row.observable_id] = {'se': row.se, 'protocol': protocol}
    
    wins = defaultdict(int)
    for obs_id, data in obs_best.items():
        if data['protocol']:
            wins[data['protocol']] += 1
    
    total_obs = len(obs_best)
    task4_answers = []
    for protocol in protocols:
        win_count = wins[protocol]
        win_pct = 100 * win_count / total_obs if total_obs > 0 else 0
        task4_answers.append(f"{protocol}: {win_count}/{total_obs} ({win_pct:.0f}%)")
    winner = max(wins, key=wins.get) if wins else "N/A"
    
    # =========================================================================
    # TASK 5: Pilot Selection
    # =========================================================================
    if result.analysis and hasattr(result.analysis, 'pilot_analysis') and result.analysis.pilot_analysis:
        pilot = result.analysis.pilot_analysis
        task5_answer = f"Optimal pilot fraction: {pilot.optimal_fraction*100:.0f}%"
    else:
        task5_answer = "Requires ANALYSIS mode"
    
    # =========================================================================
    # TASK 6: Bias-Variance Decomposition
    # =========================================================================
    task6_answers = []
    if truth_values:
        for protocol in protocols:
            rows = by_protocol_n[protocol][max_n]
            by_obs = defaultdict(list)
            for row in rows:
                if row.observable_id in truth_values:
                    by_obs[row.observable_id].append(row.estimate)
            
            biases_sq, variances = [], []
            for obs_id, estimates in by_obs.items():
                truth = truth_values[obs_id]
                mean_est = np.mean(estimates)
                biases_sq.append((mean_est - truth)**2)
                variances.append(np.var(estimates))
            
            if biases_sq:
                bias_sq = np.mean(biases_sq)
                var = np.mean(variances)
                mse = bias_sq + var
                task6_answers.append(f"{protocol}: Bias²={bias_sq:.6f}, Var={var:.6f}, MSE={mse:.6f}")
    else:
        task6_answers = ["Requires ground truth"]
    
    # =========================================================================
    # TASK 7 & 8: Not implemented in single-run benchmark
    # =========================================================================
    task7_answer = "Requires noise profile sweep (not in single-run benchmark)"
    task8_answer = "Requires adaptive protocol (see Task 5 pilot analysis)"
    
    # =========================================================================
    # PRINT CONSOLIDATED REPORT
    # =========================================================================
    print("=" * 100)
    print("BENCHMARK SUMMARY REPORT")
    print("=" * 100)
    print(f"Circuit: {result.summary.get('circuit_id', 'unknown')}")
    print(f"Observables: {len(obs_set)} | Shot Grid: {result.summary.get('n_shots_grid', [])} | Replicates: {config.n_replicates}")
    print(f"Target precision ε = {epsilon}")
    print("=" * 100)
    print()
    
    # Build the summary table
    summary_data = [
        ("1", "WORST-CASE N*", 
         f"What N* achieves max SE ≤ {epsilon} for ALL observables?",
         "\n".join(task1_answers)),
        
        ("2", "AVERAGE N*",
         f"What N* achieves mean SE ≤ {epsilon}?",
         "\n".join(task2_answers)),
        
        ("3", "SE DISTRIBUTION",
         f"What is the SE distribution at N = {max_n}?",
         "\n".join(task3_answers)),
        
        ("4", "DOMINANCE",
         "Which protocol wins on more observables?",
         "\n".join(task4_answers) + f"\n→ WINNER: {winner}"),
        
        ("5", "PILOT SELECTION",
         "What fraction of budget for pilot?",
         task5_answer),
        
        ("6", "BIAS-VARIANCE",
         "How does MSE decompose into bias² + variance?",
         "\n".join(task6_answers)),
        
        ("7", "NOISE SENSITIVITY",
         "How does performance degrade with noise?",
         task7_answer),
        
        ("8", "ADAPTIVE EFFICIENCY",
         "Can budget reallocation improve results?",
         task8_answer),
    ]
    
    for task_num, task_name, question, answer in summary_data:
        print(f"┌{'─'*98}┐")
        print(f"│ TASK {task_num}: {task_name:<87} │")
        print(f"├{'─'*98}┤")
        print(f"│ Q: {question:<94} │")
        print(f"├{'─'*98}┤")
        for line in answer.split('\n'):
            print(f"│ A: {line:<94} │")
        print(f"└{'─'*98}┘")
        print()
    
    # Executive Summary
    print("=" * 100)
    print("EXECUTIVE SUMMARY")
    print("=" * 100)
    
    protocol_summaries = result.summary.get('protocol_summaries', {})
    if protocol_summaries:
        best_by_mean = min(protocol_summaries, key=lambda p: protocol_summaries[p].get('mean_se', float('inf')))
        best_by_max = min(protocol_summaries, key=lambda p: protocol_summaries[p].get('max_se', float('inf')))
        
        print(f"Best protocol (mean SE): {best_by_mean}")
        print(f"Best protocol (max SE):  {best_by_max}")
        
        if 'classical_shadows_v0' in protocol_summaries and 'direct_grouped' in protocol_summaries:
            shadows_se = protocol_summaries['classical_shadows_v0'].get('mean_se', 1)
            grouped_se = protocol_summaries['direct_grouped'].get('mean_se', 1)
            ratio = shadows_se / grouped_se if grouped_se > 0 else float('inf')
            
            if ratio < 1:
                print(f"\nClassical shadows is {1/ratio:.1f}x MORE efficient than direct grouped")
                print(f"Shot Savings Factor (SSF): {1/(ratio**2):.1f}x")
            else:
                print(f"\nDirect grouped is {ratio:.1f}x MORE efficient than classical shadows")
    
    print()
    print("=" * 100)
    print(f"Full results saved to: {result.output_dir}")
    print("=" * 100)

# Generate the report
generate_summary_report(result, config)