In [8]:
from enum import Enum
from dataclasses import dataclass

class Direction(int, Enum):
    N = 0
    E = 1
    S = 2
    W = 3

class Square(Enum):
    Empty = '.'
    MirrorNE = '/'
    MirrorNW = "\\"
    SplitterHor = '|'
    SplitterVer = '-'

@dataclass
class Ray:
    x: int
    y: int
    d: Direction

In [9]:
def next_ray(ray: Ray, sq: Square, lenx: int, leny: int) -> list[Ray]:
    match sq:
        case Square.Empty:
            match ray.d:
                case Direction.N:
                    if ray.y == 0:
                        return []
                    else:
                        return [Ray(ray.x, ray.y-1, ray.d),]
                case Direction.E:
                    if ray.x == lenx-1:
                        return []
                    else:
                        return [Ray(ray.x+1, ray.y, ray.d),]
                case Direction.S:
                    if ray.y == leny-1:
                        return []
                    else:
                        return [Ray(ray.x, ray.y+1, ray.d),]
                case Direction.W:
                    if ray.x == 0:
                        return []
                    else:
                        return [Ray(ray.x-1, ray.y, ray.d),]
        case Square.MirrorNE:
            match ray.d:
                case Direction.N:
                    if ray.x == lenx-1:
                        return []
                    else:
                        return [Ray(ray.x+1, ray.y, Direction.E),]
                case Direction.E:
                    if ray.y == 0:
                        return []
                    else:
                        return [Ray(ray.x, ray.y-1, Direction.N)]
                case Direction.S:
                    if ray.x == 0:
                        return []
                    else:
                        return [Ray(ray.x-1, ray.y, Direction.W)]
                case Direction.W:
                    if ray.y == leny-1:
                        return []
                    else:
                        return [Ray(ray.x, ray.y+1, Direction.S)]
        case Square.MirrorNW:
            match ray.d:
                case Direction.N:
                    if ray.x == 0:
                        return []
                    else:
                        return [Ray(ray.x-1, ray.y, Direction.W)]
                case Direction.E:
                    if ray.y == leny-1:
                        return []
                    else:
                        return [Ray(ray.x, ray.y+1, Direction.S)]
                case Direction.S:
                    if ray.x == lenx-1:
                        return []
                    else:
                        return [Ray(ray.x+1, ray.y, Direction.E)]
                case Direction.W:
                    if ray.y == 0:
                        return []
                    else:
                        return [Ray(ray.x, ray.y-1, Direction.N)]
        case Square.SplitterHor:
            match ray.d:
                case Direction.N:
                    if ray.y == 0:
                        return []
                    else:
                        return [Ray(ray.x, ray.y-1, ray.d),]
                case Direction.S:
                    if ray.y == leny-1:
                        return []
                    else:
                        return [Ray(ray.x, ray.y+1, ray.d),]
                case Direction.E | Direction.W:
                    r = []
                    if ray.y != 0:
                        r.append(Ray(ray.x, ray.y-1, Direction.N))
                    if ray.y != leny-1:
                        r.append(Ray(ray.x, ray.y+1, Direction.S))
                    return r
        case Square.SplitterVer:
            match ray.d:
                case Direction.E:
                    if ray.x == lenx-1:
                        return []
                    else:
                        return [Ray(ray.x+1, ray.y, ray.d),]
                case Direction.W:
                    if ray.x == 0:
                        return []
                    else:
                        return [Ray(ray.x-1, ray.y, ray.d),]
                case Direction.N | Direction.S:
                    r = []
                    if ray.x != 0:
                        r.append(Ray(ray.x-1, ray.y, Direction.W))
                    if ray.x != lenx-1:
                        r.append(Ray(ray.x+1, ray.y, Direction.E))
                    return r

In [30]:
def part1(grid: list[str]) -> int:
    lenx = len(grid[0].strip())
    leny = len(grid)
    light = []
    for _ in range(leny):
        s = []
        for _ in range(lenx):
            s.append([False, False, False, False])
        light.append(s)
    light[0][0][Direction.E.value] = True
    rays: list[Ray] = [Ray(0, 0, Direction.E)]
    while len(rays) > 0:
        ray = rays.pop()
        newrays = next_ray(ray, Square(grid[ray.y][ray.x]), lenx, leny)
        for r in newrays:
            if not light[r.y][r.x][r.d.value]:
                light[r.y][r.x][r.d.value] = True
                rays.append(r)
    return sum(sum(1 for j in range(lenx) if any(light[i][j])) for i in range(leny))

In [25]:
with open('test.txt', 'rt') as f:
    test = f.readlines()

In [31]:
part1(test)

46

In [32]:
with open('input', 'rt') as f:
    inp = f.readlines()

In [33]:
with open('output1', 'wt') as f:
    f.write(str(part1(inp)))