Day 17: Chronospatial Computer

https://adventofcode.com/2024/day/17

In [1]:
from pathlib import Path
import sys
from itertools import batched

sys.path.append("..")

from tools import get_input


def parse(text: str) -> tuple[list[int], list[tuple[int, int]]]:
    regs_s, ops_s = text.split("\n\n")
    regs = [int(line.split(":")[1]) for line in regs_s.splitlines()]
    ops = list(batched([int(i) for i in ops_s.split(":")[1].split(",")], 2))
    return regs, ops


tst = parse((Path().parent / "test.txt").read_text(encoding="utf-8"))
inp = parse(get_input(17))

In [2]:
def solution_1(regs: list[int], ops: list[tuple[int, int]]) -> int:
    regs = regs.copy()
    p = [0]
    output = []

    def adv(op: int):  # 0
        regs[0] //= 2 ** (op if op < 4 else regs[op - 4])

    def bxl(op: int):  # 1
        regs[1] ^= op

    def bst(op: int):  # 2
        regs[1] = (op if op < 4 else regs[op - 4]) % 8

    def jnz(op: int):  # 3
        if regs[0] != 0:
            p[0] = (op // 2) - 1

    def bxc(_: int):  # 4
        regs[1] ^= regs[2]

    def out(op: int):  # 5
        output.append((op if op < 4 else regs[op - 4]) % 8)

    def bdv(op: int):  # 6
        regs[1] = regs[0] // 2 ** (op if op < 4 else regs[op - 4])

    def cdv(op: int):  # 7
        regs[2] = regs[0] // 2 ** (op if op < 4 else regs[op - 4])

    program = {0: adv, 1: bxl, 2: bst, 3: jnz, 4: bxc, 5: out, 6: bdv, 7: cdv}
    while p[0] < len(ops):
        program[ops[p[0]][0]](ops[p[0]][1])
        p[0] += 1
    return output, regs

In [3]:
# tests from description

assert solution_1([0, 0, 9], [(2, 6)])[1][1] == 1
assert solution_1([10, 0, 0], [(5, 0), (5, 1), (5, 4)])[0] == [0, 1, 2]
assert solution_1([2024, 0, 0], [(0, 1), (5, 4), (3, 0)]) == (
    [4, 2, 5, 6, 7, 7, 7, 7, 3, 1, 0],
    [0, 0, 0],
)
assert solution_1([0, 29, 0], [(1, 7)])[1][1] == 26
assert solution_1([0, 2024, 43690], [(4, 0)])[1][1] == 44354

In [4]:
assert solution_1(*tst)[0] == [4, 6, 3, 5, 6, 3, 5, 2, 1, 0]
",".join(str(i) for i in solution_1(*inp)[0])  # '7,3,5,7,5,7,4,3,0'

'7,3,5,7,5,7,4,3,0'

In [5]:
from heapq import heappush, heappop


# both programs comsume THE RIGHTMOST THREE BITS
# in register A each cycle until A = 0
# input can be retrieved by checking which A at time t produces which output
# start with A = [0, 7] and find the last opcode
# use all A in [0, 7] where output[0] != ops[-1]
# to find A* = A << 3 | [0, 7] where output[0] == ops[-2]
def solution_2(regs: list[int], ops: list[tuple[int, int]]) -> int:
    comp = sum(([a, b] for a, b in ops), [])
    heap = [(-len(comp) + 1, 0)]
    while heap:
        p, x = heappop(heap)
        for i in range(8):
            nx = (x << 3) | i
            regs[0] = nx
            if solution_1(regs, ops)[0][0] == comp[-p]:
                if p == 0:
                    return nx
                else:
                    heappush(heap, (p + 1, nx))
    return -1


tst2 = parse((Path().parent / "test2.txt").read_text(encoding="utf-8"))
assert solution_2(*tst2) == 117440
solution_2(*inp)  # 105734774294938

105734774294938