In [1]:
import pandas as pd
import numpy as np
from queue import PriorityQueue
from dataclasses import dataclass

In [2]:
grid = (
    pd.read_csv("data/16.txt", engine="python", sep=r"\s*", names=range(-1, 1000))
    .dropna(axis=1, how="all")
    .values
)

In [3]:
@dataclass(frozen=True, order=True)
class NodeState:
    position: tuple
    direction: tuple


def enhanced_heuristic(x, y, goal_x, goal_y):
    return min(np.abs(x - goal_x), np.abs(y - goal_y))


def get_neighbors(node_state: NodeState):
    x, y = node_state.position
    dx, dy = node_state.direction
    if 0 <= x + dx < grid.shape[0] and 0 <= y + dy < grid.shape[1]:
        if grid[x + dx, y + dy] != "#":
            yield NodeState(position=(x + dx, y + dy), direction=(dx, dy)), 1

    yield NodeState(position=(x, y), direction=(dy, -dx)), 1000
    yield NodeState(position=(x, y), direction=(-dy, dx)), 1000

In [4]:
start = np.where(grid == "S")
start_x, start_y = start[0][0], start[1][0]

goal = np.where(grid == "E")
goal_x, goal_y = goal[0][0], goal[1][0]

frontier = PriorityQueue()
frontier.put((0, NodeState(position=(start_x, start_y), direction=(0, 1))))

cost_so_far = {start_state: 0 for _, start_state in frontier.queue}
came_from = {start_state: None for _, start_state in frontier.queue}

while not frontier.empty():
    _, current = frontier.get()
    if current.position == goal:
        break

    for next_state, price in get_neighbors(current):
        new_cost = cost_so_far[current] + price
        if next_state not in cost_so_far or new_cost < cost_so_far[next_state]:
            cost_so_far[next_state] = new_cost
            priority = new_cost + enhanced_heuristic(
                *next_state.position, goal_x, goal_y
            )
            frontier.put((priority, next_state))
            came_from[next_state] = current

cost_so_far[current]

134588

In [5]:
frontier = PriorityQueue()
start_state = NodeState(position=(start_x, start_y), direction=(0, 1))
frontier.put((0, start_state))

cost_so_far = {start_state: 0}

# Since we want all shortest paths, we won't stop upon first reaching the goal.
while not frontier.empty():
    current_cost, current = frontier.get()
    # If the cost is higher than a known better cost, skip
    if cost_so_far[current] < current_cost:
        continue

    # Explore neighbors
    for next_state, price in get_neighbors(current):
        new_cost = current_cost + price
        if next_state not in cost_so_far or new_cost < cost_so_far[next_state]:
            cost_so_far[next_state] = new_cost
            # No heuristic for part 2, just Dijkstra's
            frontier.put((new_cost, next_state))

# At this point, we have cost_so_far for all reachable states.
# Find the minimal cost to reach the goal tile (ignore direction at first).
min_goal_cost = None
goal_states = []
for state, c in cost_so_far.items():
    if state.position == (goal_x, goal_y):
        if min_goal_cost is None or c < min_goal_cost:
            min_goal_cost = c
            goal_states = [state]
        elif c == min_goal_cost:
            goal_states.append(state)

if min_goal_cost is None:
    raise ValueError("No path found to the goal.")

# Now we want to find all states that lie on any shortest path.
# We'll reconstruct edges that conform to shortest paths:
# For each state, consider neighbors. If dist[u] + cost(u->v) = dist[v],
# then u is predecessor of v on a shortest path.

predecessors = {}  # Mapping: state -> list of predecessor states on shortest paths
for state, dist in cost_so_far.items():
    for next_state, edge_cost in get_neighbors(state):
        if next_state in cost_so_far:
            # Check if this edge is on a shortest path
            if dist + edge_cost == cost_so_far[next_state]:
                # next_state has state as a predecessor on a shortest path
                if next_state not in predecessors:
                    predecessors[next_state] = []
                predecessors[next_state].append(state)

# Starting from each goal_state, do a BFS/DFS backward using predecessors
on_shortest_path = set()

stack = list(goal_states)
while stack:
    node = stack.pop()
    if node in on_shortest_path:
        continue
    on_shortest_path.add(node)
    if node in predecessors:
        for p in predecessors[node]:
            if p not in on_shortest_path:
                stack.append(p)

# Now on_shortest_path contains all NodeStates on a shortest path.
# We only need the tile positions. Since multiple directions can be on shortest paths,
# a tile is on a shortest path if ANY of its NodeStates are in on_shortest_path.

on_shortest_path_positions = set(ns.position for ns in on_shortest_path)

# Mark these on the grid. Let's mark them with 'O' as in the puzzle description
# (except for S and E which should remain).
for x, y in on_shortest_path_positions:
    if grid[x, y] not in ["#"]:
        grid[x, y] = "O"

In [6]:
(grid == "O").sum().sum()

631