In [7]:
from collections import defaultdict
from itertools import zip_longest
from typing import Iterable


def parser(file = 'input.txt'):
    with open(file, 'r') as file:
        for line in file:
            yield line.strip('\n')

def solver(lines: Iterable[str]):
    grid = [*lines]
    height = len(grid)
    width = len(grid[0])
    places = defaultdict(list)
    queue: list[tuple[int, int, tuple[int, int]]] = [(0, -1, (0, 1))]
    for y, x, di in queue:
        if di in places[y,x]:
            continue
        else:
            places[y,x].append(di)
        
        ny, nx = y + di[0], x + di[1]
        # print(ny, nx)
        if not (0 <= ny < height and 0 <= nx < width):
            continue
        # print(grid[ny][nx], ny, nx)
        # print(queue)
        # print()

        if grid[ny][nx] == '.':
            queue.append((ny, nx, di))
        elif grid[ny][nx] == '|':
            if di[1]:
                queue.append((ny, nx, (-1, 0)))
                queue.append((ny, nx, ( 1, 0)))
            else:
                queue.append((ny, nx, di))
        elif grid[ny][nx] == '-':
            if di[0]:
                queue.append((ny, nx, (0, -1)))
                queue.append((ny, nx, (0,  1)))
            else:
                queue.append((ny, nx, di))
        elif grid[ny][nx] == '/':
            if di[0]:
                queue.append((ny, nx, (0, -di[0])))
            else:
                queue.append((ny, nx, (-di[1], 0)))
        elif grid[ny][nx] == '\\':
            if di[0]:
                queue.append((ny, nx, (0, di[0])))
            else:
                queue.append((ny, nx, (di[1], 0)))
    
    return len(places)-1, places

def solver2(lines: Iterable[str]):
    grid = [*lines]
    height = len(grid)
    width = len(grid[0])
    combinations = set()
    startings_positions = [
        *zip_longest([-1], range(width), fillvalue=-1),
        *zip_longest([height], range(width), fillvalue=height),
        *zip_longest(range(height), [-1], fillvalue=-1),
        *zip_longest(range(height), [width], fillvalue=width),
    ]
    # print(startings_positions)
    for sy, sx in startings_positions:
        starting_direction = (
            1 if sy == -1 else (-1 if sy == height else 0),
            1 if sx == -1 else (-1 if sx == width else 0),
        )
        # print(sy, sx, starting_direction)
        places = defaultdict(list)
        queue: list[tuple[int, int, tuple[int, int]]] = [(sy, sx, starting_direction)]
        for y, x, di in queue:
            if di in places[y,x]:
                continue
            else:
                places[y,x].append(di)
            
            ny, nx = y + di[0], x + di[1]
            # print(ny, nx)
            if not (0 <= ny < height and 0 <= nx < width):
                continue
            # print(grid[ny][nx], ny, nx)
            # print(queue)
            # print()

            if grid[ny][nx] == '.':
                queue.append((ny, nx, di))
            elif grid[ny][nx] == '|':
                if di[1]:
                    queue.append((ny, nx, (-1, 0)))
                    queue.append((ny, nx, ( 1, 0)))
                else:
                    queue.append((ny, nx, di))
            elif grid[ny][nx] == '-':
                if di[0]:
                    queue.append((ny, nx, (0, -1)))
                    queue.append((ny, nx, (0,  1)))
                else:
                    queue.append((ny, nx, di))
            elif grid[ny][nx] == '/':
                if di[0]:
                    queue.append((ny, nx, (0, -di[0])))
                else:
                    queue.append((ny, nx, (-di[1], 0)))
            elif grid[ny][nx] == '\\':
                if di[0]:
                    queue.append((ny, nx, (0, di[0])))
                else:
                    queue.append((ny, nx, (di[1], 0)))
        combinations.add(len(places)-1)
        
    return max(combinations)

if __name__ == '__main__':
    test, _meta_test = solver(parser('test_input.txt'))
    print(test)
    assert test == 46
    # print(solver(parser()))

    test = solver2(parser('test_input.txt'))
    print(test)
    assert test == 51

    print(solver2(parser('input.txt')))

46
51
7741
