In [1]:
from qldpc_sim import *
from qldpc_sim.qldpc_experiment import *
from qldpc_sim.qec_code import *
from qldpc_sim.data_structure import *
from qldpc_sim.ckbb_surgery import *
import stim
import numpy as np

In [2]:
def setup_cnot_exp():
    patch1 = RSC3(name="p1")
    patch2 = RSC3(name="p2")
    qm = qldpc_experiment.QuantumMemory(size=600)
    patches = [patch1, patch2]
    mapqb = {}
    for c in patches:
        for lq in c.logical_qubits:
            mapqb[lq.logical_x] = c
            mapqb[lq.logical_z] = c

    ctx = Context(
        codes=[patch1, patch2],
        logical_qubits=patch1.logical_qubits + patch2.logical_qubits,
        initial_assignement=mapqb,
        memory=qm,
    )

    return ctx

In [3]:
def compile_and_sample(ctx, program, num_samples=10):
    compilers = []

    for p in program:
        compilers.extend(p.build_compiler_instructions())

    stim_instructions = []
    meas_tag = []

    for c in compilers:
        stim_instructions.append(f"# Compiler for {c.tag}")
        si, tag = c.compile(ctx.memory)
        if tag:
            if not isinstance(tag, list):
                tag = [tag]
            for t in tag:
                ctx.record.add_event(t)
        # print(f"Compiler {c.__class__.__name__} produced {len(si)} stim instructions")
        # print(si)
        
        stim_instructions.extend(si)
    for si in stim_instructions:
        print(si)
    circ = stim.Circuit("\n".join(stim_instructions))

    sampler = circ.compile_sampler()

    samples = sampler.sample(num_samples)

    return samples

In [4]:
from collections import defaultdict
import re


def run(context, program, num_samples=1000):
    samples = compile_and_sample(context, program, num_samples=num_samples)
    outcomes = defaultdict(dict)
    global_idx = 0
    for r in context.record.events:
        for n in r.measured_nodes:
            outcomes[r.tag][n.tag] = global_idx
            global_idx += 1
    out_array = np.array(samples)
    sample_value = outcomes.copy()
    for k, v in outcomes.items():
        for n, idx in v.items():
            sample_value[k][n] = out_array[:, idx]
    return sample_value

In [5]:
from qldpc_sim.qldpc_experiment.qec_gadget import Readout

def get_joinm_program(initial_state1, initial_state2, joint_pauli, readout_basis):
    context = setup_cnot_exp()

    p1 = context.codes[0]
    p2 = context.codes[1]
    if joint_pauli == "XX":
        ckbbm = [
            CKBBMeasurement(
                distance=3,
                context=context,
                tag="CKBBM",
                logical_target=[
                    p2.logical_qubits[0].logical_x,
                    p1.logical_qubits[0].logical_x,
                ],
            ),
        ]
    else:
        ckbbm = [
            CKBBMeasurement(
                distance=3,
                context=context,
                tag="CKBBM",
                logical_target=[
                    p2.logical_qubits[0].logical_z,
                    p1.logical_qubits[0].logical_z,
                ],
            ),
        ]
    return (
        [
            InitializeCode(
                code=p1,
                context=context,
                tag=f"init_{p1.id}",
                initial_state=initial_state1,
            )
        ]
        + [
            InitializeCode(
                code=p2,
                context=context,
                tag=f"init_{p2.id}",
                initial_state=initial_state2,
            )
        ]
        + [
            StabMeasurement(code=p, context=context, tag=f"isb_2_{p.id}", round=3)
            for p in [p1, p2]
        ]
        + ckbbm
        + [
            StabMeasurement(code=p, context=context, tag=f"isb_3_{p.id}", round=3)
            for p in [p1, p2]
        ]
        + [
            LM(
                logical_target=[p1.logical_qubits[0].logical_x if readout_basis == "X" else p1.logical_qubits[0].logical_z],
                context=context,
                tag=f"p1",
                basis=PauliChar.X if readout_basis == "X" else PauliChar.Z,
                reset_qubits=False,
            ),
            LM(
                logical_target=[p2.logical_qubits[0].logical_x if readout_basis == "X" else p2.logical_qubits[0].logical_z],
                context=context,
                tag=f"p2",
                basis=PauliChar.X if readout_basis == "X" else PauliChar.Z,
                reset_qubits=False,
            ),
            Readout(
                code=p1,
                context=context,
                tag=f"readout_{p1.id}",
                basis=PauliChar.X if readout_basis == "X" else PauliChar.Z,
            ),
            Readout(
                code=p2,
                context=context,
                tag=f"readout_{p2.id}",
                basis=PauliChar.X if readout_basis == "X" else PauliChar.Z,
            ),
        ]
    ), context

In [6]:
from functools import reduce
import operator


def xor_event_nodes(events: dict, event_name: str, tag: str | None = None) -> dict:
    """
    Compute per-sample XOR across selected nodes using ±1 encoding:
        False -> +1
        True  -> -1

    XOR becomes multiplication.
    """

    if event_name not in events:
        raise ValueError(f"Event '{event_name}' not found")

    event_nodes = events[event_name]

    # Select nodes by tag in name
    selected = [
        values
        for node_name, values in event_nodes.items()
        if tag is None or tag in node_name
    ]

    if not selected:
        return {event_name + (tag if tag else ""): []}

    result = []

    for sample_values in zip(*selected):
        prod = 1
        for v in sample_values:
            # False → +1, True → -1
            prod *= -1 if v else 1
        result.append(prod)

    return {event_name + (tag if tag else ""): result}


def concat_events_per_sample(event_results: dict) -> dict:
    """
    Convert:
        {event_name: [values per sample]}
    into:
        {sample_index: [values across events]}
    """

    if not event_results:
        return {}

    event_names = list(event_results.keys())
    num_samples = len(next(iter(event_results.values())))

    result = {}

    for i in range(num_samples):
        result[f"sample{i}"] = [event_results[event][i] for event in event_names]

    return result

In [7]:
from collections import Counter

prog, ctx = get_joinm_program(
    initial_state1=PauliEigenState.Z_plus,
    initial_state2=PauliEigenState.Z_plus,
    joint_pauli="XX",
    readout_basis="Z",
)
sample_outcome = run(context=ctx, program=prog, num_samples=1000)

mxx_r0 = xor_event_nodes(sample_outcome, "ckbb_stab_CKBBM_round0", tag="_T")
mxx_r0_all = xor_event_nodes(sample_outcome, "ckbb_stab_CKBBM_round0", tag="_c")
# mxx_r1 = xor_event_nodes(sample_outcome, "ckbb_stab_CKBBM_round1", tag="_T")
# mxx_r1_all = xor_event_nodes(sample_outcome, "ckbb_stab_CKBBM_round1", tag="_c")
# mxx_r2 = xor_event_nodes(sample_outcome, "ckbb_stab_CKBBM_round2", tag="_T")
# mxx_r2_all = xor_event_nodes(sample_outcome, "ckbb_stab_CKBBM_round2", tag="_c")
p1 = xor_event_nodes(sample_outcome, "p1", tag=None)
p2 = xor_event_nodes(sample_outcome, "p2", tag=None)

cps = concat_events_per_sample(
    {
        **p1,
        **p2,
        **mxx_r0
    }
)

print("p1, p2, mxx, mxx_r1")
# for sample, values in cps.items():
#     if not all(x == values[0] for x in values[:3]):
#         print(f"*** Inconsistent MXX outcomes in {sample}: {values[:3]} ***")
#     else:
#         print(f"{sample}: {values[2:]}")

# Convert lists to tuples for hashing
counter = Counter(tuple(v) for v in cps.values())

print(counter)
print(len(counter.keys()))

# Compiler for InitializeCode_init_f9f069de-ad5c-486b-99f2-4dbf9f9abbce
RZ 0
RZ 1
RZ 2
RZ 3
RZ 4
RZ 5
RZ 6
RZ 7
RZ 8
RZ 9
RZ 10
RZ 11
RZ 12
RZ 13
RZ 14
RZ 15
RZ 16
# Compiler for InitializeCode_init_b9ab7fc1-4d09-4d62-9d7e-54b3e95a142b
RZ 17
RZ 18
RZ 19
RZ 20
RZ 21
RZ 22
RZ 23
RZ 24
RZ 25
RZ 26
RZ 27
RZ 28
RZ 29
RZ 30
RZ 31
RZ 32
RZ 33
# Compiler for StabMeasurement_isb_2_f9f069de-ad5c-486b-99f2-4dbf9f9abbce
REPEAT 3 {
    RZ 1
    RZ 2
    RZ 4
    RZ 7
    RZ 8
    RZ 12
    RZ 15
    RZ 16
    # Stab: c_1_p1, type: CheckType.X
    H 1
    CX 1 13
    CX 1 5
    H 1
    # Stab: c_5_p1, type: CheckType.Z
    CX 3 2
    CX 13 2
    CX 14 2
    CX 5 2
    # Stab: c_2_p1, type: CheckType.X
    H 4
    CX 4 3
    CX 4 0
    CX 4 14
    CX 4 10
    H 4
    # Stab: c_3_p1, type: CheckType.X
    H 7
    CX 7 0
    CX 7 11
    H 7
    # Stab: c_4_p1, type: CheckType.Z
    CX 6 8
    CX 9 8
    # Stab: c_7_p1, type: CheckType.Z
    CX 3 12
    CX 10 12
    # Stab: c_0_p1, type: CheckType.X
   