In [1]:
from collections import deque
from functools import cache
from tqdm import tqdm

In [2]:
filename = "sample2.txt"
# filename = "input.txt"
with open(filename, encoding="utf-8") as f:
    data = f.read()

grid = data.strip().split("\n")

https://adventofcode.com/2024/day/10

In [3]:
## Part 1
# Given a heightmap, find the score of each trailhead
# Trailheads have height 0, and the score is the number of 9s reachable by taking only +1 height steps
heights = {x + y * 1j: int(c) for y, line in enumerate(grid) for x, c in enumerate(line) if c.isdecimal()}
dirs = [1, 1j, -1, -1j]  # ESWN

def adjacent(pos: complex) -> list[complex]:
    return [pos + step for step in dirs]

def reachable(pos: complex) -> list[complex]:
    nexts = adjacent(pos)
    h1 = heights[pos]
    result = []
    for pos2 in nexts:
        h2 = heights.get(pos2, None)
        if h2 is None:
            # Out of bounds, skip
            continue
        # Only reachable if it's exactly 1 higher than current
        if (h2 - h1) == 1:
            result.append(pos2)
    return result 

In [4]:
def get_score(trailhead: complex) -> int:
    seen = set()
    candidates = deque([trailhead])
    score = 0
    while candidates:
        pos = candidates.pop()
        if pos in seen:
            # Already seen, skip
            continue

        seen.add(pos)
        if heights[pos] == 9:
            # Reached a peak! Nowhere up from here
            score += 1
            continue

        # Try neighbours
        candidates.extend(reachable(pos))

    return score

In [5]:
trailheads = [pos for pos, height in heights.items() if height == 0]

# total_score = 0
# for t in tqdm(trailheads):
#     score = get_score(t)
#     print(f"{t}: {score}")
#     total_score += score

total_score = sum(get_score(t) for t in tqdm(trailheads))
total_score

100%|██████████| 9/9 [00:00<00:00, 9615.06it/s]


36

In [6]:
## Part 2
# Rating is the number of distinct hiking trails which begin at that trailhead
# The same pair of (0, 9) can have many distinct paths
# Will probably need some caching to avoid re-computing reachability
@cache
def trails_from(pos: complex) -> int:
    if heights[pos] == 9:
        # Reached a peak
        return 1
    return sum(trails_from(pos2) for pos2 in reachable(pos))

In [7]:
total_rating = sum(trails_from(t) for t in tqdm(trailheads))
total_rating

100%|██████████| 9/9 [00:00<00:00, 12114.49it/s]


81