In [1]:
from functools import cache
from collections import defaultdict

import networkx as nx


def parse_input(path):
    graph = defaultdict(set)
    graph_rev = defaultdict(set)

    with open(path) as f:
        for line in f:
            src, dsts = line.strip().split(": ")
            for dst in dsts.split():
                graph[src].add(dst)
                graph_rev[dst].add(src)

    return graph, graph_rev


def count_paths(graph, start_node, end_node):

    @cache
    def dfs(node):
        if node == end_node:
            return 1
        total = 0
        for nxt in graph.get(node, []):
            total += dfs(nxt)
        return total

    return dfs(start_node)


def main(input_file, part):
    graph, graph_reversed = parse_input(input_file)
    if part == 1:
        n_paths = count_paths(graph_reversed, start_node="out", end_node="you")
    elif part == 2:
        n_paths = (
            count_paths(graph_reversed, start_node="out", end_node="dac") *
            count_paths(graph_reversed, start_node="dac", end_node="fft") *
            count_paths(graph_reversed, start_node="fft", end_node="svr")
            )
    return n_paths

In [2]:
assert main("example1.txt", part=1) == 5
main("input.txt", part=1)

523

In [3]:
assert main("example2.txt", part=2) == 2
main("input.txt", part=2)

517315308154944