In [1]:
from pathlib import Path
import re
from dataclasses import dataclass, field
from itertools import zip_longest

In [2]:
registers, program = Path("data/17.txt").read_text().strip().split("\n\n")
program = [int(x) for x in program.split(": ")[1].split(",")]
registers = {
    r: int(l.split(": ")[1]) for l, r in zip(registers.split("\n"), ["A", "B", "C"])
}

instructions = {
    i: ins
    for i, ins in enumerate(["adv", "bxl", "bst", "jnz", "bxc", "out", "bdv", "cdv"])
}

In [3]:
@dataclass
class ProgramState:
    A: int
    B: int
    C: int

    def combo(self, operand: int):
        if operand <= 3:
            return operand
        elif operand == 4:
            return self.A
        elif operand == 5:
            return self.B
        elif operand == 6:
            return self.C
        raise ValueError(operand)

    def adv(self, operand: int):
        self.A = self.A // 2 ** self.combo(operand)

    def bxl(self, operand: int):
        self.B = self.B ^ operand

    def bst(self, operand: int):
        self.B = self.combo(operand) % 8

    def jnz(self, operand: int) -> int | None:
        if self.A == 0:
            return
        return operand

    def bxc(self, _: int):
        self.B = self.B ^ self.C

    def out(self, operand: int):
        return self.combo(operand) % 8

    def bdv(self, operand: int):
        self.B = self.A // 2 ** self.combo(operand)

    def cdv(self, operand: int):
        self.C = self.A // 2 ** self.combo(operand)


def run(A=None):
    reg = registers | {"A": A or registers["A"]}
    P = ProgramState(**reg)

    instruction_funcs = {
        "adv": P.adv,
        "bxl": P.bxl,
        "bst": P.bst,
        "jnz": P.jnz,
        "bxc": P.bxc,
        "bdv": P.bdv,
        "cdv": P.cdv,
        "out": P.out,
    }

    i = 0

    while i < len(program):
        instruction = instructions[program[i]]
        operand = program[i + 1]

        if instruction == "out":
            yield instruction_funcs[instruction](operand)
            i += 2
        else:
            res = instruction_funcs[instruction](operand)
            if res is not None:
                i = res
            else:
                i += 2

In [4]:
print(",".join(str(x) for x in run()))

1,6,7,4,3,0,5,0,6


In [5]:
def check_iter_equal_to(it, expected):
    for e in expected:
        val = next(it, None)
        if val != e:
            return False
    return next(it, None) is None


def bt(program, A, done):
    if len(program) == done:
        yield A
        return

    for a in range(8):
        if check_iter_equal_to(run(A * 8 + a), program[-(done + 1) :]):
            yield from bt(program, A * 8 + a, done + 1)


min(bt(program, 0, 0))

216148338630253