In [65]:
import itertools
from heapq import heappush, heappop
import json

import numpy as np

In [66]:
def parse_input(file):

    with open(file, 'r') as file_in:
        rows = file_in.read().splitlines()

    rows = [[int(x) for x in r] for r in rows]

    return np.array(rows)

In [67]:
def get_neighbours(grid, n):
    x, y, from_dir, n_last_turn = n
    n_rows, n_cols = grid.shape
    neighbours = []

    if n == (0, 0, None, None):
        neighbours.extend([(0, 1, 'W', 1), (1, 0, 'N', 1)])

    # When coming from North, can't move North
    if from_dir == 'N':
        if y > 0:  # Move West
            neighbours.append((x, y - 1, 'E', 1))
        if y < n_cols - 1:  # Move East
            neighbours.append((x, y + 1, 'W', 1))
        if (x < n_rows - 1) and (n_last_turn < 3):  # Move South
            neighbours.append((x + 1, y, 'N', n_last_turn+1))

    # When coming from West, can't move West
    elif from_dir == 'W':
        if x > 0:  # Move North
            neighbours.append((x - 1, y, 'S', 1))
        if x < n_rows - 1:  # Move South
            neighbours.append((x + 1, y, 'N', 1))
        if (y < n_cols - 1) and (n_last_turn < 3):  # Move East
            neighbours.append((x, y + 1, 'W', n_last_turn+1))

    # When coming from South, can't move South
    elif from_dir == 'S':
        if y > 0:  # Move West
            neighbours.append((x, y - 1, 'E', 1))
        if y < n_cols - 1:  # Move East
            neighbours.append((x, y + 1, 'W', 1))
        if x > 0 and (n_last_turn < 3):  # Move North
            neighbours.append((x - 1, y, 'S', n_last_turn+1))

    # When coming from East, can't move East
    elif from_dir == 'E':
        if x > 0:  # Move North
            neighbours.append((x - 1, y, 'S', 1))
        if x < n_rows - 1:  # Move South
            neighbours.append((x + 1, y, 'N', 1))
        if y > 0 and (n_last_turn < 3):  # Move West
            neighbours.append((x, y - 1, 'E', n_last_turn+1))

    return neighbours

In [68]:
def dijkstra(grid):
    d = {}
    s_deb = (0, 0, None, None)
    d[s_deb] = 0
    P = set()

    previous = {}

    while not all(key in P for key in d):
        a = min({k: d[k] for k in d if k not in P}, key=d.get)
        P.add(a)
        for b in get_neighbours(grid, a):
            if b not in d:
                d[b] = float('inf')
            w_a_b = grid[b[0], b[1]]
            if d[b] > d[a] + w_a_b:
                d[b] = d[a] + w_a_b
                previous[b] = a

    min_heat_loss = min([v for k, v in d.items() if (k[0] == grid.shape[0] - 1) and k[1] == grid.shape[1] - 1])

    return min_heat_loss

In [84]:
file = 'puzzle.txt'
grid = parse_input(file)

priority_queue = []
s_deb = (0, 0, None, None)
heappush(priority_queue, (0, s_deb))

cost_so_far = {}
cost_so_far[s_deb] = 0

came_from = {}
came_from[s_deb] = None

while priority_queue:
    current_node = heappop(priority_queue)[1]
    for next_node in get_neighbours(grid, current_node):
        new_cost = cost_so_far[current_node] + grid[next_node[0], next_node[1]]
        if next_node not in cost_so_far or new_cost < cost_so_far[next_node]:
            cost_so_far[next_node] = new_cost
            heappush(priority_queue, (new_cost, next_node))
            came_from[next_node] = current_node

min_heat_loss = min([v for k, v in cost_so_far.items() if (k[0] == grid.shape[0] - 1) and k[1] == grid.shape[1] - 1])
min_heat_loss

970

In [70]:
def get_path(grid, came_from, n_end):
    path = []
    previous = n_end
    while previous != s_deb:
        path.append(previous)
        previous = came_from[previous]
    path.append(s_deb)
    path = list(reversed(path))
    return path


def dict_to_file(d, file):
    d = {str(k): str(v) for k, v in d.items()}
    with open('test.json', 'w') as file:
        json.dump(d, file, ensure_ascii=False, indent=4)


def plot_path(grid, path, file):
    grid_plot = grid.astype(str)
    for x, y, _, _ in path:
        grid_plot[x, y] = "#"
    np.savetxt(file, grid_plot, fmt="%s")