## Advent of Code 2021, Day 12

This might be a good example for depth-first search according to the classifier.

#### Input Processing

In [113]:
from collections import defaultdict
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Tuple, Optional

In [30]:
example1 = """start-A
start-b
A-c
A-b
b-d
A-end
b-end
""".splitlines()

In [31]:
example2 = """dc-end
HN-start
start-kj
dc-start
dc-HN
LN-dc
HN-end
kj-sa
kj-HN
kj-dc
""".splitlines()

In [32]:
example3 = """fs-end
he-DX
fs-he
start-DX
pj-DX
end-zg
zg-sl
zg-pj
pj-he
RW-he
fs-DX
pj-RW
zg-RW
start-pj
he-WI
zg-he
pj-fs
start-RW
""".splitlines()

In [33]:
with open("inputs/2021d12") as f:
  input = [line.strip() for line in  f.readlines()]

In [34]:
Graph = dict[str, list[str]]

def parse(input: str) -> Graph:
    """
    Returns a graph represented as a hashmap.
    
    Since the graph in question is (mostly) undirected, we automatically add backedges
    (i.e., if `A` is connected to `B`, we add both `A -> [B, ...]` and `B -> [A, ...]`) entries.
    The only exceptions are `start` and `end`, since those are source and sink nodes.
    """
    graph = defaultdict(list)

    for edge in input:
        (v1, v2) = edge.split("-")

        if v1 == "start" or v2 == "end":
            graph[v1].append(v2)
        elif v1 == "end" or v2 == "start":
            graph[v2].append(v1)
        else:
          graph[v1].append(v2)
          graph[v2].append(v1)

    return graph

In [35]:
example1 = parse(example1)
example2 = parse(example2)
example3 = parse(example3)
graph = parse(input)

In [39]:
graph

defaultdict(list,
            {'LA': ['sn', 'mo', 'zs', 'end'],
             'sn': ['LA', 'mo', 'mh', 'vx', 'RD', 'JQ'],
             'mo': ['LA', 'sn', 'mh', 'JQ', 'zs', 'RD'],
             'zs': ['LA', 'end', 'JI', 'mo', 'rk', 'JQ'],
             'RD': ['end', 'mo', 'sn'],
             'start': ['vx', 'mh', 'JQ'],
             'mh': ['mo', 'sn', 'JQ', 'vx'],
             'JI': ['zs'],
             'JQ': ['mo', 'mh', 'zs', 'vx', 'sn'],
             'rk': ['zs'],
             'vx': ['sn', 'mh', 'JQ']})

### Part 1: Depth-First Search

We can then use a standard depth-first search algorithm with an added condition:

- lowercase nodes can only be visited once (uppercase can be visited any number of times).

We will extract logic for this condition and maintaining paths into a separate dataclass `State`
to keep our DFS implementation clean.

In [180]:
@dataclass
class State:
    path: list[str] = field(default_factory=list)
    visited: set[str] = field(default_factory=set)
    # only relevant for p2, see below
    small_cave_exception: Optional[str] = None
    exception_visited_once: bool = False


    def conditionally_visit(self, node: str):
        # only relevant for p2
        if node == self.small_cave_exception:
            if self.exception_visited_once:
                self.visited.add(node)
            else:
                self.exception_visited_once = True
        # relevant for both p1 and p2
        elif node.islower():
            self.visited.add(node)
    
    def is_visited(self, node: str) -> bool:
        return node in self.visited
    
    def advance(self, node: str):
        self.path.append(node)

    def get_path(self) -> list[str]:
        return tuple(self.path)

In [181]:
def dfs(graph: Graph, small_cave_exception: Optional[str]=None) -> list[list[str]]:
    paths = []

    stack = [("start", State(small_cave_exception=small_cave_exception))]
    while stack:
        curr, state = stack.pop()

        if curr == "end":
            state.advance(curr)
            paths.append(state.get_path)
            continue

        if not state.is_visited(curr):
            state.conditionally_visit(curr)

            for next in graph[curr]:
                next_state = deepcopy(state)
                next_state.advance(curr)

                stack.append((next, next_state))
    
    return paths

In [182]:
def p1(graph: Graph) -> int:
    return len(dfs(graph))

In [183]:
assert p1(example1) == 10
assert p1(example2) == 19
assert p1(example3) == 226

In [186]:
assert p1(graph) == 4970

### Part 2: Apply DFS multiple times

For part 2, we can just re-run dfs for each small cave marked as a "visit twice" cave,
union the resulting paths and count them.

We added fields `small_cave_exception` and `exception_visited_once` to `State` and a 
`if node == self.small_cave_exception` branch to `State.conditionally_visit` + default parameter 
`small_cave_exception` to `dfs`.

In [142]:
def p2(graph: Graph) -> int:
    all_caves = set(graph.keys()).union({v for vs in graph.values() for v in vs})
    all_caves = all_caves.difference({"start", "end"})

    paths = set()
    for cave in all_caves:
        if cave.islower():
            paths = paths.union(dfs(graph, small_cave_exception=cave))
    return len(paths)

In [148]:
assert p2(example1) == 36
assert p2(example2) == 103
assert p2(example3) == 3509 

In [150]:
assert p2(graph) == 137948

## Upshot

It seems like a good example as well! Will need to check if we can provide a nice visualization for it though.