In [1]:
from collections import namedtuple
from typing import Optional
import heapq


In [2]:
with open("input.txt", "rt") as f:
    heat_loss_map = []
    for row in f.read().strip().split("\n"):
        heat_loss_map.append(list(map(int, row)))

n_rows = len(heat_loss_map)
n_cols = len(heat_loss_map[0])

In [3]:
Point = namedtuple("Point", ["x", "y"])


def top(p: Point) -> Point:
    return Point(p.x, p.y - 1)


def left(p: Point) -> Point:
    return Point(p.x - 1, p.y)


def right(p: Point) -> Point:
    return Point(p.x + 1, p.y)


def bottom(p: Point) -> Point:
    return Point(p.x, p.y + 1)


def move(p: Point, direction: str) -> Optional[Point]:
    if direction == "t":
        return top(p) if p.y > 0 else None
    elif direction == "l":
        return left(p) if p.x > 0 else None
    elif direction == "r":
        return right(p) if p.x < n_cols - 1 else None
    elif direction == "b":
        return bottom(p) if p.y < n_rows - 1 else None

# Part 1

In [4]:
final_cost = float("inf")
cache = {}
paths = [
    (
        0,  # heat loss it takes to get here
        Point(0,0),  # point at which we are now
        0,  # number of steps in straight line
        "r",  # direction from which we came
    )
]

_cache_hits = 0
_heat_loss_hits = 0

while paths:
    heat_loss, position, steps, direction = heapq.heappop(paths)

    if heat_loss >= final_cost:
        _heat_loss_hits += 1
        continue

    if position.x == n_cols - 1 and position.y == n_rows - 1:
        final_cost = min(heat_loss, final_cost)
        continue
    
    if direction == "t":
        to_go = "tlr"
    elif direction == "l":
        to_go = "ltb"
    elif direction == "r":
        to_go = "rtb"
    elif direction == "b":
        to_go = "blr"

    for new_direction in to_go:
        new_steps = 0 if direction != new_direction else steps + 1
        if new_steps == 3:
            continue
        
        new_position = move(position, new_direction)
        if new_position is None:
            continue

        new_heat_loss = heat_loss + heat_loss_map[new_position.y][new_position.x]

        key = (new_position, new_direction, new_steps)
        if key in cache and cache[key] <= new_heat_loss:
            _cache_hits += 1
            continue
        cache[key] = new_heat_loss

        heapq.heappush(paths, (new_heat_loss, new_position, new_steps, new_direction))

print(f"{_cache_hits = }")
print(f"{_heat_loss_hits = }")
final_cost

_cache_hits = 365572
_heat_loss_hits = 1393


742

# Part 2

In [5]:
final_cost = float("inf")
cache = {}
paths = [
    (
        0,  # heat loss it takes to get here
        Point(0,0),  # point at which we are now
        "r",  # direction from which we came
    ),
]

_cache_hits = 0
_heat_loss_hits = 0


while paths:
    heat_loss, position, direction = heapq.heappop(paths)

    if heat_loss >= final_cost:
        _heat_loss_hits += 1
        continue

    if position.x == n_cols - 1 and position.y == n_rows - 1:
        final_cost = min(heat_loss, final_cost)
        continue

    new_position = position
    new_heat_loss = heat_loss
    for step in range(1, 10 + 1):


        new_position = move(new_position, direction)
        if new_position is None:
            break
        
        new_heat_loss += heat_loss_map[new_position.y][new_position.x]

        if step >= 4:
            if direction in "lr":
                to_go = "tb"
            elif direction in "tb":
                to_go = "lr"
            
            for new_direction in to_go:
                key = (new_direction, new_position)
                if key in cache and cache[key] <= new_heat_loss:
                    _cache_hits += 1
                    continue
                cache[key] = new_heat_loss

                heapq.heappush(paths, (new_heat_loss, new_position, new_direction))

print(f"{_cache_hits = }")
print(f"{_heat_loss_hits = }")
final_cost

_cache_hits = 2266644
_heat_loss_hits = 1279


918