# High-fidelity entanglement distillation (DEJMPS)

This notebook implements the DEJMPS purification protocol from the purification paper.
It follows the game handbook LOCC rules and uses the same server API calls as demo.ipynb.

Goal: maximize output fidelity for a single claim by using a Bell-diagonal error model
and a DEJMPS plan (alignment + rotations + round bases).


## Protocol summary (DEJMPS, Bell-diagonal form)

Given two identical Bell-diagonal pairs with coefficients p00, p01, p10, p11:
1) (Optional) local basis change (DEJMPS): Alice Sdg+H, Bob S+H.
2) Bilateral CNOT (source -> target) on Alice and Bob sides.
3) Measure the target pair in Z; keep only if parity matches (XOR = 0).
4) If both outcomes are 1, apply Z to Alice source qubit.
5) Repeat on remaining pairs (recurrence).

The map on Bell coefficients is:
- p00' = (p00^2 + p11^2) / N
- p01' = (p01^2 + p10^2) / N
- p10' = (2 p00 p11) / N
- p11' = (2 p01 p10) / N
- N = (p00 + p11)^2 + (p01 + p10)^2   (success probability)


In [31]:
from client import GameClient
from visualization import GraphTool
from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister
from qiskit import qasm3
from pathlib import Path
from datetime import datetime
import json
import random

from edge_error_models import (
    estimate_bell_coeffs,
    load_error_model_row,
    load_error_models,
    append_error_row,
    plan_dejmps_from_error_row,
    canonical_edge,
)


In [45]:
# Flag behavior reproducer for organizers (selfâ€‘contained).
# Run from repo root with requirements installed.
from pathlib import Path
import json
from client import GameClient
from qiskit import QuantumCircuit, ClassicalRegister, QuantumRegister, qasm3

SESSION_FILE = Path("session.json")

def load_or_register():
    if SESSION_FILE.exists():
        data = json.loads(SESSION_FILE.read_text(encoding="utf-8"))
        client = GameClient(api_token=data.get("api_token"))
        client.player_id = data.get("player_id")
        client.name = data.get("name")
        if client.get_status():
            print(f"Resumed {client.player_id}")
            return client
    # Register if no session
    client = GameClient()
    player_id = input("player_id: ").strip()
    name = input("name: ").strip()
    location = input("location (remote/in_person): ").strip()
    resp = client.register(player_id, name, location=location)
    if not resp.get("ok"):
        print("Register failed:", resp)
        return None
    SESSION_FILE.write_text(json.dumps({
        "api_token": client.api_token,
        "player_id": client.player_id,
        "name": client.name,
    }), encoding="utf-8")
    return client

def pick_edge(client):
    claimable = client.get_claimable_edges()
    if not claimable:
        print("No claimable edges; select a starting node first.")
        return None
    # pick highest threshold edge for visibility
    edge = max(claimable, key=lambda e: e.get("base_threshold", 0))
    return tuple(edge["edge_id"]), edge.get("base_threshold", 0)

def run_case(client, edge_id, circuit, flag_bit, num_bell_pairs=2):
    result = client.claim_edge(edge_id, circuit, flag_bit, num_bell_pairs)
    if not result.get("ok"):
        print(f"{circuit.name:18} | error: {result.get('error', {}).get('message')}")
        return
    data = result.get("data", {})
    print(f"{circuit.name:18} | flag_bit={flag_bit} | "
          f"success_prob={data.get('success_probability',0):.4f} | "
          f"fidelity={data.get('fidelity',0):.4f} | success={data.get('success')}")

def run_direct_qasm(client, edge_id, qasm, flag_bit, num_bell_pairs=1, label="direct_qasm"):
    payload = {
        "player_id": client.player_id,
        "edge": [edge_id[0], edge_id[1]],
        "num_bell_pairs": int(num_bell_pairs),
        "circuit_qasm": qasm,
        "flag_bit": int(flag_bit),
    }
    result = client._post("/v1/claim_edge", payload)
    if not result.get("ok"):
        print(f"{label:18} | error: {result.get('error', {}).get('message')}")
        return
    data = result.get("data", {})
    print(f"{label:18} | flag_bit={flag_bit} | "
          f"success_prob={data.get('success_probability',0):.4f} | "
          f"fidelity={data.get('fidelity',0):.4f} | success={data.get('success')}")

client = load_or_register()
if not client:
    raise SystemExit(1)

status = client.get_status()
if not status.get("starting_node"):
    print("No starting node selected. Select one in the main notebook/demo first.")
    raise SystemExit(1)

edge_id, thr = pick_edge(client)
print(f"Diagnostics edge: {edge_id} (threshold: {thr:.3f})")

# --- Build minimal circuits (N=2 unless noted) ---
def base_circuit(n=2, name="base"):
    qr = QuantumRegister(2*n, "q")
    cr = ClassicalRegister(3, "c")
    qc = QuantumCircuit(qr, cr, name=name)
    return qc, qr, cr

# 1) No flag measurement (flag bit never written) -> success_prob should be 1.0
qc, qr, cr = base_circuit(2, "no_flag")
run_case(client, edge_id, qc, flag_bit=2, num_bell_pairs=2)

# 2) Measure entangled qubit into flag -> success_prob ~0.5
qc, qr, cr = base_circuit(2, "flag_meas_qubit")
qc.measure(qr[0], cr[2])
run_case(client, edge_id, qc, flag_bit=2, num_bell_pairs=2)

# 3) Reset + measure (should be 1.0 if reset honored; observed ~0.5)
qc, qr, cr = base_circuit(2, "flag_reset_meas")
qc.reset(qr[0])
qc.measure(qr[0], cr[2])
run_case(client, edge_id, qc, flag_bit=2, num_bell_pairs=2)

# 4) Reset + X + measure (should be 0.0 if reset honored; observed ~0.5)
qc, qr, cr = base_circuit(2, "flag_reset_x_meas")
qc.reset(qr[0])
qc.x(qr[0])
qc.measure(qr[0], cr[2])
run_case(client, edge_id, qc, flag_bit=2, num_bell_pairs=2)

# 5) Purely classical flag assignment (direct QASM) -> ignored by server
qasm = """OPENQASM 3.0;
include "stdgates.inc";
qubit[2] q;
bit[3] c;
c[0] = measure q[0];
c[1] = measure q[1];
c[2] = c[0] ^ c[1];
"""
run_direct_qasm(client, edge_id, qasm, flag_bit=2, num_bell_pairs=1, label="classical_xor")

print("\nExpected/Observed:")
print("- no_flag -> success_prob = 1.0")
print("- any direct measurement into flag -> success_prob ~ 0.5")
print("- reset is ignored by server (still ~0.5)")
print("- classical assignment ignored (success_prob stays 1.0)")


Resumed bobbyKat123456
Diagnostics edge: ('Boston, MA', 'Cambridge, MA') (threshold: 0.900)
no_flag            | flag_bit=2 | success_prob=1.0000 | fidelity=0.8500 | success=False
flag_meas_qubit    | flag_bit=2 | success_prob=0.5000 | fidelity=0.8500 | success=False
flag_reset_meas    | flag_bit=2 | success_prob=0.5000 | fidelity=0.8500 | success=False
flag_reset_x_meas  | flag_bit=2 | success_prob=0.5000 | fidelity=0.8500 | success=False
classical_xor      | flag_bit=2 | success_prob=0.8500 | fidelity=0.5000 | success=False

Expected/Observed:
- no_flag -> success_prob = 1.0
- any direct measurement into flag -> success_prob ~ 0.5
- reset is ignored by server (still ~0.5)
- classical assignment ignored (success_prob stays 1.0)


In [32]:
SESSION_FILE = Path("session.json")

def save_session(client: GameClient) -> None:
    if client.api_token:
        with SESSION_FILE.open("w", encoding="utf-8") as f:
            json.dump(
                {
                    "api_token": client.api_token,
                    "player_id": client.player_id,
                    "name": client.name,
                },
                f,
            )
        print("Session saved.")


def load_session() -> GameClient | None:
    if not SESSION_FILE.exists():
        return None
    with SESSION_FILE.open("r", encoding="utf-8") as f:
        data = json.load(f)
    client = GameClient(api_token=data.get("api_token"))
    client.player_id = data.get("player_id")
    client.name = data.get("name")
    status = client.get_status()
    if status:
        print(
            f"Resumed: {client.player_id} | Score: {status.get('score', 0)} | Budget: {status.get('budget', 0)}"
        )
        return client
    return None


In [33]:
# Try to resume an existing session
client = load_session()
if not client:
    print("No saved session. Register below.")


Resumed: bobbyKat123456 | Score: 0 | Budget: 40


In [34]:
# Register if needed
if client and client.api_token:
    print(f"Already registered as {client.player_id}")
else:
    client = GameClient()

    # CHANGE THESE to your unique values
    PLAYER_ID = "your_player_id"
    PLAYER_NAME = "Your Name"

    result = client.register(PLAYER_ID, PLAYER_NAME, location=input("remote or in_person: ").strip())
    if result.get("ok"):
        print(f"Registered! Token: {client.api_token[:20]}...")
        candidates = result["data"].get("starting_candidates", [])
        print(f"Starting candidates ({len(candidates)}):")
        for c in candidates:
            print(f"  - {c['node_id']}: {c['utility_qubits']} qubits, +{c['bonus_bell_pairs']} bonus")
        save_session(client)
    else:
        print(f"Failed: {result.get('error', {}).get('message')}")


Already registered as bobbyKat123456


In [35]:
# Select a starting node if not chosen
status = client.get_status()
if status.get("starting_node"):
    print(f"Starting node: {status['starting_node']}")
    print(f"Budget: {status['budget']} | Score: {status['score']}")
else:
    print("Select a starting node from the candidates shown above.")
    # Uncomment and modify:
    # result = client.select_starting_node("Cambridge, MA")
    # print(result)


Starting node: Cambridge, MA
Budget: 40 | Score: 0


In [36]:
# Claimable edges overview
claimable = client.get_claimable_edges()
claimable_sorted = sorted(claimable, key=lambda e: (e.get("difficulty_rating", 999), e.get("base_threshold", 999)))

print(f"Claimable edges ({len(claimable_sorted)}):")
for edge in claimable_sorted[:10]:
    print(
        f"  {edge['edge_id']} - threshold: {edge.get('base_threshold', 0):.3f}, difficulty: {edge.get('difficulty_rating')}"
    )
if len(claimable_sorted) > 10:
    print(f"  ... and {len(claimable_sorted) - 10} more")


Claimable edges (6):
  ['Boston, MA', 'Cambridge, MA'] - threshold: 0.900, difficulty: 1
  ['Cambridge, MA', 'Providence, RI'] - threshold: 0.900, difficulty: 1
  ['Cambridge, MA', 'Worcester, MA'] - threshold: 0.900, difficulty: 1
  ['Cambridge, MA', 'Hartford, CT'] - threshold: 0.900, difficulty: 1
  ['Cambridge, MA', 'Portland, ME'] - threshold: 0.900, difficulty: 1
  ['Cambridge, MA', 'Halifax, Canada'] - threshold: 0.900, difficulty: 3


## Bell tomography and error-model helpers

We estimate Bell coefficients with N=1 and store them in edge_error_models.csv.
This lets us plan a DEJMPS sequence that maximizes expected fidelity.


In [37]:
ERROR_MODEL_CSV = Path("edge_error_models.csv")

CALIBRATE_IF_MISSING = True
CALIBRATE_NUM_BELL_PAIRS = 1  # uses N=1 for tomography

OBJECTIVE = "fidelity"  # "fidelity", "success", or "fidelity_times_success"
MIN_SUCCESS_PROB = 0.0  # set >0.0 to avoid very low success plans
ALIGNMENT_MODE = "heuristic"  # "heuristic" or "all"

# Protocol selection
USE_CAUSAL_ORDER_PROTOCOL = True  # PhysRevA.108.062601 (N=4 control-swap)
ALLOW_N8_WITH_DEJMPS = False      # if True, N=8 fallback uses DEJMPS

# Error-model awareness
USE_ERROR_MODEL = False  # set True to use tomography + DEJMPS planning

# QASM compatibility
PATCH_FLAG_SYNTAX = True  # rewrite `if (!c[i])` -> `if (c[i] == false)`




In [38]:
def ensure_error_model(client: GameClient, edge_id, num_bell_pairs=1, csv_path: Path = ERROR_MODEL_CSV):
    row = load_error_model_row(edge_id, path=csv_path, prefer_latest=True)
    if row:
        return row

    payload, err = estimate_bell_coeffs(client, edge_id, num_bell_pairs=num_bell_pairs)
    if err:
        print(f"Calibration error: {err.get('error', {}).get('message')}")
        return None

    coeffs, coeffs_norm, success_probs = payload
    p_sum = sum(coeffs.values())
    avg_success = sum(success_probs.values()) / max(len(success_probs), 1)

    edge_info = None
    try:
        edge_info = client.get_edge_info(edge_id[0], edge_id[1])
    except Exception:
        edge_info = None

    a, b = canonical_edge(edge_id)
    row = {
        "node_a": a,
        "node_b": b,
        "p_phi_plus": coeffs["phi_plus"],
        "p_phi_minus": coeffs["phi_minus"],
        "p_psi_plus": coeffs["psi_plus"],
        "p_psi_minus": coeffs["psi_minus"],
        "p_sum": p_sum,
        "p_phi_plus_norm": coeffs_norm["phi_plus"],
        "p_phi_minus_norm": coeffs_norm["phi_minus"],
        "p_psi_plus_norm": coeffs_norm["psi_plus"],
        "p_psi_minus_norm": coeffs_norm["psi_minus"],
        "base_threshold": (edge_info or {}).get("base_threshold"),
        "difficulty_rating": (edge_info or {}).get("difficulty_rating"),
        "success_probability": avg_success,
        "timestamp_utc": datetime.utcnow().isoformat() + "Z",
        "note": "tomography N=1, local Pauli on Alice",
    }
    append_error_row(csv_path, row)
    return row


def pick_best_edge_for_fidelity(claimable_edges, budget, error_table):
    candidate_ns = [n for n in (2, 4, 8) if n <= budget]
    if not candidate_ns:
        return None, None, None

    best_edge = None
    best_plan = None
    best_row = None

    for edge in claimable_edges:
        row = error_table.get(canonical_edge(edge["edge_id"]))
        if not row:
            continue
        plan = plan_dejmps_from_error_row(
            row,
            candidate_num_bell_pairs=candidate_ns,
            alignment_mode=ALIGNMENT_MODE,
            objective=OBJECTIVE,
            success_threshold=MIN_SUCCESS_PROB,
        )
        if not plan:
            continue
        if best_plan is None or plan.expected_fidelity > best_plan.expected_fidelity:
            best_edge, best_plan, best_row = edge, plan, row

    return best_edge, best_plan, best_row


In [39]:
# Choose a target edge, calibrating if needed
status = client.get_status()
budget = int(status.get("budget", 0)) if status else 0

if not claimable_sorted:
    print("No claimable edges. Claim or select a starting node first.")
else:
    error_table = load_error_models(ERROR_MODEL_CSV) if USE_ERROR_MODEL else {}
    target_edge = None
    plan = None
    row = None

    if USE_ERROR_MODEL:
        target_edge, plan, row = pick_best_edge_for_fidelity(claimable_sorted, budget, error_table)

    if target_edge is None:
        # Fallback: easiest edge by difficulty/threshold
        target_edge = claimable_sorted[0]
        row = error_table.get(canonical_edge(target_edge["edge_id"])) if USE_ERROR_MODEL else None
        if USE_ERROR_MODEL and row is None and CALIBRATE_IF_MISSING:
            row = ensure_error_model(
                client,
                target_edge["edge_id"],
                num_bell_pairs=CALIBRATE_NUM_BELL_PAIRS,
                csv_path=ERROR_MODEL_CSV,
            )
        if USE_ERROR_MODEL and row:
            candidate_ns = [n for n in (2, 4, 8) if n <= budget]
            plan = plan_dejmps_from_error_row(
                row,
                candidate_num_bell_pairs=candidate_ns,
                alignment_mode=ALIGNMENT_MODE,
                objective=OBJECTIVE,
                success_threshold=MIN_SUCCESS_PROB,
            )

    print(f"Target edge: {target_edge['edge_id']}")
    if plan:
        print("Planned DEJMPS:")
        print(plan)
    else:
        print("No plan available (error model disabled or missing).")


Target edge: ['Boston, MA', 'Cambridge, MA']
No plan available (error model disabled or missing).


## DEJMPS circuit builder

This circuit respects LOCC constraints and uses a sticky flag bit (c[2]) for post-selection.
The output pair is always (N-1, N) as required by the game rules.


In [40]:
def _apply_pauli(qc: QuantumCircuit, qubits, pauli: str) -> None:
    if pauli == "I":
        return
    for q in qubits:
        if pauli == "X":
            qc.x(q)
        elif pauli == "Y":
            qc.y(q)
        elif pauli == "Z":
            qc.z(q)
        else:
            raise ValueError(f"Unknown Pauli: {pauli}")


def _apply_alignment_gates(qc: QuantumCircuit, qubits, gates) -> None:
    for gate in gates:
        if gate == "H":
            for q in qubits:
                qc.h(q)
        elif gate == "S":
            for q in qubits:
                qc.s(q)
        elif gate == "I":
            continue
        else:
            raise ValueError(f"Unknown alignment gate: {gate}")


def _measure_parity_sticky(qc: QuantumCircuit, a_t, b_t, c) -> None:
    # Measure parity (Alice XOR Bob) into c0, with sticky OR using existing c0.
    qc.measure(b_t, c[1])
    # If c0 already 1, force it to remain 1 by measuring |1>
    with qc.if_test((c[0], 1)):
        qc.reset(a_t)
        qc.x(a_t)
        qc.measure(a_t, c[0])
    # If c0 is 0, compute parity into c0
    with qc.if_test((c[0], False)):
        with qc.if_test((c[1], 1)):
            qc.x(a_t)
        qc.measure(a_t, c[0])


def create_dejmps_circuit(
    num_bell_pairs: int,
    alignment_gates=(),
    pre_rotation_pauli: str = "I",
    round_rotations=None,
    round_bases=None,
    random_pairing: bool = False,
    rng_seed: int | None = 1234,
):
    if num_bell_pairs not in (2, 4, 8):
        raise ValueError("num_bell_pairs must be 2, 4, or 8")

    qr = QuantumRegister(2 * num_bell_pairs, "q")
    c = ClassicalRegister(3, "c")  # c0: sticky parity flag, c1: Bob meas, c2: unused
    qc = QuantumCircuit(qr, c)

    rng = random.Random(rng_seed)

    pairs = [(i, 2 * num_bell_pairs - 1 - i) for i in range(num_bell_pairs)]
    output_pair = (num_bell_pairs - 1, num_bell_pairs)

    bob_qubits = list(range(num_bell_pairs, 2 * num_bell_pairs))

    # Optional alignment (Bob-side only) and pre-rotation to map dominant Bell label to phi+
    _apply_alignment_gates(qc, [qr[q] for q in bob_qubits], alignment_gates)
    _apply_pauli(qc, [qr[q] for q in bob_qubits], pre_rotation_pauli)

    def _round_basis(level: int) -> str:
        if not round_bases:
            return "dejmps"
        if 0 <= level < len(round_bases):
            return round_bases[level]
        return "dejmps"

    def _round_rotation(level: int) -> str:
        if not round_rotations:
            return "I"
        if 0 <= level < len(round_rotations):
            return round_rotations[level]
        return "I"

    def _apply_dejmps_basis(a_s, a_t, b_s, b_t):
        for q in (a_s, a_t):
            qc.sdg(qr[q])
            qc.h(qr[q])
        for q in (b_s, b_t):
            qc.s(qr[q])
            qc.h(qr[q])

    def dejmps_round(source_pair, target_pair, level: int):
        a_s, b_s = source_pair
        a_t, b_t = target_pair

        basis = _round_basis(level)
        if basis == "dejmps":
            _apply_dejmps_basis(a_s, a_t, b_s, b_t)
        elif basis != "identity":
            raise ValueError(f"Unknown basis: {basis}")

        # Bilateral CNOT: source -> target
        qc.cx(qr[a_s], qr[a_t])
        qc.cx(qr[b_s], qr[b_t])

        # Parity measurement into c0 (sticky across rounds)
        _measure_parity_sticky(qc, qr[a_t], qr[b_t], c)

        # Conditional correction if both original outcomes are 1
        with qc.if_test((c[1], 1)):
            with qc.if_test((c[0], False)):
                qc.z(qr[a_s])

        # Optional Bob-side rotation after each round
        rot = _round_rotation(level)
        _apply_pauli(qc, [qr[b_s]], rot)

    def _prepare_pairing(pair_list):
        pair_list = list(pair_list)
        if output_pair in pair_list:
            pair_list.remove(output_pair)
        if random_pairing and pair_list:
            rng.shuffle(pair_list)
        return [output_pair] + pair_list

    current_pairs = list(pairs)
    level = 0
    while len(current_pairs) > 1:
        current_pairs = _prepare_pairing(current_pairs)
        next_pairs = []
        for i in range(0, len(current_pairs), 2):
            source = current_pairs[i]
            target = current_pairs[i + 1]
            dejmps_round(source, target, level)
            next_pairs.append(source)
        current_pairs = next_pairs
        level += 1

    return qc


def _ccx_decomposed(qc: QuantumCircuit, a: int, b: int, c: int) -> None:
    # Toffoli decomposition using only 1Q + CNOT gates.
    qc.h(c)
    qc.cx(b, c)
    qc.tdg(c)
    qc.cx(a, c)
    qc.t(c)
    qc.cx(b, c)
    qc.tdg(c)
    qc.cx(a, c)
    qc.t(b)
    qc.t(c)
    qc.h(c)
    qc.cx(a, b)
    qc.t(a)
    qc.tdg(b)
    qc.cx(a, b)


def _cswap_local(qc: QuantumCircuit, control: int, t1: int, t2: int) -> None:
    # Controlled-SWAP using decomposed Toffolis (all qubits local).
    _ccx_decomposed(qc, control, t1, t2)
    _ccx_decomposed(qc, control, t2, t1)
    _ccx_decomposed(qc, control, t1, t2)


def create_causal_order_circuit(
    num_bell_pairs: int = 4,
    alignment_gates=(),
    pre_rotation_pauli: str = "I",
    round_rotations=None,
    round_bases=None,
):
    # Coherent-control causal-order protocol (PhysRevA.108.062601) for N=4.
    # Pair layout: pair0=control, pair1/pair2 swapped, pair3 main output.
    if num_bell_pairs != 4:
        raise ValueError("causal-order protocol requires num_bell_pairs=4")

    qr = QuantumRegister(2 * num_bell_pairs, "q")
    c = ClassicalRegister(3, "c")  # c0: sticky parity flag, c1: Bob meas, c2: unused
    qc = QuantumCircuit(qr, c)

    pairs = [(i, 2 * num_bell_pairs - 1 - i) for i in range(num_bell_pairs)]
    control_pair = pairs[0]
    pair1 = pairs[1]
    pair2 = pairs[2]
    pair3 = pairs[3]  # output pair

    bob_qubits = list(range(num_bell_pairs, 2 * num_bell_pairs))

    # Optional alignment + pre-rotation on Bob side for all pairs
    _apply_alignment_gates(qc, [qr[q] for q in bob_qubits], alignment_gates)
    _apply_pauli(qc, [qr[q] for q in bob_qubits], pre_rotation_pauli)

    def _round_basis(level: int) -> str:
        if not round_bases:
            return "dejmps"
        if 0 <= level < len(round_bases):
            return round_bases[level]
        return "dejmps"

    def _round_rotation(level: int) -> str:
        if not round_rotations:
            return "I"
        if 0 <= level < len(round_rotations):
            return round_rotations[level]
        return "I"

    # Coherent control: swap pair1 and pair2 on each side
    _cswap_local(qc, control_pair[0], pair1[0], pair2[0])
    _cswap_local(qc, control_pair[1], pair1[1], pair2[1])

    def _apply_dejmps_basis(a_s, a_t, b_s, b_t):
        for q in (a_s, a_t):
            qc.sdg(qr[q])
            qc.h(qr[q])
        for q in (b_s, b_t):
            qc.s(qr[q])
            qc.h(qr[q])

    def dejmps_round(source_pair, target_pair, level: int):
        a_s, b_s = source_pair
        a_t, b_t = target_pair

        basis = _round_basis(level)
        if basis == "dejmps":
            _apply_dejmps_basis(a_s, a_t, b_s, b_t)
        elif basis != "identity":
            raise ValueError(f"Unknown basis: {basis}")

        # Bilateral CNOT: source -> target
        qc.cx(qr[a_s], qr[a_t])
        qc.cx(qr[b_s], qr[b_t])

        # Parity measurement into c0 (sticky across rounds)
        _measure_parity_sticky(qc, qr[a_t], qr[b_t], c)

        # Conditional correction if both original outcomes are 1
        with qc.if_test((c[1], 1)):
            with qc.if_test((c[0], False)):
                qc.z(qr[a_s])

        # Optional Bob-side rotation after each round
        rot = _round_rotation(level)
        _apply_pauli(qc, [qr[b_s]], rot)

    # Two DEJMPS steps with pair3 as the main pair
    dejmps_round(pair3, pair2, 0)
    dejmps_round(pair3, pair1, 1)

    # Hadamard on control pair and parity post-selection into c0
    qc.h(qr[control_pair[0]])
    qc.h(qr[control_pair[1]])
    _measure_parity_sticky(qc, qr[control_pair[0]], qr[control_pair[1]], c)

    return qc



## Coherent causal-order protocol (N=4)

Based on *Phys. Rev. A 108, 062601 (2023)*, this uses one Bell pair as a control that
coherently swaps two other pairs before two DEJMPS steps. After the two steps, the
control pair is measured in the X basis and we postselect even parity (flag=0).


In [41]:
import re

def claim_edge_with_qasm(client, edge_id, circuit, flag_bit, num_bell_pairs, patch_flag: bool = True):
    qasm = qasm3.dumps(circuit)
    if patch_flag:
        # Replace OpenQASM3 negation with explicit '== false' for server parser compatibility.
        qasm = re.sub(r"if \(!c\[(\d+)\]\)", r"if (c[\1] == false)", qasm)
    payload = {
        "player_id": client.player_id,
        "edge": [edge_id[0], edge_id[1]],
        "num_bell_pairs": int(num_bell_pairs),
        "circuit_qasm": qasm,
        "flag_bit": int(flag_bit),
    }
    return client._post("/v1/claim_edge", payload)


In [42]:
# Build circuit and claim the target edge (increase N on failure)
if not claimable_sorted:
    print("No claimable edges to attempt.")
else:
    status = client.get_status()
    budget = int(status.get("budget", 0)) if status else 0
    candidate_ns = sorted([n for n in (2, 4, 8) if n <= budget])

    if USE_CAUSAL_ORDER_PROTOCOL and not ALLOW_N8_WITH_DEJMPS:
        candidate_ns = [n for n in candidate_ns if n <= 4]

    if not candidate_ns:
        print("Budget too low for selected protocol (need at least 2 bell pairs).")
    else:
        edge_id = tuple(target_edge["edge_id"])

        for idx, num_bell_pairs in enumerate(candidate_ns):
            use_causal = USE_CAUSAL_ORDER_PROTOCOL and num_bell_pairs == 4

            plan_for_n = None
            if row:
                plan_for_n = plan_dejmps_from_error_row(
                    row,
                    candidate_num_bell_pairs=[num_bell_pairs],
                    alignment_mode=ALIGNMENT_MODE,
                    objective=OBJECTIVE,
                    success_threshold=MIN_SUCCESS_PROB,
                )

            if plan_for_n:
                alignment_gates = plan_for_n.alignment_gates
                pre_rotation = plan_for_n.pre_rotation_pauli
                round_rotations = plan_for_n.round_rotations
                round_bases = plan_for_n.round_bases
                if use_causal:
                    print(
                        f"N={num_bell_pairs} using causal-order protocol "
                        f"(alignment/pre-rotation from DEJMPS plan)."
                    )
                else:
                    print(
                        f"N={num_bell_pairs} expected fidelity: {plan_for_n.expected_fidelity:.4f} | "
                        f"expected success: {plan_for_n.expected_success_probability:.4f}"
                    )
            else:
                alignment_gates = ()
                pre_rotation = "I"
                rounds = {2: 1, 4: 2, 8: 3}[num_bell_pairs]
                round_rotations = ["I"] * rounds
                round_bases = ["dejmps"] * rounds

            if use_causal:
                circuit = create_causal_order_circuit(
                    num_bell_pairs=num_bell_pairs,
                    alignment_gates=alignment_gates,
                    pre_rotation_pauli=pre_rotation,
                    round_rotations=round_rotations,
                    round_bases=round_bases,
                )
                protocol_name = "causal-order (PhysRevA.108.062601)"
            else:
                circuit = create_dejmps_circuit(
                    num_bell_pairs=num_bell_pairs,
                    alignment_gates=alignment_gates,
                    pre_rotation_pauli=pre_rotation,
                    round_rotations=round_rotations,
                    round_bases=round_bases,
                    random_pairing=False,
                )
                protocol_name = "DEJMPS"

            flag_bit = 0

            PRINT_CIRCUIT = False
            if PRINT_CIRCUIT:
                print(circuit.draw(output="text"))

            print(
                f"Claiming {edge_id} with N={num_bell_pairs} using {protocol_name} "
                f"(threshold: {target_edge.get('base_threshold', 0):.3f})"
            )
            result = claim_edge_with_qasm(client, edge_id, circuit, flag_bit, num_bell_pairs, patch_flag=PATCH_FLAG_SYNTAX)

            if not result.get("ok"):
                print(f"Error: {result.get('error', {}).get('message')}")
                break

            data = result.get("data", {})
            print(f"Success: {data.get('success')}")
            print(f"Fidelity: {data.get('fidelity', 0):.4f} (threshold: {data.get('threshold', 0):.4f})")
            print(f"Success probability: {data.get('success_probability', 0):.4f}")

            if data.get("success"):
                break
            if idx < len(candidate_ns) - 1:
                print("Failed attempt, increasing N..")





Claiming ('Boston, MA', 'Cambridge, MA') with N=2 using DEJMPS (threshold: 0.900)
Success: False
Fidelity: 0.7225 (threshold: 0.9000)
Success probability: 1.0000
Failed attempt, increasing N..
Claiming ('Boston, MA', 'Cambridge, MA') with N=4 using causal-order (PhysRevA.108.062601) (threshold: 0.900)
Success: False
Fidelity: 0.5886 (threshold: 0.9000)
Success probability: 1.0000


In [43]:
# Status and leaderboard
client.print_status()
leaderboard = client.get_leaderboard().get("leaderboard", [])
print("Leaderboard:")
for i, p in enumerate(leaderboard[:10]):
    print(f"{i+1}. {p.get('player_id', 'Unknown'):20} Score: {p.get('score', 0)}")


Player: bobbyKat123456 (bobbyKat87)
Score: 0 | Budget: 40 bell pairs
Active: Yes
Starting node: Cambridge, MA
Owned: 1 nodes, 0 edges
Claimable edges: 6
  - ['Boston, MA', 'Cambridge, MA']: threshold=0.90, difficulty=1
  - ['Cambridge, MA', 'Halifax, Canada']: threshold=0.90, difficulty=3
  - ['Cambridge, MA', 'Providence, RI']: threshold=0.90, difficulty=1
  ... and 3 more
Leaderboard:
1. bloch                Score: 92
2. bloch_distiller      Score: 92
3. Ram_23356_new_version1 Score: 83
4. test_claude_1523b    Score: 78
5. Munich_id            Score: 77
6. willwantsROIv3andbudgetblock Score: 61
7. EEE3                 Score: 60
8. okaywillisveryawakenow Score: 51
9. Ram_23356_new        Score: 48
10. willwantsROIv3       Score: 45


In [44]:
# Flag diagnostics: probe how the server interprets flag bits
if not claimable_sorted:
    print("No claimable edges for diagnostics.")
else:
    status = client.get_status()
    budget = int(status.get("budget", 0)) if status else 0
    diag_edge = max(claimable_sorted, key=lambda e: e.get("base_threshold", 0))
    edge_id = tuple(diag_edge["edge_id"])

    num_bell_pairs = 2 if budget >= 2 else 1
    if num_bell_pairs < 1:
        print("Budget too low for diagnostics.")
    else:
        print(f"Diagnostics edge: {edge_id} (threshold: {diag_edge.get('base_threshold', 0):.3f})")

        def _base_circuit(n):
            qr = QuantumRegister(2 * n, "q")
            c = ClassicalRegister(3, "c")
            qc = QuantumCircuit(qr, c)
            return qc, qr, c

        tests = []

        # 1) No flag measurement at all (baseline)
        qc, qr, c = _base_circuit(num_bell_pairs)
        tests.append(("no_flag", qc, 2))

        # 2) Measure an entangled qubit into flag
        qc, qr, c = _base_circuit(num_bell_pairs)
        qc.measure(qr[0], c[2])
        tests.append(("flag_meas_qubit", qc, 2))

        # 3) Reset -> measure (should be success_prob ~1.0 if reset honored)
        qc, qr, c = _base_circuit(num_bell_pairs)
        qc.reset(qr[0])
        qc.measure(qr[0], c[2])
        tests.append(("flag_reset_meas", qc, 2))

        # 4) Reset + X -> measure (should be success_prob ~0.0 if reset honored)
        qc, qr, c = _base_circuit(num_bell_pairs)
        qc.reset(qr[0])
        qc.x(qr[0])
        qc.measure(qr[0], c[2])
        tests.append(("flag_reset_x_meas", qc, 2))

        # 5) Reset + H -> measure (should be ~0.5)
        qc, qr, c = _base_circuit(num_bell_pairs)
        qc.reset(qr[0])
        qc.h(qr[0])
        qc.measure(qr[0], c[2])
        tests.append(("flag_reset_h_meas", qc, 2))

        # 6) Parity/XOR into flag using re-measure (expected to break)
        qc, qr, c = _base_circuit(num_bell_pairs)
        qc.measure(qr[0], c[0])
        qc.measure(qr[1], c[1])
        with qc.if_test((c[0], 1)):
            qc.x(qr[0])
        with qc.if_test((c[1], 1)):
            qc.x(qr[0])
        qc.measure(qr[0], c[2])
        tests.append(("flag_xor_remeas", qc, 2))

        # 7) Same as (4) but with flag_bit=0
        qc, qr, c = _base_circuit(num_bell_pairs)
        qc.reset(qr[0])
        qc.x(qr[0])
        qc.measure(qr[0], c[0])
        tests.append(("flag_one_meas_bit0", qc, 0))

        # 8) Same as (4) but with flag_bit=1
        qc, qr, c = _base_circuit(num_bell_pairs)
        qc.reset(qr[0])
        qc.x(qr[0])
        qc.measure(qr[0], c[1])
        tests.append(("flag_one_meas_bit1", qc, 1))

        for name, circ, flag_bit in tests:
            result = claim_edge_with_qasm(
                client,
                edge_id,
                circ,
                flag_bit,
                num_bell_pairs,
                patch_flag=False,
            )
            if not result.get("ok"):
                print(f"{name}: error {result.get('error', {}).get('message')}")
                continue
            data = result.get("data", {})
            print(
                f"{name:22} | flag_bit={flag_bit} | success_prob={data.get('success_probability', 0):.4f} "
                f"| fidelity={data.get('fidelity', 0):.4f} | success={data.get('success')}"
            )

        # 9) Test classical XOR patch on a minimal circuit
        qc, qr, c = _base_circuit(num_bell_pairs)
        qc.measure(qr[0], c[0])
        qc.measure(qr[1], c[1])
        result = claim_edge_with_qasm(
            client,
            edge_id,
            qc,
            flag_bit=2,
            num_bell_pairs=num_bell_pairs,
            patch_flag=True,
        )
        if result.get("ok"):
            data = result.get("data", {})
            print(
                f"classical_or_xor_patch | flag_bit=2 | success_prob={data.get('success_probability', 0):.4f} "
                f"| fidelity={data.get('fidelity', 0):.4f} | success={data.get('success')}"
            )
        else:
            print(f"classical_or_xor_patch: error {result.get('error', {}).get('message')}")


Diagnostics edge: ('Boston, MA', 'Cambridge, MA') (threshold: 0.900)
no_flag                | flag_bit=2 | success_prob=1.0000 | fidelity=0.8500 | success=False
flag_meas_qubit        | flag_bit=2 | success_prob=0.5000 | fidelity=0.8500 | success=False
flag_reset_meas        | flag_bit=2 | success_prob=0.5000 | fidelity=0.8500 | success=False
flag_reset_x_meas      | flag_bit=2 | success_prob=0.5000 | fidelity=0.8500 | success=False
flag_reset_h_meas      | flag_bit=2 | success_prob=0.5000 | fidelity=0.8500 | success=False
flag_xor_remeas        | flag_bit=2 | success_prob=0.5000 | fidelity=0.4250 | success=False
flag_one_meas_bit0     | flag_bit=0 | success_prob=0.5000 | fidelity=0.8500 | success=False
flag_one_meas_bit1     | flag_bit=1 | success_prob=0.5000 | fidelity=0.8500 | success=False
classical_or_xor_patch | flag_bit=2 | success_prob=1.0000 | fidelity=0.4250 | success=False
