In [25]:
from collections import deque
from functools import cache

import numpy as np

In [26]:
def parse_input(path):
    with open(path) as file_in:
        data = file_in.read().splitlines()

    splitters = set()
    for i, row in enumerate(data):
        for j, elem in enumerate(row):
            if elem == 'S':
                start = (i, j)
            if elem == '^':
                splitters.add((i, j))

    return start, splitters, len(data), len(data[0])

In [27]:
def get_next_splitter(pos, splitters, n_rows):
    pos = (pos[0]+1, pos[1])
    if pos[0] == n_rows-1:
        return None

    while pos not in splitters:
        pos = pos[0]+1, pos[1]
        if pos[0] == n_rows-1:
            return None 

    return pos

In [28]:
def main1(input_path):
    start, splitters, n_rows, __ = parse_input(input_path)

    splitters_reached = set()
    queue = deque()
    queue.append(start)

    while queue:
        pos = queue.popleft()
        next_splitter = get_next_splitter(pos, splitters, n_rows)
        if next_splitter is not None and next_splitter not in splitters_reached:
            splitters_reached.add(next_splitter)
            queue.append((next_splitter[0], next_splitter[1]-1))
            queue.append((next_splitter[0], next_splitter[1]+1))

    return len(splitters_reached)

In [29]:
def main2(input_path):
    start, splitters, n_rows, n_cols = parse_input(input_path)

    pos_last_row = [(n_rows-1, col) for col in range(n_cols)]
    n_incoming = {s: 0 for s in list(splitters) + pos_last_row}
    first_splitter = [s for s in splitters if s[0] == 2][0]
    n_incoming[first_splitter] = 1

    for row in range(2, n_rows, 1):
        row_splitters = [s for s in splitters if s[0] == row]
        if not row_splitters:
            continue
        for rs in row_splitters:

            next_splitter1 = get_next_splitter((rs[0], rs[1]-1), splitters, n_rows)
            if next_splitter1 is not None:
                n_incoming[next_splitter1] += n_incoming[rs]
            else:
                n_incoming[(n_rows-1, rs[1]-1)] += n_incoming[rs]
                
            next_splitter2 = get_next_splitter((rs[0], rs[1]+1), splitters, n_rows)
            if next_splitter2 is not None:
                n_incoming[next_splitter2] += n_incoming[rs]
            else:
                n_incoming[(n_rows-1, rs[1]+1)] += n_incoming[rs]

    last_row_total = sum([n_incoming[pos] for pos in pos_last_row])
    return last_row_total

In [30]:
assert main1("example.txt") == 21
main1("input.txt")

1630

In [31]:
assert main2("example.txt") == 40
main2("input.txt")

47857642990160