In [48]:
from pathlib import Path
from heapq import heappush, heappop

NORTH = 0
EAST = 1
SOUTH = 2
WEST = 3

def load_file(path):
    res = []
    with Path(path).open() as f:
        for line in f.readlines():
            res.append([int(num) for num in line.strip()])
    return res

In [49]:
load_file("17_test.txt")

[[2, 4, 1, 3, 4, 3, 2, 3, 1, 1, 3, 2, 3],
 [3, 2, 1, 5, 4, 5, 3, 5, 3, 5, 6, 2, 3],
 [3, 2, 5, 5, 2, 4, 5, 6, 5, 4, 2, 5, 4],
 [3, 4, 4, 6, 5, 8, 5, 8, 4, 5, 4, 5, 2],
 [4, 5, 4, 6, 6, 5, 7, 8, 6, 7, 5, 3, 6],
 [1, 4, 3, 8, 5, 9, 8, 7, 9, 8, 4, 5, 4],
 [4, 4, 5, 7, 8, 7, 6, 9, 8, 7, 7, 6, 6],
 [3, 6, 3, 7, 8, 7, 7, 9, 7, 9, 6, 5, 3],
 [4, 6, 5, 4, 9, 6, 7, 9, 8, 6, 8, 8, 7],
 [4, 5, 6, 4, 6, 7, 9, 9, 8, 6, 4, 5, 3],
 [1, 2, 2, 4, 6, 8, 6, 8, 6, 5, 5, 6, 3],
 [2, 5, 4, 6, 5, 4, 8, 8, 8, 7, 7, 3, 5],
 [4, 3, 2, 2, 6, 7, 4, 6, 5, 5, 5, 3, 3]]

In [103]:
def trace(grid):
    m = len(grid)
    n = len(grid[0])
    queue = [(0, 0, 0, 0, EAST)]
    seen = set()
    while queue:
        h, i, j, cnt, d = heappop(queue)
        if i == m - 1 and j == n - 1:
            return h
        if i < m-1 and d != NORTH and (d != SOUTH or cnt < 3):
            new_cnt = cnt+1 if d == SOUTH else 1
            if (i+1, j, new_cnt, SOUTH) not in seen:
                seen.add((i+1, j, new_cnt, SOUTH))
                heappush(queue, (grid[i+1][j] + h, i+1, j, new_cnt, SOUTH))
        if j < n-1 and d != WEST and (d != EAST or cnt < 3):
            new_cnt = cnt+1 if d == EAST else 1
            if (i, j+1, new_cnt, EAST) not in seen:
                seen.add((i, j+1, new_cnt, EAST))
                heappush(queue, (grid[i][j+1] + h, i, j+1, new_cnt, EAST))
        if i > 0 and d != SOUTH and (d != NORTH or cnt < 3):
            new_cnt = cnt+1 if d == NORTH else 1
            if (i-1, j, new_cnt, NORTH) not in seen:
                seen.add((i-1, j, new_cnt, NORTH))
                heappush(queue, (grid[i-1][j] + h, i-1, j, new_cnt, NORTH))
        if j > 0 and d != EAST and (d != WEST or cnt < 3):
            new_cnt = cnt+1 if d == WEST else 1
            if (i, j-1, new_cnt, WEST) not in seen:
                seen.add((i, j-1, new_cnt, WEST))
                heappush(queue, (grid[i][j-1] + h, i, j-1, new_cnt, WEST))
    return None

def trace2(grid):
    m = len(grid)
    n = len(grid[0])
    queue = [(0, 0, 0, 0, EAST, []), (0, 0, 0, 0, SOUTH, [])]
    seen = set()
    while queue:
        h, i, j, cnt, d, q = heappop(queue)
        if i == m - 1 and j == n - 1 and cnt > 3:
            return h, q
        if i < m-1 and d != NORTH and ((d == SOUTH and cnt < 10) or (d != SOUTH and cnt > 3)):
            new_cnt = cnt+1 if d == SOUTH else 1
            if (i+1, j, new_cnt, SOUTH) not in seen:
                seen.add((i+1, j, new_cnt, SOUTH))
                heappush(queue, (grid[i+1][j] + h, i+1, j, new_cnt, SOUTH, q + [(i+1, j)]))
        if j < n-1 and d != WEST and ((d == EAST and cnt < 10) or (d != EAST and cnt > 3)):
            new_cnt = cnt+1 if d == EAST else 1
            if (i, j+1, new_cnt, EAST) not in seen:
                seen.add((i, j+1, new_cnt, EAST))
                heappush(queue, (grid[i][j+1] + h, i, j+1, new_cnt, EAST, q + [(i, j+1)]))
        if i > 0 and d != SOUTH and ((d == NORTH and cnt < 10) or (d != NORTH and cnt > 3)):
            new_cnt = cnt+1 if d == NORTH else 1
            if (i-1, j, new_cnt, NORTH) not in seen:
                seen.add((i-1, j, new_cnt, NORTH))
                heappush(queue, (grid[i-1][j] + h, i-1, j, new_cnt, NORTH, q + [(i-1, j)]))
        if j > 0 and d != EAST and ((d == WEST and cnt < 10) or (d != WEST and cnt > 3)):
            new_cnt = cnt+1 if d == WEST else 1
            if (i, j-1, new_cnt, WEST) not in seen:
                seen.add((i, j-1, new_cnt, WEST))
                heappush(queue, (grid[i][j-1] + h, i, j-1, new_cnt, WEST, q + [(i, j-1)]))
    return None

In [104]:
grid = load_file("17_test.txt")
trace(grid)  # 102

102

In [105]:
grid = load_file("17_input.txt")
trace(grid)  # 886

886

In [106]:
grid = load_file("17_test.txt")
trace2(grid) # 94

(94,
 [(0, 1),
  (0, 2),
  (0, 3),
  (0, 4),
  (0, 5),
  (0, 6),
  (0, 7),
  (0, 8),
  (1, 8),
  (2, 8),
  (3, 8),
  (4, 8),
  (4, 9),
  (4, 10),
  (4, 11),
  (4, 12),
  (5, 12),
  (6, 12),
  (7, 12),
  (8, 12),
  (9, 12),
  (10, 12),
  (11, 12),
  (12, 12)])

In [112]:
grid = load_file("17_test2.txt")
res, path = trace2(grid) 
res # 71

71

In [113]:
for i, j in path:
    grid[i][j] = 0
grid

[[1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
 [9, 9, 9, 9, 9, 9, 9, 0, 9, 9, 9, 1],
 [9, 9, 9, 9, 9, 9, 9, 0, 9, 9, 9, 1],
 [9, 9, 9, 9, 9, 9, 9, 0, 9, 9, 9, 1],
 [9, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0]]

In [114]:
grid = load_file("17_input.txt")
res, path = trace2(grid)  
res # 1055

1055

Exception in thread Thread-9:
Traceback (most recent call last):
  File "/opt/homebrew/Caskroom/miniconda/base/envs/aoc/lib/python3.12/threading.py", line 1052, in _bootstrap_inner
    self.run()
  File "/opt/homebrew/Caskroom/miniconda/base/envs/aoc/lib/python3.12/site-packages/zmq/utils/garbage.py", line 48, in run
    if msg == b'DIE':
       ^^^
  File "/opt/homebrew/Caskroom/miniconda/base/envs/aoc/lib/python3.12/site-packages/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_trace_dispatch_regular.py", line 203, in trace_dispatch
    py_db.enable_tracing(thread_trace_func)
  File "/opt/homebrew/Caskroom/miniconda/base/envs/aoc/lib/python3.12/site-packages/debugpy/_vendored/pydevd/pydevd.py", line 1107, in enable_tracing
    pydevd_tracing.SetTrace(thread_trace_func)
  File "/opt/homebrew/Caskroom/miniconda/base/envs/aoc/lib/python3.12/site-packages/debugpy/_vendored/pydevd/pydevd_tracing.py", line 87, in SetTrace
    if set_trace_to_threads(tracing_func, thread_idents=[thread.get_id