In [1]:
year = 2023
day = 17

In [148]:
from aocd import submit
from aocd.models import Puzzle
from functools import reduce
import numpy as np

puzzle = Puzzle(year=year, day=day)
data = puzzle.input_data
# data = puzzle.examples[0].input_data

data = data.strip()
data = data.split("\n")
data = [list(line) for line in data]
data = np.array(data).astype("int")


headings = {
    "N": (-1, 0),
    "S": (1, 0),
    "E": (0, 1),
    "W": (0, -1),
}

neighbours_90_deg = {
    "N": [(-1, 0), (0, -1), (0, 1)],
    "S": [(1, 0), (0, -1), (0, 1)],
    "E": [(0, 1), (-1, 0), (1, 0)],
    "W": [(0, -1), (-1, 0), (1, 0)],
}

lookup_heading = {
    (-1, 0): "N",
    (1, 0): "S",
    (0, 1): "E",
    (0, -1): "W",
}

H, W = data.shape

data

array([[5, 3, 4, ..., 3, 3, 2],
       [1, 4, 3, ..., 3, 2, 4],
       [4, 3, 1, ..., 4, 4, 2],
       ...,
       [4, 3, 3, ..., 2, 3, 5],
       [4, 2, 3, ..., 3, 3, 4],
       [1, 2, 4, ..., 1, 1, 2]])

In [191]:
def get_neighbours(y, x, h, t, max_t=2, min_t=0):
    n = []
    for d_y, d_x in neighbours_90_deg[h]:
        new_h = lookup_heading[(d_y, d_x)]
        new_t = t + 1 if new_h == h else 0
        if y + d_y < 0 or y + d_y >= H or x + d_x < 0 or x + d_x >= W:
            continue
        if (t+1 < min_t) and (h != new_h):
            continue
        if t+1 > max_t:
            continue
        n.append(((y + d_y, x + d_x), new_h, new_t))
    return n

In [194]:
from collections import defaultdict


def find_best_path(data, max_straight, min_straight=0):
    cache = defaultdict(lambda: 1_000_000_000)
    cache[((1,0), "S", 0)] = data[1,0]
    cache[((0,1), "E", 0)] = data[0,1]
    front = set(cache.keys())

    for i in range(1_000):
        new_front = set()
        for key in front:
            loc, h, t = key
            y, x = loc
            curr_loss = cache[key]
            neigh = get_neighbours(y, x, h, t, max_straight, min_straight)

            for next_key in neigh:
                loc, _, _ = next_key
                n_y, n_x = loc
                new_loss = data[n_y, n_x]

                if curr_loss + new_loss < cache[next_key]:
                    cache[next_key] = curr_loss + new_loss
                    new_front.add(next_key)

        if not new_front:
            print("done", i)
            break
        front = new_front
    return cache

In [195]:
cache = find_best_path(data, max_straight=3)
answer = min([value for key, value in cache.items() if key[0] == (H-1, W-1)])
submit(answer, part="a", year=year, day=day)

coerced int64 value 963 for 2023/17


done 365
Part a already solved with same answer: 963


In [196]:
cache = find_best_path(data, max_straight=10, min_straight=4)
answer = min([value for key, value in cache.items() if key[0] == (H-1, W-1)])
submit(answer, part="b", year=year, day=day)

coerced int64 value 1178 for 2023/17


done 378
Part b already solved with same answer: 1178
