In [None]:
from tabulate import tabulate
from queue import PriorityQueue
from dataclasses import dataclass, field
from typing import Any

EXAMPLE_1 = "../example_1.txt"
EXAMPLE_2 = "../example_2.txt"
INPUT = "../input.txt"

In [None]:
def parse_input(input_file_name):
    map = []
    with open(input_file_name, "r") as f:
        for line in f:
            map.append([c for c in line.strip().replace("\n", "")])
    return map

In [None]:
map = parse_input(EXAMPLE_1)
print(tabulate(map))

In [None]:
def find_start_end(map):
    height = len(map)
    width = len(map[0])
    start, end = (-1, -1), (-1, -1)
    for row in range(height):
        for col in range(width):
            if map[row][col] == "S":
                start = (row, col)
            if map[row][col] == "E":
                end = (row, col)
            if start != (-1, -1) and end != (-1, -1):
                return (start, end)
    return (start, end)

In [None]:
start, end = find_start_end(map)
print(start, end)

In [None]:
def get_possible_moves(map, position, direction, score):
    possible_moves = []
    (row, col) = position
    (r_dir, c_dir) = direction
    if map[row + r_dir][col + c_dir] != "#":
        # Move in current direction
        possible_moves.append(((row + r_dir, col + c_dir), direction, score + 1))
    for r_d, c_d in [(0, 1), (0, -1), (1, 0), (-1, 0)]:
        # Turn 90 degrees
        if (r_d, c_d) != (r_dir, c_dir) and (r_d, c_d) != (-r_dir, -c_dir):
            possible_moves.append(((row, col), (r_d, c_d), score + 1000))
    return possible_moves

In [None]:
@dataclass(order=True)
class PrioritizedPosition:
    score: int
    position: Any = field(compare=False)
    direction: Any = field(compare=False)
    path: Any = field(compare=False)


def move(map, queue: PriorityQueue[PrioritizedPosition], visited_positions: set, end):
    # We use a priority queue to store the positions we explore, based on their current score
    # Once we reach the end position, we know we've found the lowest score to get there
    prioritied_position = queue.get()
    current_score = prioritied_position.score
    position = prioritied_position.position
    direction = prioritied_position.direction
    path: set = prioritied_position.path
    path.add(position)
    visited_positions.add((position, direction))
    if position == end:
        return True, current_score, path
    possible_moves = get_possible_moves(map, position, direction, current_score)
    for possible_move in possible_moves:
        position, direction, score = possible_move
        if (position, direction) not in visited_positions:
            new_path = path.copy()
            p = PrioritizedPosition(score, position, direction, new_path)
            queue.put(p)
    return False, current_score, path


def find_best_paths(map):
    start, end = find_start_end(map)
    visited_positions = set()
    queue: PriorityQueue[PrioritizedPosition] = PriorityQueue()
    p = PrioritizedPosition(0, start, (0, 1), set())
    queue.put(p)
    end_reached = False
    min_score = -1
    best_spots = set()

    while not end_reached and not queue.empty():
        end_reached, score, path = move(map, queue, visited_positions, end)
        if end_reached:
            # We found the minimum score to reach the end
            min_score = score
            # We need to go through the remaining items in the queue with the minimum score
            # (that includes positions that are not the end)
            while score == min_score:
                if end_reached:
                    # We update the best spots only if the end is reached
                    best_spots = best_spots.union(path)
                end_reached, score, path = move(map, queue, visited_positions, end)
            # We've unqueued all positions with a min score, we can stop now
            break
    return min_score, best_spots

In [None]:
def part_1(input_file_name):
    map = parse_input(input_file_name)
    result, _ = find_best_paths(map)
    print(result)

In [None]:
part_1(EXAMPLE_1)

In [None]:
part_1(EXAMPLE_2)

In [None]:
part_1(INPUT)

In [None]:
def part_2(input_file_name):
    map = parse_input(input_file_name)
    _, best_spots = find_best_paths(map)
    result = len(best_spots)
    print(result)

In [None]:
part_2(EXAMPLE_1)

In [None]:
part_2(EXAMPLE_2)

In [None]:
part_2(INPUT)