In [1]:
from collections import defaultdict
from heapq import heappop, heappush

input_file = "data/input.txt"

TOP = (0, -1)
RIGHT = (1, 0)
BOTTOM = (0, 1)
LEFT = (-1, 0)
START_DIR = (0, 0)

DIRECTIONS = [TOP, RIGHT, BOTTOM, LEFT]

OPPOSITE = {
    TOP: BOTTOM,
    RIGHT: LEFT,
    BOTTOM: TOP,
    LEFT: RIGHT,
    START_DIR: None
}

MAX_X = None
MAX_Y = None

def within_bounds(x, y):
    return (0 <= x < MAX_X) and (0 <= y < MAX_Y)

def get_node_after_step(node, direction):
    return (node[0] + direction[0], node[1] + direction[1])

def valid_direction(node, prev_direction, new_direction, straight_count, max_straight, min_straight):
    is_straight = (prev_direction == new_direction) or (prev_direction == START_DIR)
    new_node_x, new_node_y = get_node_after_step(node, new_direction)
    if new_direction == OPPOSITE[prev_direction] or not within_bounds(new_node_x, new_node_y):
        return False
    if is_straight:
        return straight_count < max_straight
    return straight_count >= min_straight

def dijkstra(matrix, start_node, max_straight, min_straight=0):
    distances = {
        (x, y): defaultdict(lambda: float('inf'))
        for y in range(MAX_Y)
        for x in range(MAX_X)
    }
    queue = [(0, start_node, START_DIR, 0)]

    while queue:
        heat_loss, node, prev_direction, straight_count = heappop(queue)
        for new_direction in DIRECTIONS:
            if not valid_direction(
                node,
                prev_direction,
                new_direction,
                straight_count,
                max_straight,
                min_straight
            ):
                continue

            new_node = get_node_after_step(node, new_direction)
            new_node_x, new_node_y = new_node

            new_heat_loss = heat_loss + matrix[new_node_y][new_node_x]
            new_dr_x, new_dr_y = new_direction

            is_straight = prev_direction == new_direction
            new_straight_count = straight_count + 1 if is_straight else 1

            ind1 = (new_node_x, new_node_y)
            ind2 = (new_dr_x, new_dr_y, new_straight_count)

            if new_heat_loss < distances[ind1][ind2]:
                distances[ind1][ind2] = new_heat_loss
                heappush(queue, (new_heat_loss, new_node, new_direction, new_straight_count))
    return distances

with open(input_file, 'r') as f:
    lines = [l.strip() for l in f.readlines()]
    matrix = [[int(d) for d in l] for l in lines]

    MAX_X = len(matrix[0])
    MAX_Y = len(matrix)

    distances1 = dijkstra(matrix, (0, 0), 3, 0)
    distances2 = dijkstra(matrix, (0, 0), 10, 4)

    final_distances1 = distances1[MAX_X - 1, MAX_Y - 1]
    final_distances2 = distances2[MAX_X - 1, MAX_Y - 1]

    ans1 = min(v for v in final_distances1.values())
    ans2 = min([final_distances2[k, v, s] for (k, v, s) in final_distances2 if s >= 4])

    print(f"{ans1 = }")
    print(f"{ans2 = }")

ans1 = 866
ans2 = 1010
