In [1]:
year = 2024; day = 17

In [2]:

from aocd import get_data, submit

data = get_data(year=year, day=day)
data = data.strip()
starting_registers, program = data.split("\n\n")

starting_registers = starting_registers.split("\n")
starting_registers = [int(r.split(":")[-1]) for r in starting_registers]

program = list(map(int, program.split(" ")[-1].strip().split(",")))
starting_registers

[55593699, 0, 0]

In [3]:
A = 0; B = 1; C = 2

instructions = {0: "ADV", 1: "BXL", 2: "BST", 3: "JNZ", 4: "BXC", 5: "OUT", 6: "BDV", 7: "CDV"}

combo_op = {
    0: lambda reg: 0,
    1: lambda reg: 1,
    2: lambda reg: 2,
    3: lambda reg: 3,
    4: lambda reg: reg[A],
    5: lambda reg: reg[B],
    6: lambda reg: reg[C],
    7: lambda reg: -1,
}

In [4]:
def single_cycle(pc, reg, program, print_debug=True):
    instr = program[pc]
    operand = program[pc + 1]
    assert instr < 8
    assert operand < 8
    out = None; ret = None; jumped = False; op = None
    match instr:
        case 0:
            op = combo_op[operand](reg)
            ret = reg[A] // (2 ** op)
            reg[A] = ret
        case 1:
            op = int(operand)
            ret = reg[B] ^ op
            reg[B] = ret
        case 2:
            op = combo_op[operand](reg)
            ret = op % 8
            reg[B] = ret
        case 3:
            op = int(operand)
            if reg[A] != 0:
                pc = op
                jumped = True
        case 4:
            op = operand  # ignored
            ret = reg[B] ^ reg[C]
            reg[B] = ret
        case 5:
            op = combo_op[operand](reg)
            out = op % 8
        case 6:
            op = combo_op[operand](reg)
            ret = reg[A] // (2 ** op)
            reg[B] = ret
        case 7:
            op = combo_op[operand](reg)
            ret = reg[A] // (2 ** op)
            reg[C] = ret
    instr_str = instructions[instr]
    ret = -1 if (ret is None) else ret
    out = -1 if (out is None) else out
    assert op is not None
    if print_debug:
        print(
            f"{instr_str}: {op:14d} -> {ret:14d} ({out:14d}) REG A: {reg[A]:14d}, REG B: {reg[B]:14d}, REG C: {reg[C]:14d}")
        if jumped: print()
    if not jumped:
        pc += 2
    return pc, reg, out


def run_program(reg_A, print_debug=True):
    reg = starting_registers.copy()
    reg[A] = reg_A
    prog = program.copy()
    outputs = []
    pc = 0
    while pc < len(prog):
        pc, reg, out = single_cycle(pc, reg, prog, print_debug)
        if out != -1: outputs.append(out)
    return outputs

In [5]:
output = run_program(starting_registers[A])
ans1 = ",".join(map(str, output))
submit(ans1, part="a", year=year, day=day)

BST:       55593699 ->              3 (            -1) REG A:       55593699, REG B:              3, REG C:              0
BXL:              3 ->              0 (            -1) REG A:       55593699, REG B:              0, REG C:              0
CDV:              0 ->       55593699 (            -1) REG A:       55593699, REG B:              0, REG C:       55593699
ADV:              3 ->        6949212 (            -1) REG A:        6949212, REG B:              0, REG C:       55593699
BXL:              5 ->              5 (            -1) REG A:        6949212, REG B:              5, REG C:       55593699
BXC:              4 ->       55593702 (            -1) REG A:        6949212, REG B:       55593702, REG C:       55593699
OUT:       55593702 ->             -1 (             6) REG A:        6949212, REG B:       55593702, REG C:       55593699
JNZ:              0 ->             -1 (            -1) REG A:        6949212, REG B:       55593702, REG C:       55593699

BST:        694

In [6]:
sols = set()
# init possible bits with all combos of first 3 bits
possible_bits = list(range(8))
# we loop 16 times, each time adding 3 bits to the input
for N in range(16):
    new_possible_bits = set()
    # loop over all successful bit strings so far
    for offset in possible_bits:
        # search 6 extra bids
        for i in range(1 << 6):
            val = offset + (i << (3 * N))
            res = run_program(val, print_debug=False)
            # check that result of program matches on first N outputs
            if len(res) >= N:
                if all(r == p for r, p in zip(res[:N], program[:N])):
                    mask = (1 << (3 * (N + 1))) - 1
                    # keep the part that was (partially) responsible for those N outputs
                    new_possible_bits.add(val & mask)
            # if exact match we found a solution
            if len(res) == len(program):
                if all(r == p for r, p in zip(res, program)):
                    sols.add(val)
    possible_bits = list(new_possible_bits)

In [7]:
ans2 = min(list(sols))

In [8]:
program

[2, 4, 1, 3, 7, 5, 0, 3, 1, 5, 4, 4, 5, 5, 3, 0]

In [9]:
run_program(ans2, print_debug=False)

[2, 4, 1, 3, 7, 5, 0, 3, 1, 5, 4, 4, 5, 5, 3, 0]

In [10]:
submit(ans2, part="b", year=year, day=day)

Part b already solved with same answer: 236539226447469
