# QuartumSE Benchmark Suite

This notebook benchmarks **classical shadows** against direct measurement baselines for a single circuit.

In [None]:
# --- Setup ---
import sys
sys.path.insert(0, '../src')

import numpy as np
from collections import Counter, defaultdict
from qiskit import QuantumCircuit

from quartumse import (
    run_benchmark_suite,
    BenchmarkMode,
    BenchmarkSuiteConfig,
    generate_observable_set,
    Observable,
    ObservableSet,
)

print("Setup complete!")

---

## 1. Circuit Configuration

Define the circuit and observables to benchmark.

In [None]:
# =============================================================================
# CIRCUIT CONFIGURATION
# =============================================================================

CIRCUIT_ID = "ghz_4q"  # Identifier for this benchmark
N_QUBITS = 4

# --- Build Circuit ---
def build_ghz(n_qubits: int) -> QuantumCircuit:
    """GHZ state: (|00...0⟩ + |11...1⟩) / sqrt(2)"""
    qc = QuantumCircuit(n_qubits)
    qc.h(0)
    for i in range(1, n_qubits):
        qc.cx(i - 1, i)
    return qc

circuit = build_ghz(N_QUBITS)
print(f"Circuit: {CIRCUIT_ID}")
print(circuit.draw('text'))

In [None]:
# --- Observables ---
observables = []

# Random observables with mixed localities
for k in range(1, N_QUBITS + 1):
    obs_set = generate_observable_set(
        generator_id='random_pauli',
        n_qubits=N_QUBITS,
        n_observables=5,
        seed=42 + k,
        weight_distribution='fixed',
        fixed_weight=k,
    )
    observables.extend(list(obs_set.observables))

# Add GHZ stabilizers
observables.extend([
    Observable('Z' * N_QUBITS),
    Observable('X' * N_QUBITS),
])

obs_set = ObservableSet(
    observables=observables,
    observable_set_id=f'{CIRCUIT_ID}_obs',
    generator_id='mixed',
    generator_seed=42,
)

# Build locality map
locality_map = {}
for obs in observables:
    locality = sum(1 for c in obs.pauli_string if c != 'I')
    locality_map[obs.observable_id] = locality

print(f"Observables: {len(obs_set)}")
loc_counts = Counter(locality_map.values())
for k in sorted(loc_counts.keys()):
    print(f"  K={k}: {loc_counts[k]} observables")

---

## 2. Benchmark Configuration

In [None]:
# =============================================================================
# BENCHMARK CONFIGURATION
# =============================================================================

config = BenchmarkSuiteConfig(
    mode=BenchmarkMode.ANALYSIS,      # Full analysis with all tasks
    n_shots_grid=[100, 500, 1000, 5000],
    n_replicates=20,                  # 20 for publication quality
    seed=42,
    epsilon=0.01,                     # Target precision
    delta=0.05,                       # Failure probability
    shadows_protocol_id="classical_shadows_v0",
    baseline_protocol_id="direct_grouped",
    output_base_dir="benchmark_results",
)

print("Benchmark Configuration:")
print(f"  Mode: {config.mode.value}")
print(f"  Shot grid: {config.n_shots_grid}")
print(f"  Replicates: {config.n_replicates}")
print(f"  Target ε: {config.epsilon}")

---

## 3. Run Benchmark

In [None]:
%%time
# =============================================================================
# RUN BENCHMARK
# =============================================================================

result = run_benchmark_suite(
    circuit=circuit,
    observable_set=obs_set,
    circuit_id=CIRCUIT_ID,
    config=config,
    locality_map=locality_map,
)

---

## 4. Benchmark Summary Report

In [None]:
# =============================================================================
# SUMMARY REPORT
# =============================================================================

long_form = result.long_form_results
truth_values = result.ground_truth.truth_values if result.ground_truth else {}
max_n = max(result.summary.get('n_shots_grid', [5000]))
epsilon = config.epsilon

# Group data by protocol and N
by_protocol_n = defaultdict(lambda: defaultdict(list))
for row in long_form:
    by_protocol_n[row.protocol_id][row.N_total].append(row)

protocols = sorted(by_protocol_n.keys())

# Header
print(f"Circuit: {CIRCUIT_ID} | Qubits: {N_QUBITS} | Observables: {len(obs_set)}")
print(f"Shot grid: {result.summary.get('n_shots_grid', [])} | Replicates: {config.n_replicates} | Target ε = {epsilon}")
print()

# Build protocol column headers
col_width = 22
header = f"{'Task':<6} {'Question':<45}"
for p in protocols:
    short_name = p.replace('classical_shadows_v0', 'shadows').replace('direct_', '')
    header += f" {short_name:>{col_width}}"
print(header)
print("-" * len(header))

# Task 1: Worst-Case N*
row = f"{'1':<6} {'What N* for max SE ≤ ε (all obs)?':<45}"
for protocol in protocols:
    n_star = None
    for n in sorted(by_protocol_n[protocol].keys()):
        rows = by_protocol_n[protocol][n]
        max_se = max(r.se for r in rows if r.se is not None)
        if max_se <= epsilon:
            n_star = n
            break
    if n_star:
        row += f" {'N*=' + str(n_star):>{col_width}}"
    else:
        rows = by_protocol_n[protocol][max_n]
        max_se = max(r.se for r in rows if r.se is not None)
        row += f" {f'N*>{max_n} (SE={max_se:.3f})':>{col_width}}"
print(row)

# Task 2: Average N*
row = f"{'2':<6} {'What N* for mean SE ≤ ε?':<45}"
for protocol in protocols:
    n_star = None
    for n in sorted(by_protocol_n[protocol].keys()):
        rows = by_protocol_n[protocol][n]
        mean_se = np.mean([r.se for r in rows if r.se is not None])
        if mean_se <= epsilon:
            n_star = n
            break
    if n_star:
        row += f" {'N*=' + str(n_star):>{col_width}}"
    else:
        rows = by_protocol_n[protocol][max_n]
        mean_se = np.mean([r.se for r in rows if r.se is not None])
        row += f" {f'N*>{max_n} (SE={mean_se:.3f})':>{col_width}}"
print(row)

# Task 3: SE Distribution (multiple rows)
print(f"{'3':<6} {'SE distribution at N={max_n}?':<45}")
for metric, label in [('mean', 'mean'), ('median', 'median'), ('max', 'max')]:
    row = f"{'':.<6} {'  ' + label:<45}"
    for protocol in protocols:
        rows = by_protocol_n[protocol][max_n]
        ses = [r.se for r in rows if r.se is not None]
        if metric == 'mean':
            val = np.mean(ses)
        elif metric == 'median':
            val = np.median(ses)
        else:
            val = np.max(ses)
        row += f" {val:>{col_width}.4f}"
    print(row)

# Task 4: Dominance
obs_best = defaultdict(lambda: {'se': float('inf'), 'protocol': None})
for protocol in protocols:
    rows = by_protocol_n[protocol][max_n]
    for r in rows:
        if r.se < obs_best[r.observable_id]['se']:
            obs_best[r.observable_id] = {'se': r.se, 'protocol': protocol}
wins = defaultdict(int)
for obs_id, data in obs_best.items():
    if data['protocol']:
        wins[data['protocol']] += 1
total_obs = len(obs_best)

row = f"{'4':<6} {'Which protocol wins most observables?':<45}"
for protocol in protocols:
    win_count = wins[protocol]
    win_pct = 100 * win_count / total_obs if total_obs > 0 else 0
    row += f" {f'{win_count}/{total_obs} ({win_pct:.0f}%)':>{col_width}}"
print(row)

# Task 5: Pilot Selection
row = f"{'5':<6} {'Optimal pilot fraction?':<45}"
if result.analysis and hasattr(result.analysis, 'pilot_analysis') and result.analysis.pilot_analysis:
    pilot_pct = f"{result.analysis.pilot_analysis.optimal_fraction*100:.0f}%"
    row += f" {pilot_pct:>{col_width}}" + " " * (col_width + 1) * (len(protocols) - 1)
else:
    row += f" {'N/A':>{col_width}}" + " " * (col_width + 1) * (len(protocols) - 1)
print(row)

# Task 6: Bias-Variance (multiple rows)
print(f"{'6':<6} {'Bias-variance decomposition (MSE)?':<45}")
if truth_values:
    for metric, label in [('bias2', 'Bias²'), ('var', 'Variance'), ('mse', 'MSE')]:
        row = f"{'':.<6} {'  ' + label:<45}"
        for protocol in protocols:
            rows = by_protocol_n[protocol][max_n]
            by_obs = defaultdict(list)
            for r in rows:
                if r.observable_id in truth_values:
                    by_obs[r.observable_id].append(r.estimate)
            biases_sq, variances = [], []
            for obs_id, estimates in by_obs.items():
                truth = truth_values[obs_id]
                mean_est = np.mean(estimates)
                biases_sq.append((mean_est - truth)**2)
                variances.append(np.var(estimates))
            if biases_sq:
                if metric == 'bias2':
                    val = np.mean(biases_sq)
                elif metric == 'var':
                    val = np.mean(variances)
                else:
                    val = np.mean(biases_sq) + np.mean(variances)
                row += f" {val:>{col_width}.6f}"
            else:
                row += f" {'N/A':>{col_width}}"
        print(row)
else:
    print(f"{'':.<6} {'  (requires ground truth)':<45}")

# Task 7: Noise Sensitivity
row = f"{'7':<6} {'Noise sensitivity?':<45}"
row += f" {'(requires noise sweep)':>{col_width}}" + " " * (col_width + 1) * (len(protocols) - 1)
print(row)

# Task 8: Adaptive Efficiency
row = f"{'8':<6} {'Adaptive budget reallocation?':<45}"
row += f" {'(see Task 5)':>{col_width}}" + " " * (col_width + 1) * (len(protocols) - 1)
print(row)

print("-" * len(header))

# Winner summary
winner = max(wins, key=wins.get) if wins else "N/A"
print(f"\nWinner (most observables): {winner}")
print(f"Results saved to: {result.output_dir}")