In [1]:
import itertools
from collections import defaultdict
from pathlib import Path

import numpy as np


In [2]:
with Path("../08.in").open() as f:
    data = f.read().splitlines()


In [3]:
testdata = """\
RL

AAA = (BBB, CCC)
BBB = (DDD, EEE)
CCC = (ZZZ, GGG)
DDD = (DDD, DDD)
EEE = (EEE, EEE)
GGG = (GGG, GGG)
ZZZ = (ZZZ, ZZZ)""".splitlines()


## Part I

In [4]:
def parse(data):
    instructions = [int(char == "R") for char in data[0]]

    mapping = {}
    for line in data[2:]:
        pos = line[0:3]
        left = line[7:10]
        right = line[12:15]
        mapping[pos] = (left, right)

    return instructions, mapping


In [5]:
def solve1(instructions, mapping, start="AAA", target="ZZZ"):
    pos = start
    for step, direction in enumerate(itertools.cycle(instructions), start=1):
        pos = mapping[pos][direction]
        if pos == target:
            break
    return step


In [6]:
instructions, mapping = parse(data)
steps = solve1(instructions, mapping)

print(f"Part 1: {steps}")


Part 1: 17141


# Part II

- Setup
  - Nodes ending on `A` are now starting nodes
  - Process all at the same time
  - End is reached when all (not just some) of them end on a `Z` node
- Observations
  - Apparently it is 1:1 mapping between source and target nodes
  - It also seems to be cyclic
  - Hence we just need the 1st hit and calculate the LCM of all starting nodes
    - LCM = Least Common Multiple (see `np.lcm` or `math.lcm`)

In [7]:
testdata2 = """\
LR

11A = (11B, XXX)
11B = (XXX, 11Z)
11Z = (11B, XXX)
22A = (22B, XXX)
22B = (22C, 22C)
22C = (22Z, 22Z)
22Z = (22B, 22B)
XXX = (XXX, XXX)""".splitlines()


In [8]:
def solve2(instructions, mapping, max_steps=100_000):
    start_nodes = [node for node in mapping.keys() if node[-1] == "A"]
    node_to_steps: dict[list] = defaultdict(list)

    for node in start_nodes:
        current_node = node
        for step, direction in enumerate(itertools.cycle(instructions), start=1):
            current_node = mapping[current_node][direction]
            if current_node[-1] == "Z":
                node_to_steps[node].append(step)
            if step > max_steps:
                break
    node_to_steps = dict(node_to_steps)

    for node, steps in node_to_steps.items():
        distances = [steps[i+1] - steps[i] for i in range(len(steps)-1)]
        is_cyclic = len(set(distances)) == 1
        assert is_cyclic
        print(f"{node} -> {steps}")

    min_steps = [steps[0] for steps in node_to_steps.values()]
    solution = np.lcm.reduce(min_steps)

    return solution


In [9]:
instructions, mapping = parse(testdata2)
solve2(instructions, mapping, max_steps=10)


11A -> [2, 4, 6, 8, 10]
22A -> [3, 6, 9]


6

In [10]:
instructions, mapping = parse(data)
solve2(instructions, mapping, max_steps=100_000)


AAA -> [17141, 34282, 51423, 68564, 85705]
XQA -> [16579, 33158, 49737, 66316, 82895, 99474]
SKA -> [18827, 37654, 56481, 75308, 94135]
NQA -> [12083, 24166, 36249, 48332, 60415, 72498, 84581, 96664]
LJA -> [13207, 26414, 39621, 52828, 66035, 79242, 92449]
NVA -> [22199, 44398, 66597, 88796]


10818234074807

I wonder how long the naive version would have taken to run :D

$10.818.234.074.807$ iterations are no fun I suppose