In [53]:
import functools
import operator
import itertools
import math
import sys
import json
import re
from enum import Enum
from collections import deque
from dataclasses import dataclass
from typing import Dict, List, Tuple, TypeAlias, Set
import pulp
C2: TypeAlias = Tuple[int, int]
C3: TypeAlias = Tuple[int, int, int]
Grid:TypeAlias = List[List[str]]

data = open("input.txt").read().splitlines()
grid: Grid = [[x for x in line] for line in data]

directions = {
    "^":(-1,0),
    ">":(0,1),
    "v":(1,0),
    "<":(0,-1),
    "x":(0,0),
}
direction_symbol = {v:k for k,v in directions.items()}

start: C2
end: C2
min_r = 1
max_r = len(grid) - 2
min_c = 1
max_c = len(grid[0]) - 2

def add(a,b):
    return tuple([x + y for x,y in zip(a,b)])

def add_to_list_at_key(dict, key, item):
    ls = dict.get(key, [])
    ls.append(item)
    dict[key] = ls

class Blizzard():
    def __init__(self, pos: C2, dir: C2) -> None:
        self.pos = pos
        self.dir = dir
        self.symbol = direction_symbol[dir]
    def next(self, get: lambda pos : str) -> "Blizzard":
        (r,c) = add(self.pos, self.dir)
        if r < min_r:
            return Blizzard((max_r, c), self.dir)
        if r > max_r:
            return Blizzard((min_r, c), self.dir)
        if c < min_c:
            return Blizzard((r, max_c), self.dir)
        if c > max_c:
            return Blizzard((r, min_c), self.dir)
        return Blizzard((r,c), self.dir)

class Map():
    def __init__(self, grid: Grid):
        self.blizzards: Dict[C2,List[Blizzard]] = {}
        self.start = None
        self.end = None
        self.exps = set()
        self.height = len(grid)
        self.width = len(grid[0])
        for r,row in enumerate(grid):
            for c, col in enumerate(row): 
                pos = (r,c)
                if r == 0 and col == ".":
                    self.start = pos
                    self.exps.add(pos)
                if r == self.height - 1 and col == ".":
                    self.end = pos
                if dir := directions.get(col):
                    add_to_list_at_key(self.blizzards, pos, Blizzard(pos, dir))

    def __iter__(self):
        return iter([(r,c,self.get((r,c))) for r in range(0, self.height) for c in range(0, self.width)])

    def get(self, pos: C2, include_exp = False):
        r,c = pos
        if pos in self.exps and include_exp:
            return "E"
        if pos == self.end or pos == self.start:
            return "."
        if b := self.blizzards.get(pos):
            if len(b) > 1:
                return str(len(b))
            return b[0].symbol
        if r == 0 or r == self.height - 1:
            return "#"
        if c == 0 or c == self.width - 1:
            return "#"
        return "."

    def __repr__(self):
        out = ""
        for r in range(0, self.height):
            for c in range(0, self.width):
                out += self.get((r,c), include_exp=True)
            out += "\n"
        return out
    
    def step(self):
        next_blizzards = {}
        for pos, blizzards in self.blizzards.items():
            for b in blizzards:
                n = b.next(self.get)
                add_to_list_at_key(next_blizzards, n.pos, n)
        self.blizzards = next_blizzards

def compute_path_time(m: Map, start: C2, end: C2):
    m.start = start
    m.end = end
    pending = set([start])
    t = 0
    done = False
    while not done:
        t += 1
        active = set()
        m.step()
        for pos in pending:
            for d in directions.values():
                next_pos = add(d, pos)
                if next_pos == end:
                    return t
                if m.get(next_pos) == ".":
                    active.add(next_pos)
        pending = active
        m.exps = active
        #print(f"time = {t}")
        #print(m)
        if t % 100 == 0:
            print(f"At time {t} we have {len(active)} active heads")


def part_1():
    m = Map(grid)
    answer = compute_path_time(m, m.start, m.end)
    print(f"part 1 -> {answer}")
#part_1()

def part_2():
    m = Map(grid)
    start = m.start
    end = m.end
    one = compute_path_time(m, start, end)
    print(f"one -> {one}")
    two = compute_path_time(m, end, start)
    print(f"two -> {two}")
    three = compute_path_time(m, start, end)
    print(f"three -> {three}")
    answer = one + two + three
    print(f"part 2 -> {answer}")
part_2()


At time 100 we have 5343 active heads
At time 200 we have 15776 active heads
one -> {one}
At time 100 we have 5333 active heads
At time 200 we have 15851 active heads
two -> {two}
At time 100 we have 5360 active heads
At time 200 we have 15849 active heads
three -> {three}
part 2 -> 785
