In [1]:
from pathlib import Path
import re

from dataclasses import dataclass
from graphlib import TopologicalSorter
from collections import ChainMap
from functools import cache

In [2]:
init, ops_str = Path("data/24.txt").read_text().strip().split("\n\n")
l = {(s := x.split(": "))[0]: bool(int(s[1])) for x in init.split("\n")}

def swaps(r: str):
    d = [('vmv', 'z07'), ('kfm', 'z20'), ('hnv', 'z28'), ('hth', 'tqr')]
    d1 = dict(d)
    d2 = {k:v for v, k in d1.items()}
    return d1.get(r, d2.get(r, r))

@dataclass(eq=True, frozen=True)
class OP:
    a: str
    b: str
    r: str
    op: str


pattern = re.compile("([a-z0-9]+) (AND|XOR|OR) ([a-z0-9]+) -> ([a-z0-9]+)")
pops = {"XOR": "^", "OR": "|", "AND": "&"}


def p(s: str):
    m = pattern.match(s)
    return OP(a=m.group(1), b=m.group(3), r=swaps(m.group(4)), op=pops[m.group(2)])


ops = {(op := p(op_s)).r: op for op_s in ops_str.split("\n")}
nbits = max(int(k[1:]) for k in ops.keys() if k.startswith('z'))

In [3]:
def run(l, ops, x=None, y=None):
    if x is not None or y is not None:
        l = l.copy()
        
    if x is not None:
        for i in range(nbits):
            l[f"x{i:02d}"] = bool((x >> i) & 1)
            
    if y is not None:
        for i in range(nbits):
            l[f"y{i:02d}"] = bool((y >> i) & 1)
        
    graph = {op: {ops[x] for x in [op.a, op.b] if x in ops} for op in ops.values()}
    ts = TopologicalSorter(graph)
    static_order = list(ts.static_order())
    
    l = ChainMap({}, l)
    for op in static_order:
        match op.op:
            case "^":
                l[op.r] = l[op.a] ^ l[op.b]
            case "&":
                l[op.r] = l[op.a] & l[op.b]
            case "|":
                l[op.r] = l[op.a] | l[op.b]
    return int(
        "".join(
            [
                str(int(v))
                for k, v in sorted((x for x in l.items() if x[0].startswith('z')), key=lambda x: int(x[0][1:]), reverse=True)
                if k.startswith("z")
            ]
        ),
        2,
    )

In [4]:
for i in range(1, nbits):
    a = 2**i - 1
    for b in [0, 1]:
        c = run(l, ops, x=a, y=b)
        if c != a + b:
            print(i, c, a, b)
            break

In [5]:
@cache
def deps(s: str):
    if s.startswith('x') or s.startswith('y'):
        return {s}
    return deps(ops[s].a) | deps(ops[s].b)

def ideal_sum_deps(i):
    if i < nbits:
        return {f"x{j:02d}" for j in range(i+1)} | {f"y{j:02d}" for j in range(i+1)}
    return {f"x{j:02d}" for j in range(nbits)} | {f"y{j:02d}" for j in range(nbits)}


def candidate_wires_for_bit(i):
    target = ideal_sum_deps(i)
    # among all gate outputs, find those that match exactly
    return [w for w in ops.keys() if deps(w) <= target and f"x{i:02d}" in deps(w) and f"y{i:02d}"]

def run_with_intermediates(inputs: dict[str, bool], ops: dict[str, OP]) -> dict[str, bool]:
    """
    Evaluate the entire circuit, returning *all* wire values (including intermediate ones).
    :param inputs: mapping wire_name -> bool, for any wire that is an input (x??, y??, or leftover).
    :param ops: mapping wire_name -> OP (the gate producing that wire).
    :return: dictionary wire_name -> bool for *all* wires in the circuit.
    """
    # Step 1: Build a dependency graph for topological sort
    #         Each node is an OP, edges from input OPs to output OP
    graph = {}
    for op in ops.values():
        parents = []
        if op.a in ops: parents.append(ops[op.a])
        if op.b in ops: parents.append(ops[op.b])
        graph[op] = set(parents)
    
    ts = TopologicalSorter(graph)
    topo_order = list(ts.static_order())
    
    # We'll store computed wire values here:
    wirevals = dict(inputs)  # start with known inputs (x??, y??)
    
    # Evaluate gates in topological order
    for node in topo_order:
        A = wirevals[node.a]
        B = wirevals[node.b]
        match node.op:
            case '^': out = A ^ B
            case '&': out = A & B
            case '|': out = A | B
            case _:   raise ValueError(f"Unknown op: {node.op}")
        wirevals[node.r] = out
    
    return wirevals


def int_to_x_y_bits(x_val, y_val, nbits):
    d = {}
    for i in range(nbits):
        d[f"x{i:02d}"] = bool((x_val >> i) & 1)
        d[f"y{i:02d}"] = bool((y_val >> i) & 1)
    return d


def test_bit_candidates(ops, wires_to_test, bit_index, nbits=nbits):
    """
    Test the given candidate wires to see which one acts like the real 'bit_index' sum bit.
    """
    res = wires_to_test.copy()
    
    test_vectors = [
        (0, 1 << bit_index),
        (1 << bit_index, 0),
        (1 << bit_index, 1 << bit_index),
        ((1 << bit_index) - 1, 1 << bit_index),
        (1 << bit_index, (1 << bit_index) - 1),
        (2**bit_index - 2, 1),
    ]

    
    for x_val, y_val in test_vectors:
        inputs = int_to_x_y_bits(x_val, y_val, nbits=nbits)
        wirevals = run_with_intermediates(inputs, ops)
        
        # correct bit for x+y:
        sum_ = x_val + y_val
        correct_bit = bool((sum_ >> bit_index) & 1)
        
        for w in wires_to_test:
            val = wirevals[w]
            if val != correct_bit:
                res = [x for x in res if x != w]
        return res

In [12]:
','.join(sorted([k for x in [('vmv', 'z07'), ('kfm', 'z20'), ('hnv', 'z28'), ('hth', 'tqr')] for k in x]))

'hnv,hth,kfm,tqr,vmv,z07,z20,z28'