In [1]:
from qldpc_sim import *
from qldpc_sim.qldpc_experiment import *
from qldpc_sim.qec_code import *
from qldpc_sim.data_structure import *
from qldpc_sim.ckbb_surgery import *
import stim
import numpy as np

In [2]:
def setup_cnot_exp():
    patch1 = RSC3()
    patch2 = RSC3()
    patch3 = RSC3()
    qm = qldpc_experiment.QuantumMemory(size=600)
    patches = [patch1, patch2, patch3]
    mapqb = {}
    for c in patches:
        for lq in c.logical_qubits:
            mapqb[lq.logical_x] = c
            mapqb[lq.logical_z] = c

    ctx = Context(
        codes=[patch1, patch2, patch3], logical_qubits=patch1.logical_qubits + patch2.logical_qubits + patch3.logical_qubits, initial_assignement=mapqb, memory=qm
    )

    return ctx

In [3]:
def compile_and_sample(ctx, program, num_samples=10):
    compilers = []

    for p in program:
        compilers.extend(p.build_compiler_instructions())

    stim_instructions = []
    meas_tag = []

    for c in compilers:
        stim_instructions.append(f"# Compiler for {c.__class__.__name__}")
        si, tag = c.compile(ctx.memory)
        if tag:
            if not isinstance(tag, list):
                tag = [tag]
            for t in tag:
                ctx.record.add_event(t)
        # print(f"Compiler {c.__class__.__name__} produced {len(si)} stim instructions")
        # print(si)
        stim_instructions.extend(si)

    circ = stim.Circuit("\n".join(stim_instructions))

    sampler = circ.compile_sampler()

    samples = sampler.sample(num_samples)

    return samples

In [4]:
from collections import defaultdict
import re


def run(context, program, interest_list=None):
    num_samples = 1000
    samples = compile_and_sample(context, program, num_samples=num_samples)
    filtered_outcomes = defaultdict(list)
    global_idx = 0
    for r in context.record.events:
        for il in interest_list:
            if il in r.tag:
                start = global_idx
                for i, n in enumerate(r.measured_nodes):
                    if "final_" in r.tag:
                        filtered_outcomes[il].append(start + i)
                    elif "initial_" in r.tag:
                        filtered_outcomes[il].append(start + i)
                    elif "readout_" in r.tag:
                        filtered_outcomes[il].append(start + i)
                    elif "ckbb_stab_Merge" in r.tag:
                        if "_T" in n.tag:
                            # Extract round number and create separate key
                            round_match = re.search(r'round(\d+)', r.tag)
                            if round_match:
                                round_num = round_match.group(1)
                                key = f"{il}_round{round_num}"
                            else:
                                key = il
                            filtered_outcomes[key].append(start + i)
        global_idx += r.size
    out_array = np.array(samples)
    outcome = {}
    for k, v in filtered_outcomes.items():
        row_matter = out_array[:, v]
        outcome[k] = np.bitwise_xor.reduce(row_matter, axis=1)

    return outcome

In [5]:
context = setup_cnot_exp()

control = context.codes[0]
ancilla = context.codes[1]
target = context.codes[2]
print(ancilla.logical_qubits[0].logical_x)
print(target.logical_qubits[0].logical_x)

cnot_program = (
    [
        InitializeCode(
            code=control,
            context=context,
            tag=f"init_{control.id}",
            initial_state=PauliEigenState.Z_minus,
        )
    ]
    + [
        InitializeCode(
            code=target,
            context=context,
            tag=f"init_{target.id}",
            initial_state=PauliEigenState.Z_plus,
        )
    ]
    + [
        InitializeCode(
            code=ancilla,
            context=context,
            tag=f"init_{ancilla.id}",
            initial_state=PauliEigenState.X_plus,
        )
    ]
    + [
        StabMeasurement(code=p, context=context, tag=f"isb_2_{p.id}", round=3)
        for p in [control, ancilla, target]
    ]
    + [
        CKBBMeasurement(
            distance=3,
            context=context,
            tag="Merge_C-A",
            logical_target=[
                control.logical_qubits[0].logical_z,
                ancilla.logical_qubits[0].logical_z,
            ],
        ),
    ]
    + [
        StabMeasurement(code=p, context=context, tag=f"isb_2_{p.id}", round=3)
        for p in [control, ancilla, target]
    ]
    + [
        CKBBMeasurement(
            distance=3,
            context=context,
            tag="Merge_T-A",
            logical_target=[
                ancilla.logical_qubits[0].logical_x,
                target.logical_qubits[0].logical_x,
            ],
        ),
    ]
    + [
        StabMeasurement(code=p, context=context, tag=f"isb_3_{p.id}", round=3)
        for p in [control, ancilla, target]
    ]
    + [
        LM(
            logical_target=[ancilla.logical_qubits[0].logical_z],
            context=context,
            tag=f"final_ancilla",
            basis=PauliChar.X,
        ),
        LM(
            logical_target=[target.logical_qubits[0].logical_z],
            context=context,
            tag=f"final_target",
            basis=PauliChar.Z,
        ),
        LM(
            logical_target=[control.logical_qubits[0].logical_z],
            context=context,
            tag=f"final_control",
            basis=PauliChar.Z,
        ),
    ]
)

logical_type=<PauliChar.X: 'X'> operator=PauliString(string=(<PauliChar.X: 'X'>, <PauliChar.X: 'X'>, <PauliChar.X: 'X'>)) target_nodes=(VariableNode(id=UUID('b763aff2-0a6a-46f9-ae54-5a77886d920c'), tag='v_0_RSC3'), VariableNode(id=UUID('fa29ab3e-d3d7-45f1-8ab8-59d9a57a0617'), tag='v_3_RSC3'), VariableNode(id=UUID('03e53170-edd9-4ddd-9882-8dec03ae1d85'), tag='v_6_RSC3'))
logical_type=<PauliChar.X: 'X'> operator=PauliString(string=(<PauliChar.X: 'X'>, <PauliChar.X: 'X'>, <PauliChar.X: 'X'>)) target_nodes=(VariableNode(id=UUID('361791a0-f298-45a9-94f0-081818188ba7'), tag='v_0_RSC3'), VariableNode(id=UUID('a83d822a-d2fa-46da-a8ca-8b45f519dc88'), tag='v_3_RSC3'), VariableNode(id=UUID('e9356916-9a24-4d30-a713-a1acd9ba57a3'), tag='v_6_RSC3'))


In [6]:
outcome = run(
    context=context,
    program=cnot_program,
    interest_list=[
        "ckbb_stab_Merge_T-A",
        "ckbb_stab_Merge_C-A",
        "final_ancilla",
        "final_control",
        "final_target",
        "readout_ckbb_ancilla",
    ],
)

# Check all rounds are equal and use the latest one
merged_outcomes = {}
processed_keys = set()

for key in outcome.keys():
    if "round" in key:
        # Extract base key (without round number)
        base_key = re.sub(r"_round\d+", "", key)
        if base_key not in processed_keys:
            processed_keys.add(base_key)
            # Get all round keys for this base key
            all_round_keys = sorted(
                [k for k in outcome.keys() if k.startswith(base_key + "_round")]
            )
            print(f"\n{base_key}:")
            print(f"  Rounds found: {all_round_keys}")

            # Check that all rounds are equal
            all_equal = True
            for i in range(len(all_round_keys) - 1):
                key_a = all_round_keys[i]
                key_b = all_round_keys[i + 1]
                if not np.array_equal(outcome[key_a], outcome[key_b]):
                    all_equal = False
                    diff_count = np.sum(np.bitwise_xor(outcome[key_a], outcome[key_b]))
                    print(f"  ✗ {key_a} != {key_b}: {diff_count} differences")
                else:
                    print(f"  ✓ {key_a} == {key_b}")

            if all_equal:
                print(f"  ✓ All rounds are equal")
            else:
                print(f"  ⚠ WARNING: Rounds differ!")

            # Use the latest round
            latest_round_key = all_round_keys[-1]
            merged_outcomes[base_key] = outcome[latest_round_key].copy()
    else:
        # Keep non-round keys as-is
        merged_outcomes[key] = outcome[key]

# Use merged outcomes
merge_CA = np.where(merged_outcomes["ckbb_stab_Merge_C-A"], 1, 0)
merge_TA = np.where(merged_outcomes["ckbb_stab_Merge_T-A"], 1, 0)
control = np.where(merged_outcomes["final_control"], 1, 0)
target = np.where(merged_outcomes["final_target"], 1, 0)
ancilla = np.where(merged_outcomes["final_ancilla"], 1, 0)
ro = np.where(merged_outcomes["readout_ckbb_ancilla"], 1, 0)

from collections import Counter

# X on target if Merge_C-A == 1
target_corrected = (target + merge_CA + ancilla) % 2
control_corrected = (control) % 2
ct = [f"{int(c)}{int(t)}" for c, t in zip(control_corrected, target_corrected)]
print(f"corr ct: {Counter(ct)}")
bitstrings = [
    f"{int(control)}{int(target)}" for control, target in zip(control, target)
]

bitstrings2 = [
    f"{bitstrings}{int(ancilla)}"
    for bitstrings, ancilla in zip(bitstrings, ancilla)
]
bitstrings2 = [
    f"{bitstrings2}{int(merge_CA)}"
    for bitstrings2, merge_CA in zip(bitstrings2, merge_CA)
]
bitstrings2 = [
    f"{bitstrings2}{int(merge_TA)}"
    for bitstrings2, merge_TA in zip(bitstrings2, merge_TA)
]

# bitstrings2 = [f"{bitstrings2}{int(ro)}" for bitstrings2, ro in zip(bitstrings2, ro)]

print("ct: ", Counter(bitstrings))
print("ct MZZ A: ", Counter(bitstrings2))
print(len(Counter(bitstrings2).keys()))


ckbb_stab_Merge_C-A:
  Rounds found: ['ckbb_stab_Merge_C-A_round0', 'ckbb_stab_Merge_C-A_round1', 'ckbb_stab_Merge_C-A_round2']
  ✓ ckbb_stab_Merge_C-A_round0 == ckbb_stab_Merge_C-A_round1
  ✓ ckbb_stab_Merge_C-A_round1 == ckbb_stab_Merge_C-A_round2
  ✓ All rounds are equal

ckbb_stab_Merge_T-A:
  Rounds found: ['ckbb_stab_Merge_T-A_round0', 'ckbb_stab_Merge_T-A_round1', 'ckbb_stab_Merge_T-A_round2']
  ✓ ckbb_stab_Merge_T-A_round0 == ckbb_stab_Merge_T-A_round1
  ✓ ckbb_stab_Merge_T-A_round1 == ckbb_stab_Merge_T-A_round2
  ✓ All rounds are equal
corr ct: Counter({'11': 504, '10': 496})
ct:  Counter({'11': 513, '10': 487})
ct MZZ A:  Counter({'10100': 72, '11010': 69, '11111': 69, '11101': 67, '10010': 64, '11011': 63, '11001': 63, '11100': 62, '10110': 61, '11000': 61, '10000': 61, '10111': 61, '10101': 61, '11110': 59, '10011': 55, '10001': 52})
16
