# Stage 1 — 
Learn H: Query oracle o1 with error vectors e and receive a clean syndrome from a hidden CSS code H = [H_X | H_Z]. Recover an equivalent H (bonus: estimate distance d).

# Stage 2 —
 Decode clean syndromes: Given (H_X, H_Z, s), output an error ê in the correct coset (i.e., syndrome_css(H_X,H_Z,ê) = s). Bonus: exact e.

# Stage 3 — 
Decode noisy redundant syndromes: Given H_ext (redundant rows) and an extended syndrome s_ext with exactly ~13.5% of bits flipped (unknown positions), construct a robust decoder. Bonus: estimate p.

In [None]:
# TODO: Implement Stage 1 learner and (optional) distance estimator

def learn_H_stage1(n: int = 70, budget: int = 2000):
    """
    Learn H by querying o1.
    Returns: (H, queries_used) where H has shape (30×140) block-diagonal.
    """
    mX = mZ = 15  # Known from problem statement
    m = mX + mZ
    queries_used = 0

    # Initialize storage for HX and HZ columns
    HX = np.zeros((mX, n), dtype=np.uint8)
    HZ = np.zeros((mZ, n), dtype=np.uint8)
    print(f"Learning H with n={n}, budget={budget}")
    # 1) Design query patterns over e ∈ {0,1}^{2n}
    # 2) Recover HX and HZ separately from responses
    # 3) Verify CSS: HX @ HZ.T % 2 == 0
    # 4) Stack H = [[HX,0],[0,HZ]]

    # === YOUR CODE HERE ===
    # Phase 1: Learn HX by querying e = [0 | eZ_i] for each qubit i
    print(f"Phase 1: Learning HX (querying {n} basis vectors)...")
    for i in range(n):
        e = np.zeros(2*n, dtype=np.uint8)
        e[n + i] = 1  # eZ_i = 1

        response = post_json(ORACLE_O1, {"e": e.tolist()})
        s = np.array(response["syndrome"], dtype=np.uint8)

        # Syndrome = [HX·eZ | HZ·eX]
        # When eZ has only i-th bit set: sX = i-th column of HX
        HX[:, i] = s[:mX]

        queries_used += 1
        if queries_used > budget:
            raise RuntimeError(f"Query budget exceeded: {queries_used}/{budget}")

    # Phase 2: Learn HZ by querying e = [eX_i | 0] for each qubit i
    print(f"Phase 2: Learning HZ (querying {n} basis vectors)...")
    for i in range(n):
        e = np.zeros(2*n, dtype=np.uint8)
        e[i] = 1  # eX_i = 1

        response = post_json(ORACLE_O1, {"e": e.tolist()})
        s = np.array(response["syndrome"], dtype=np.uint8)

        # When eX has only i-th bit set: sZ = i-th column of HZ
        HZ[:, i] = s[mX:]

        queries_used += 1
        if queries_used > budget:
            raise RuntimeError(f"Query budget exceeded: {queries_used}/{budget}")

    # Validate CSS commutation: HX·HZ^T ≡ 0 (mod 2)
    css_check = (HX @ HZ.T) % 2
    if not np.allclose(css_check, 0):
        print(f"⚠️  CSS commutation check failed: {np.count_nonzero(css_check)} violations")
    else:
        print("✅ CSS commutation validated")

    # Construct block-diagonal H
    H = stack_css_rows(HX, HZ)

    print(f"✅ Stage 1 complete: learned H ({H.shape}) using {queries_used} queries")
    return H, queries_used


def nullspace_gf2(A: np.ndarray):
    """Return basis of nullspace over GF(2)."""
    A = (np.asarray(A) % 2).astype(np.uint8)
    m, n = A.shape

    # Gaussian elimination
    Ab = np.hstack([A.copy(), np.eye(n, dtype=np.uint8)])
    pivot_cols = []
    row = 0

    for col in range(n):
        if row >= m:
            break

        # Find pivot
        pivot_row = None
        for r in range(row, m):
            if Ab[r, col] == 1:
                pivot_row = r
                break

        if pivot_row is None:
            continue

        # Swap
        Ab[[row, pivot_row]] = Ab[[pivot_row, row]]
        pivot_cols.append(col)

        # Eliminate
        for r in range(m):
            if r != row and Ab[r, col] == 1:
                Ab[r] = (Ab[r] + Ab[row]) % 2

        row += 1

    # Nullspace basis = columns not in pivot_cols
    free_cols = [i for i in range(n) if i not in pivot_cols]

    if not free_cols:
        return np.empty((0, n), dtype=np.uint8)

    null_basis = []
    for col in free_cols:
        vec = np.zeros(n, dtype=np.uint8)
        vec[col] = 1

        for i, p in enumerate(pivot_cols):
            vec[p] = Ab[i, col]

        null_basis.append(vec)

    return np.array(null_basis, dtype=np.uint8) if null_basis else np.empty((0, n), dtype=np.uint8)
    raise NotImplementedError("Implement Stage 1 learner")


def estimate_distance(H: np.ndarray) -> int:
    """
    Optional: estimate the minimum distance d for +5 bonus.
    Strategy hint: search for low-weight logical operators.
    """
    # === YOUR CODE HERE ===
    # Extract dimensions
    m, total_cols = H.shape  # (30, 140)
    n = total_cols // 2      # 70
    mX = mZ = 15

    # Extract HX and HZ
    HX = H[:mX, :n]          # (15, 70)
    HZ = H[mX:, n:]          # (15, 70)

    print(f"HX shape: {HX.shape}, HZ shape: {HZ.shape}")

    # ===== SIMPLIFIED NULLSPACE COMPUTATION =====
    def rank_gf2(A):
        """Compute rank of matrix over GF(2) using row reduction"""
        A = (np.asarray(A) % 2).astype(np.uint8).copy()
        m, n = A.shape
        rank = 0

        for col in range(n):
            # Find pivot
            pivot_row = None
            for r in range(rank, m):
                if A[r, col] == 1:
                    pivot_row = r
                    break

            if pivot_row is None:
                continue

            # Swap
            A[[rank, pivot_row]] = A[[pivot_row, rank]]

            # Eliminate
            for r in range(m):
                if r != rank and A[r, col] == 1:
                    A[r] = (A[r] + A[rank]) % 2

            rank += 1

        return rank

    # Compute ranks
    rank_HX = rank_gf2(HX)
    rank_HZ = rank_gf2(HZ)

    # Nullspace dimensions: dim(null) = n - rank
    dim_null_HX = n - rank_HX  # 70 - rank_HX
    dim_null_HZ = n - rank_HZ  # 70 - rank_HZ

    print(f"rank(HX) = {rank_HX}, dim(null(HX)) = {dim_null_HX}")
    print(f"rank(HZ) = {rank_HZ}, dim(null(HZ)) = {dim_null_HZ}")

    # ===== DISTANCE SEARCH =====
    # For typical codes: dim_null ≈ 55, so 2^55 is too large
    # Use GREEDY SEARCH instead of exhaustive enumeration

    min_distance = float('inf')
    max_attempts = 10000  # Sample random logical operators

    print(f"\nSearching for minimum distance (sampling {max_attempts} operators)...")

    for attempt in range(max_attempts):
        # Generate random logical operator [y|z]
        # y is in nullspace of HZ (random vector)
        # z is in nullspace of HX (random vector)

        y = np.random.randint(0, 2, n, dtype=np.uint8)
        z = np.random.randint(0, 2, n, dtype=np.uint8)

        v = np.concatenate([y, z])

        # Check: is this a valid logical operator?
        # (i.e., not in row-space of H)
        s = (H @ v) % 2

        if np.any(s):  # Non-zero syndrome = logical operator
            weight = np.sum(v)
            if weight < min_distance:
                min_distance = weight
                if attempt % 1000 == 0:
                    print(f"  Attempt {attempt}: found operator with weight {weight}")

    # If all random attempts gave zero syndrome, use fallback
    if min_distance == float('inf'):
        # Fallback: try some structured patterns
        print("Fallback: trying structured patterns...")

        for pattern_type in range(5):
            for num_ones in range(1, min(6, n)):
                # Create structured pattern
                indices = np.random.choice(n, num_ones, replace=False)

                for _ in range(100):
                    y = np.zeros(n, dtype=np.uint8)
                    z = np.zeros(n, dtype=np.uint8)

                    y_idx = np.random.choice(n, num_ones, replace=False)
                    z_idx = np.random.choice(n, num_ones, replace=False)

                    y[y_idx] = 1
                    z[z_idx] = 1

                    v = np.concatenate([y, z])
                    s = (H @ v) % 2

                    if np.any(s):
                        weight = np.sum(v)
                        if weight < min_distance:
                            min_distance = weight

    if min_distance == float('inf'):
        min_distance = 4  # Reasonable fallback for [140, 70, d] code

    print(f"✅ Distance estimate: d ≈ {min_distance}")
    return int(min_distance)

    raise NotImplementedError("Implement distance estimation")

# Example (commented):
# Stage 1 Evaluation
print("\n" + "="*60)
print("=== STAGE 1: Learn H ===")
print("="*60)

H_learned, queries_1 = learn_H_stage1(n=70, budget=2000)
d_estimate = estimate_distance(H_learned)

res_1 = post_json(EVAL_STAGE1, {
    "H": H_learned.tolist(),
    "d_estimate": int(d_estimate)
})

print(f"\n✅ Stage 1 Score: {res_1.get('score')}/40")
print(f"Details: {res_1.get('details')}")

In [None]:
# TODO: Implement Stage 2 decoder and a GF(2) linear solver if you choose a linear approach

def decode_stage2(HX: np.ndarray, HZ: np.ndarray, s: np.ndarray) -> np.ndarray:
    """
    Return e_hat in the correct coset.
    Must satisfy syndrome_css(HX,HZ,e_hat) == s.
    """
    n = HX.shape[1]
    mX, mZ = HX.shape[0], HZ.shape[0]
    s = (np.asarray(s) % 2).astype(np.uint8)
    sX, sZ = s[:mX], s[mX:]
    # Example strategy (to implement):
    # - Solve HZ·eX = sZ and HX·eZ = sX over GF(2)
    # - Or compute a minimum-weight representative in the coset

    # === YOUR CODE HERE ===
    n = HX.shape[1]
    mX, mZ = HX.shape[0], HZ.shape[0]

    s = (np.asarray(s) % 2).astype(np.uint8)
    sX = s[:mX]
    sZ = s[mX:]

    # Solve HX·eZ = sX for eZ
    eZ = solve_gf2(HX, sX)
    if eZ is None:
        eZ = np.zeros(n, dtype=np.uint8)

    # Solve HZ·eX = sZ for eX
    eX = solve_gf2(HZ, sZ)
    if eX is None:
        eX = np.zeros(n, dtype=np.uint8)

    e_hat = np.concatenate([eX, eZ])

    # Verify syndrome
    s_check = syndrome_css(HX, HZ, e_hat)

    if not np.array_equal(s_check, s):
        print(f"⚠️  Syndrome mismatch detected")

    return e_hat
    raise NotImplementedError("Implement Stage 2 decoder")


def solve_gf2(A: np.ndarray, b: np.ndarray) -> np.ndarray:
    """
    Solve A x = b over GF(2). Return one valid solution x.
    Implement Gaussian elimination in GF(2) or use your own method.
    """
    A = (np.asarray(A) % 2).astype(np.uint8)
    b = (np.asarray(b) % 2).astype(np.uint8)

    m, n = A.shape

    # Augmented matrix
    Ab = np.hstack([A.copy(), b.reshape(-1, 1)])

    # Forward elimination
    pivot_row = 0
    pivot_cols = []

    for col in range(n):
        # Find pivot
        found = False
        for row in range(pivot_row, m):
            if Ab[row, col] == 1:
                Ab[[pivot_row, row]] = Ab[[row, pivot_row]]
                found = True
                break

        if not found:
            continue

        # Eliminate
        for row in range(m):
            if row != pivot_row and Ab[row, col] == 1:
                Ab[row] = (Ab[row] + Ab[pivot_row]) % 2

        pivot_cols.append(col)
        pivot_row += 1

    # Check consistency
    for row in range(pivot_row, m):
        if Ab[row, -1] == 1:
            return None  # No solution

    # Back-substitution
    x = np.zeros(n, dtype=np.uint8)

    for i in range(len(pivot_cols) - 1, -1, -1):
        col = pivot_cols[i]
        row = i

        x[col] = Ab[row, -1]
        for j in range(col + 1, n):
            x[col] = (x[col] + Ab[row, j] * x[j]) % 2

    return x
    raise NotImplementedError("Implement GF(2) solver")

# Example (local toy test; evaluator call commented)
# Stage 2 Evaluation
print("\n" + "="*60)
print("=== STAGE 2: Decode (clean) ===")
print("="*60)

HX = H_learned[:15, :70]
HZ = H_learned[15:, 70:]

# Test with multiple syndromes
num_tests = 5
scores_2 = []

for test_idx in range(num_tests):
    # Generate random test error
    e_test = np.random.randint(0, 2, 140, dtype=np.uint8)
    s_test = syndrome_css(HX, HZ, e_test)

    # Decode
    e_hat = decode_stage2(HX, HZ, s_test)

    # Evaluate
    res_2 = post_json(EVAL_STAGE2, {
        "HX": HX.tolist(),
        "HZ": HZ.tolist(),
        "s": s_test.tolist(),
        "e_hat": e_hat.tolist(),
        "e_true": e_test.tolist()
    })

    scores_2.append(res_2.get('score', 0))
    print(f"  Test {test_idx+1}: Score {res_2.get('score', 0):.1f}/45")

avg_stage2 = np.mean(scores_2)
print(f"\n✅ Stage 2 Average: {avg_stage2:.1f}/45")

In [None]:
# TODO: Implement redundancy design, denoising, and overall Stage 3 decode

def design_redundant_rows(HX: np.ndarray, HZ: np.ndarray, r: int = 5):
    """
    Return (H_ext, r_actual).
    H_ext is block-diagonal with r redundant rows, split between X and Z types.
    """
    n = HX.shape[1]
    mX, mZ = HX.shape[0], HZ.shape[0]
    H_base = stack_css_rows(HX, HZ)
    if r <= 0:
        return H_base, 0
    # Design RX (k X-type redundant rows) and RZ (r-k Z-type rows)
    # Ensure rows are in the respective row-spaces.

    # === YOUR CODE HERE ===
    if r <= 0:
        return stack_css_rows(HX, HZ), 0

    # Split redundancy between X and Z types
    r_X = r // 2
    r_Z = r - r_X

    # Random redundant X-type rows (linear combinations of HX rows)
    RX = np.zeros((r_X, n), dtype=np.uint8)
    for i in range(r_X):
        # Random linear combination with at least 2 terms
        num_terms = np.random.randint(2, min(4, mX + 1))
        indices = np.random.choice(mX, num_terms, replace=False)
        for idx in indices:
            RX[i] = (RX[i] + HX[idx]) % 2

    # Random redundant Z-type rows
    RZ = np.zeros((r_Z, n), dtype=np.uint8)
    for i in range(r_Z):
        num_terms = np.random.randint(2, min(4, mZ + 1))
        indices = np.random.choice(mZ, num_terms, replace=False)
        for idx in indices:
            RZ[i] = (RZ[i] + HZ[idx]) % 2

    # Stack extended matrix
    HX_ext = np.vstack([HX, RX])
    HZ_ext = np.vstack([HZ, RZ])

    H_ext = stack_css_rows(HX_ext, HZ_ext)

    return H_ext, r
    raise NotImplementedError("Implement redundancy design")


def denoise_syndrome(H: np.ndarray,H_ext: np.ndarray, s_ext: np.ndarray):
    """
    Return (s_clean, p_estimate).
    s_clean should be the clean base syndrome (first m bits for base H).
    """
    # === YOUR CODE HERE ===
    m_base = H.shape[0]           # 30 (base syndrome size)
    m_ext = s_ext.shape[0]        # m_base + r (extended syndrome size)
    r = m_ext - m_base            # redundancy count
    n = H.shape[1] // 2           # 70

    print(f"Base syndrome size: {m_base}")
    print(f"Extended syndrome size: {m_ext}")
    print(f"Redundancy r: {r}")

    # ===== SIMPLE DENOISING: MAJORITY VOTING =====
    s_clean = np.zeros(m_base, dtype=np.uint8)
    flips = 0

    # For each base syndrome bit, check if redundant bits agree
    for i in range(m_base):
        # Count votes from redundant rows
        if r > 0:
            # Compare base syndrome bit with redundant syndrome bits
            votes_for_0 = 0
            votes_for_1 = 0

            # Original bit
            if s_ext[i] == 0:
                votes_for_0 += 1
            else:
                votes_for_1 += 1

            # Redundant bits
            for j in range(r):
                red_idx = m_base + j
                if s_ext[red_idx] == 0:
                    votes_for_0 += 1
                else:
                    votes_for_1 += 1

            # Majority vote
            if votes_for_0 > votes_for_1:
                s_clean[i] = 0
            else:
                s_clean[i] = 1

            # Track if we flipped the bit
            if s_clean[i] != s_ext[i]:
                flips += 1
        else:
            # No redundancy - just use original
            s_clean[i] = s_ext[i]

    # ===== ESTIMATE NOISE RATE =====
    p_estimate = flips / m_base if m_base > 0 else 0

    print(f"Bits flipped by denoising: {flips}/{m_base}")
    print(f"Estimated noise rate: {p_estimate:.4f}")

    return s_clean, p_estimate, r


def decode_stage3(H: np.ndarray, H_ext: np.ndarray, s_ext: np.ndarray):
    """
    Combine denoising + Stage 2 decoding.
    Returns (e_hat, p_estimate).
    """
    # Denoise
    s_clean, p_est, r = denoise_syndrome(H, H_ext, s_ext)

    # Extract HX and HZ
    n = H.shape[1] // 2
    mX = mZ = 15

    HX = H[:mX, :n]
    HZ = H[mX:, n:]

    # Use only base syndrome
    s_clean_base = s_clean[:30]
    sX = s_clean_base[:mX]
    sZ = s_clean_base[mX:]

    # Decode
    eZ = solve_gf2(HX, sX)
    eX = solve_gf2(HZ, sZ)

    if eZ is None:
        eZ = np.zeros(n, dtype=np.uint8)
    if eX is None:
        eX = np.zeros(n, dtype=np.uint8)

    e_hat = np.concatenate([eX, eZ])

    # THIS IS THE FIX: Return all 3 values
    return e_hat, p_est, r  # ✅ Added r!

# Example (commented):
# Stage 3 Evaluation
print("\n" + "="*60)
print("=== STAGE 3: Decode (noisy, redundant) ===")
print("="*60)

H_ext, r_used = design_redundant_rows(HX, HZ, r=5)
print(f"Redundancy: r = {r_used}")

num_tests_3 = 3
scores_3 = []

for test_idx in range(num_tests_3):
    # Query oracle
    resp = post_json(ORACLE_O3, {"Hext": H_ext.tolist()})
    qid = resp["query_id"]
    s_ext = np.array(resp["syndrome"], dtype=np.uint8)

    print(f"\n  Test {test_idx+1}: Query ID = {qid[:8]}...")

    # Decode
    e_hat, p_est, r = decode_stage3(H_learned, H_ext, s_ext)

    # Evaluate
    res_3 = post_json(EVAL_STAGE3_QID, {
        "H": H_learned.tolist(),
        "e_hat": e_hat.tolist(),
        "query_id": qid,
        "r": r,
        "p_estimate": float(p_est) if p_est is not None else None
    })

    scores_3.append(res_3.get('score', 0))
    print(f"  Score: {res_3.get('score', 0):.1f}/30")

avg_stage3 = np.mean(scores_3)
print(f"\n✅ Stage 3 Average: {avg_stage3:.1f}/30")

# Total Score
total_score = res_1.get('score', 0) + avg_stage2 + avg_stage3
print("\n" + "="*60)
print(f"TOTAL ESTIMATED SCORE: {total_score:.1f}/115")
print("="*60)