In [1]:
import numpy as np
from dataclasses import dataclass
from aocd import data
from functools import reduce

In [2]:
N =  0 - 1j
E =  1 + 0j
S =  0 + 1j
W = -1 + 0j

In [4]:
grid = np.array([list(line) for line in data.splitlines()])
max_x, max_y = grid.shape
visited = set()

In [5]:
def out_of_bounds(position: complex):
    x, y = position.imag, position.real
    return x < 0 or y < 0 or x >= max_x or y >= max_y

@dataclass
class Beam:
    position:  complex
    direction: complex

    @property
    def index(self):
        return (int(self.position.imag), int(self.position.real))

    def travel(self) -> 'list[Beam]':
        if out_of_bounds(self.position):
            return []
        visited.add((self.position, self.direction))
        match grid[self.index]:
            case '.':
                return [Beam(position=self.position + self.direction, direction=self.direction)]
            case '/':
                direction = {N: E, E: N, S: W, W: S}[self.direction]
                return [Beam(position=self.position + direction, direction=direction)]
            case '\\':
                direction = {N: W, E: S, S: E, W: N}[self.direction]
                return [Beam(position=self.position + direction, direction=direction)]
            case '|':
                if self.direction in [N, S]:
                    return [Beam(position=self.position + self.direction, direction=self.direction)]
                return [Beam(position=self.position + N, direction=N),
                        Beam(position=self.position + S, direction=S)]
            case '-':
                if self.direction in [E, W]:
                    return [Beam(position=self.position + self.direction, direction=self.direction)]
                return [Beam(position=self.position + E, direction=E),
                        Beam(position=self.position + W, direction=W)]
        return []


In [6]:
def energized(position, direction):
    start = Beam(position=position, direction=direction)
    beams = [start]
    visited.clear()

    while True:
        num_visited = len(visited)
        _beams = sum([beam.travel() for beam in beams], [])
        beams = [beam for beam in _beams if (beam.position, beam.direction) not in visited]
        if len(visited) == num_visited:
            break

    return len(set(pos for pos, _ in visited))

In [7]:
energized(0, E)


7498

In [8]:
max_energized = 0
for x in range(max_x):
    max_energized = max(max_energized, energized(x*1j, E))
    max_energized = max(max_energized, energized(x*1j + max_y-1, W))
for y in range(max_y):
    max_energized = max(max_energized, energized(y, S))
    max_energized = max(max_energized, energized(y + (max_y-1)*1j, N))

max_energized

7846