In [3]:
import itertools

import networkx as nx
import numpy as np


In [4]:
def parse_input(file):

    with open(file, 'r') as file_in:
        rows = file_in.read().splitlines()

    rows = [[int(x) for x in r] for r in rows]

    return np.array(rows)

In [6]:
def get_neighbours(grid, n):
    x, y, from_dir, n_last_turn = n
    n_rows, n_cols = grid.shape
    neighbours = []

    if n == (0, 0, None, None):
        neighbours.extend([(0, 1, 'W', 1), (1, 0, 'N', 1)])

    # When coming from North, can't move North
    if from_dir == 'N':
        if y > 0:  # Move West
            neighbours.append((x, y - 1, 'E', 1))
        if y < n_cols - 1:  # Move East
            neighbours.append((x, y + 1, 'W', 1))
        if (x < n_rows - 1) and (n_last_turn < 3):  # Move South
            neighbours.append((x + 1, y, 'N', n_last_turn+1))

    # When coming from West, can't move West
    elif from_dir == 'W':
        if x > 0:  # Move North
            neighbours.append((x - 1, y, 'S', 1))
        if x < n_rows - 1:  # Move South
            neighbours.append((x + 1, y, 'N', 1))
        if (y < n_cols - 1) and (n_last_turn < 3):  # Move East
            neighbours.append((x, y + 1, 'W', n_last_turn+1))

    # When coming from South, can't move South
    elif from_dir == 'S':
        if y > 0:  # Move West
            neighbours.append((x, y - 1, 'E', 1))
        if y < n_cols - 1:  # Move East
            neighbours.append((x, y + 1, 'W', 1))
        if x > 0 and (n_last_turn < 3):  # Move North
            neighbours.append((x - 1, y, 'S', n_last_turn+1))

    # When coming from East, can't move East
    elif from_dir == 'E':
        if x > 0:  # Move North
            neighbours.append((x - 1, y, 'S', 1))
        if x < n_rows - 1:  # Move South
            neighbours.append((x + 1, y, 'N', 1))
        if y > 0 and (n_last_turn < 3):  # Move West
            neighbours.append((x, y - 1, 'E', n_last_turn+1))

    return neighbours

In [31]:
def dijkstra(grid):
    d = {}
    s_deb = (0, 0, None, None)
    d[s_deb] = 0
    P = set()

    previous = {}

    for i in range(10000):
        a = min({k: d[k] for k in d if k not in P}, key=d.get)
        P.add(a)
        for b in get_neighbours(grid, a):
            if b not in d:
                d[b] = float('inf')
            w_a_b = grid[b[0], b[1]]
            if d[b] > d[a] + w_a_b:
                d[b] = d[a] + w_a_b
                previous[b] = a

    # min_heat_loss = min([v for k, v in d.items() if (k[0] == grid.shape[0] - 1) and k[1] == grid.shape[1] - 1])

    return d

In [32]:
file = 'puzzle.txt'
grid = parse_input(file)
dijkstra(grid)

{(0, 0, None, None): 0,
 (0, 1, 'W', 1): 3,
 (1, 0, 'N', 1): 2,
 (1, 1, 'W', 1): 6,
 (2, 0, 'N', 2): 3,
 (1, 1, 'N', 1): 7,
 (0, 2, 'W', 2): 6,
 (2, 1, 'W', 1): 4,
 (3, 0, 'N', 3): 7,
 (1, 1, 'S', 1): 8,
 (3, 1, 'N', 1): 5,
 (2, 2, 'W', 2): 8,
 (3, 0, 'E', 1): 9,
 (3, 2, 'W', 1): 7,
 (4, 1, 'N', 2): 8,
 (0, 1, 'S', 1): 9,
 (2, 1, 'N', 1): 7,
 (1, 2, 'W', 2): 8,
 (1, 2, 'N', 1): 8,
 (0, 3, 'W', 3): 9,
 (1, 0, 'E', 1): 9,
 (1, 2, 'W', 1): 9,
 (2, 1, 'N', 2): 8,
 (3, 1, 'W', 1): 8,
 (2, 2, 'S', 1): 11,
 (4, 2, 'N', 1): 9,
 (3, 3, 'W', 2): 9,
 (2, 0, 'E', 1): 8,
 (2, 2, 'W', 1): 11,
 (3, 1, 'N', 2): 8,
 (0, 1, 'S', 2): 11,
 (1, 2, 'S', 1): 10,
 (3, 2, 'N', 1): 10,
 (2, 3, 'W', 3): 9,
 (4, 0, 'E', 1): 9,
 (4, 2, 'W', 1): 10,
 (5, 1, 'N', 3): 9,
 (0, 2, 'S', 1): 11,
 (2, 2, 'N', 1): 12,
 (1, 3, 'W', 3): 12,
 (1, 1, 'E', 1): 12,
 (1, 3, 'W', 1): 12,
 (2, 2, 'N', 2): 12,
 (3, 1, 'N', 3): 9,
 (2, 1, 'S', 1): 9,
 (4, 1, 'N', 1): 11,
 (3, 2, 'W', 2): 10,
 (1, 0, 'S', 1): 10,
 (3, 0, 'N', 1): 12,


In [2]:
def plot_path(grid, path):
    grid_plot = grid.copy()
    for x, y, _, _ in path:
        grid_plot[x, y] = 99
    print(grid_plot)


def heat_path(grid, path):
    heat = 0
    for x, y, _, _ in path:
        heat += grid[x, y]
    print(heat)


path = []
s = (12, 12, 'W', 1)

while s != s_deb:
    path.append(s)
    s = previous[s]

path = list(reversed(path))


plot_path(grid, path)
print(heat_path(grid, path))


NameError: name 's_deb' is not defined