In [33]:
from aocd.models import Puzzle

puzzle = Puzzle(year=2021, day=12)

def parses(input):
    graph = defaultdict(list)
    for src, dst in [line.split('-') for line in input.strip().split('\n')]:
        if dst != 'start' and src != 'end':
            graph[src].append(dst)
        if src != 'start' and dst != 'end':
            graph[dst].append(src)
    return graph

data = parses(puzzle.input_data)

In [34]:
samples = [
"""start-A
start-b
A-c
A-b
b-d
A-end
b-end""",
"""dc-end
HN-start
start-kj
dc-start
dc-HN
LN-dc
HN-end
kj-sa
kj-HN
kj-dc""",
"""fs-end
he-DX
fs-he
start-DX
pj-DX
end-zg
zg-sl
zg-pj
pj-he
RW-he
fs-DX
pj-RW
zg-RW
start-pj
he-WI
zg-he
pj-fs
start-RW"""]
samples = [parses(s) for s in samples]
solutions_a = [10, 19, 226]

In [35]:
def solve_a(graph):
    valid_paths = 0
    path_stack = [('start', 'start')]
    seen_paths = set([hash('start')])
    
    while path_stack:
        node, path = path_stack.pop()
        for child in graph[node]:
            if child.upper() == child or child not in path:
                new_path = path + ',' + child
                h = hash(new_path)
                if h not in seen_paths:
                    seen_paths.add(h)
                    if child == 'end':
                        valid_paths += 1
                    else:
                        path_stack.append((child, new_path))
    return valid_paths
                    

In [36]:
solve_a(samples[0])

10

In [37]:
solve_a(samples[1])

19

In [38]:
solve_a(samples[2])

226

In [39]:
solve_a(data)

3708

In [117]:
# No need to worry about cycles since a uppercase-cycle would
# lead to infinite paths so it's not a valid input
def solve(graph, part='a'):
    valid_paths = 0
    path_stack = [('start', 'start', part=='b')]
    
    while path_stack:
        node, path, wildcard = path_stack.pop()
        for child in graph[node]:
            can_visit = child.isupper() or child not in path
            if can_visit or wildcard:
                if child == 'end':
                    valid_paths += 1
                else:
                    child_wildcard = wildcard and can_visit
                    new_path = path + ',' + child
                    path_stack.append((child, new_path, child_wildcard))

    return valid_paths                

In [118]:
%%timeit
solve(data), solve(data, 'b')

225 ms ± 10.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [119]:
def solve(graph, part="a"):
    valid_paths = 0
    path_stack = [("start", ["start"], part == "b")]

    while path_stack:
        node, path, wildcard = path_stack.pop()
        for child in graph[node]:
            can_visit = child.isupper() or child not in path
            if can_visit or wildcard:
                if child == "end":
                    valid_paths += 1
                else:
                    child_wildcard = wildcard and can_visit
                    new_path = path + [child]
                    path_stack.append((child, new_path, child_wildcard))
    return valid_paths


In [120]:
%%timeit
solve(data), solve(data, 'b')

278 ms ± 16.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [93]:
def solve(graph, part='a'):
    valid_paths = 0
    path_stack = [('start', 'start', part=='b')]
    seen_paths = set([hash('start')])
    
    while path_stack:
        node, path, wildcard = path_stack.pop()
        for child in graph[node]:
            valid = child.isupper() or child not in path
            if valid or wildcard:
                new_path = path + ',' + child
                if (h := hash(new_path)) not in seen_paths:
                    seen_paths.add(h)
                    if child == 'end':
                        valid_paths += 1
                    else:
                        child_wildcard = wildcard and valid
                        path_stack.append((child, new_path, child_wildcard))
    return valid_paths
                    

In [96]:
solve(data), solve(data, 'b')

(3708, 93858)

In [92]:
"HDASD".isupper()

True

In [103]:
import itertools

In [104]:
def run_step(energy):
    N, M = energy.shape
    energy += 1
    flashes = 0
    toflash = np.argwhere(energy > 9).tolist()
    while toflash:
        x, y = toflash.pop()
        flashes += 1
        for dx, dy in itertools.product([-1,0,1], repeat=2):
            x2, y2 = x+dx, y+dy
            if 0 <= x2 < N and 0 <= y2 < M and energy[x2,y2] < 10:
                energy[x2,y2] += 1
                if energy[x2,y2] == 10:
                    toflash.append([x2,y2])
    energy[energy > 9] = 0
    return flashes

In [105]:
def solve_a(energy, steps):
    energy = np.array(energy)
    return sum(run_step(energy) for k in range(steps))

In [106]:
tiny = parses("""11111
19991
19191
19991
11111""")

In [107]:
solve_a(tiny, 2)

9

In [108]:
solve_a(sample, 10)

204

In [109]:
solve_a(sample, 100)

1656

In [110]:
solve_a(data, 100)

1757

In [111]:
def solve_b(energy):
    energy = np.array(energy)
    for k in itertools.count(1):
        run_step(energy)
        if (energy == 0).all():
            return k

In [64]:
solve_b(sample)

195

In [65]:
solve_b(data)

422