In [None]:
import math
from collections import deque
from functools import cache

In [None]:
filename = "sample.txt"
# filename = "sample2.txt"
# filename = "input.txt"
with open(filename, encoding="utf-8") as f:
    data = f.read()

lines = data.strip().split("\n")

In [None]:
devices = {}
for l in lines:
    device, raw_out = l.split(":")
    outputs = raw_out.strip().split()
    devices[device] = outputs
# devices

In [None]:
# Exploration: Visualise the graph. Are there any obvious cycles?
# -> Looks like the connections originate from svr, and "you" is actually quite near "out"
def to_graphviz(devices: dict[str, list[str]]) -> str:
    # Helper to convert to graphviz edges
    # Paste into https://dreampuf.github.io/GraphvizOnline/
    out = []
    for k, v in devices.items():
        rhs = ("{" + " ".join(v) + "}")
        out.append(f"{k} -> {rhs}")
    return ";\n".join(out)

print(to_graphviz(devices))

In [None]:
## Part 1
# How many different paths lead from you to out?
# Probably need to count the unique ways to reach each intermediate node rather than following the full path for actual input

# parents of nodes {node: n_parents} ?
# OR cached+recursive dfs that returns the number of paths from that node to end

In [None]:
def n_paths(devices: dict[str, list[str]], start: str, end: str) -> int:
    # Exhaustive method, bfs with nothing smart
    paths = 0
    frontier = deque([start])
    while frontier:
        current = frontier.popleft()
        if current == end:
            paths += 1
            continue
        # Next nodes
        for n in devices[current]:
            frontier.append(n)
    return paths

result = n_paths(devices, "you", "out")
result

In [None]:
## Part 2
# Number of paths from svr -> out including fft and dac?
# Looking at the graph, fft should come before dac
# So result = (svr -> fft) * (fft -> dac) * (dac -> out)
# Need a better algorithm for svr -> fft and fft -> dac

In [None]:
# cached+recursive approach
@cache
def cached_paths(current: str, end: str) -> int:
    if current == end:
        return 1
    result = sum(cached_paths(child, end) for child in devices.get(current, []))
    return result

In [None]:
p2_path_nodes = ["svr", "fft", "dac", "out"]
p2_segments = [cached_paths(a, b) for a, b in zip(p2_path_nodes, p2_path_nodes[1:])]
p2_result = math.prod(p2_segments)
p2_result