### Test inputs

In [231]:
inputs = [
    'aaa: you hhh',
    'you: bbb ccc',
    'bbb: ddd eee',
    'ccc: ddd eee fff',
    'ddd: ggg',
    'eee: out',
    'fff: out',
    'ggg: out',
    'hhh: ccc fff iii',
    'iii: out'
]

## Task 1
Build a graph

In [232]:
from collections import defaultdict
from functools import cache

# Version of graph that takes also limiting list as input
class Graph:
    def __init__(self, graph, end_node, node_list=None):
        self.graph = graph
        self.end_node = end_node
        self.node_list = node_list
    
    @cache
    def DFS(self, node):        
        # If node is out exit
        if node == self.end_node:
            return 1

        output = 0
        for child in self.graph[node]:
            if self.node_list is None:
                output += self.DFS(child)
            elif child in self.node_list:
                output += self.DFS(child)
            
        return output

graph = defaultdict(list)
keys = []
for row in inputs:
    row = row.split(" ")
    key = row[0].strip(':')
    graph[key] = row[1:]

graphh = Graph(graph, end_node="out")
visited = []
graphh.DFS('you')


5

## Task 2

In [233]:
inputs = [
    'svr: aaa bbb',
    'aaa: fft',
    'fft: ccc',
    'bbb: tty',
    'tty: ccc',
    'ccc: ddd eee',
    'ddd: hub',
    'hub: fff',
    'eee: dac',
    'dac: fff',
    'fff: ggg hhh',
    'ggg: out',
    'hhh: out',
]


def find_reach(graph_inv, start):
    seen = set()
    stack = [start]
    while stack:
        node = stack.pop()
        if node in seen:
            continue
        seen.add(node)
        for parent in graph_inv.get(node, []):
            stack.append(parent)
    return seen

In [234]:

with open("input") as f:
    inputs = f.read().splitlines() 

graph = defaultdict(list)
graph_inv = defaultdict(list)
keys = []
for row in inputs:
    row = row.split(" ")
    key = row[0].strip(':')
    graph[key] = row[1:]
    
    for key_inv in row[1:]:
        graph_inv[key_inv] += [key]    
        
# Find nodes that can reach fft, dac, and out
can_reach_fft = find_reach(graph_inv, "fft")
can_reach_dac = find_reach(graph_inv, "dac")
can_reach_out = find_reach(graph_inv, "out")


graphh = Graph(graph, "out", can_reach_out)
num_dac_out = graphh.DFS("dac")

graphh = Graph(graph, "fft", can_reach_fft)
num_svr_fft = graphh.DFS("svr")

graphh = Graph(graph, "dac",  can_reach_dac)
num_fft_dac = graphh.DFS("fft")

print(num_dac_out * num_svr_fft * num_fft_dac)

520476725037672


Sadly the node lists do not really speed up the process that much when using cache. But without cache those make the problem feasible. As below

In [244]:
class Graph:
    def __init__(self, graph, end_node, node_list):
        self.graph = graph
        self.end_node = end_node
        self.node_list = node_list
    
    def DFS(self, node, visited):     
        
        # If node is out exit
        if node == self.end_node:
            return 1
        
        visited.add(node)    

        output = 0
        for child in self.graph[node]:
            if child in self.node_list and child not in visited:
                output += self.DFS(child, visited)
                
        visited.remove(node)
        return output

In [246]:
graphh = Graph(graph, "fft", can_reach_fft)
num_fft = graphh.DFS("svr", set())

graphh = Graph(graph, "dac", can_reach_dac)
num_dac = graphh.DFS("fft", set())

graphh = Graph(graph, "out", can_reach_out)
num_out = graphh.DFS("dac", set())
print(num_fft * num_dac * num_out)


520476725037672
