In [1]:
import numpy as np
from collections import defaultdict
from copy import deepcopy
import heapq

from PIL import Image

from tol_colors import tol_cmap

import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
data = np.genfromtxt('day17_input.txt', dtype=str, delimiter='\n', comments=None)

In [3]:
def gen_graph(data):
    graph = []
    for line in data:
        graph_line = []
        for c in line:
            graph_line.append(int(c))
        graph.append(graph_line)
    return np.array(graph)

def dijkstra(graph, start, end, min_chain=1, max_chain=3):
    # 0 >
    # 1 v
    # 2 <
    # 3 ^
    
    dxdy = [[0,1], [1,0], [0,-1], [-1,0]]
    
    distances = {}
    distances[(start[0],start[1],0,0)] = 0
    for i in range(0, len(dxdy)):
        for j in range(1, max_chain+1):
            distances[(start[0],start[1],i,j)] = 0
    
    routes = {}
    for i in range(0, len(dxdy)):
        for j in range(1, max_chain+1):
            routes[(start[0],start[1],i,j)] = [(0, start, 0, 0)]
    
    queue = [(0, start, 0, 0)] #distance, position, direction, chain
    while queue:
        cur_dist, cur_pos, cur_dir, cur_chain = heapq.heappop(queue)
        
        for nxt_dir in range(0, len(dxdy)):
            #cannot backtrack
            if nxt_dir == (cur_dir+2)%4:
                continue
                
            #Must move at least min_chain in certain direction
            if cur_chain < min_chain and nxt_dir != cur_dir:
                continue
            
            #cannot travel more than max_chain in one direction
            nxt_chain = 1
            if nxt_dir == cur_dir:
                nxt_chain += cur_chain
            if nxt_chain > max_chain:
                continue
                
            nxt_pos = (cur_pos[0]+dxdy[nxt_dir][0],cur_pos[1]+dxdy[nxt_dir][1])
            #position must exist...
            if 0 > nxt_pos[0] or nxt_pos[0] >= len(graph) or 0 > nxt_pos[1] or nxt_pos[1] >= len(graph[0]):
                continue
            #...and must not have been visited
            if (nxt_pos[0], nxt_pos[1], nxt_dir, nxt_chain) in distances.keys():
                continue
                
            nxt_dist = cur_dist+graph[nxt_pos[0],nxt_pos[1]]
            if nxt_chain >= min_chain:
                distances[(nxt_pos[0], nxt_pos[1], nxt_dir, nxt_chain)] = nxt_dist
            heapq.heappush(queue, (nxt_dist, nxt_pos, nxt_dir, nxt_chain))
            
            if cur_pos[0] == 0 and cur_pos[1] == 0 and cur_dist == 0 and cur_chain == 0:
                rout = [(0, start, 0, 0)]
            else:
                rout = deepcopy(routes[(cur_pos[0], cur_pos[1], cur_dir, cur_chain)])
            rout.append((nxt_dist, nxt_pos, nxt_dir, nxt_chain))
            routes[(nxt_pos[0], nxt_pos[1], nxt_dir, nxt_chain)] = rout
            
    end_distances = []
    ends = []
    for i in range(0, len(dxdy)):
        for j in range(1, max_chain+1):
            if (end[0],end[1],i,j) in distances.keys():
                end_distances.append(distances[(end[0],end[1],i,j)])
                ends.append((end[0],end[1],i,j))
            
    min_idx = np.argmin(end_distances)
    
    return distances, routes, ends[min_idx]

In [4]:
def create_frames(data, min_chain=1, max_chain=3):
    graph = gen_graph(data)
    end = (len(graph)-1,len(graph[0])-1)
    distances, routes, end = dijkstra(graph, (0,0), end, min_chain, max_chain)

    cmap = tol_cmap('rainbow_PuRd')
    max_los = distances[end]

    blank_frame = np.ones((graph.shape[0],graph.shape[1],3))
    for i in range(0, len(blank_frame)):
        for j in range(0, len(blank_frame[i])):
            c = np.array(cmap(1-(graph[i,j]/9))[:3])*255
            blank_frame[i,j] = c #[232,236,251]

    dxdy = [[0,1], [1,0], [0,-1], [-1,0]]

    #print(routes[end])
    cmap = tol_cmap('iridescent')

    for h in range(0, len(routes[end])):
        place = routes[end][h]
        frame = deepcopy(blank_frame)

        pos_distances = []
        pos = []
        for i in range(0, len(dxdy)):
            for j in range(1, max_chain+1):
                if (place[1][0],place[1][1],i,j) in distances.keys():
                    pos_distances.append(distances[(place[1][0],place[1][1],i,j)])
                    pos.append((place[1][0],place[1][1],i,j))
        min_idx = np.argmin(pos_distances)

        for j in range(0, len(routes[pos[min_idx]])):
            los = distances[(routes[pos[min_idx]][j][1][0], routes[pos[min_idx]][j][1][1],
                             routes[pos[min_idx]][j][2], routes[pos[min_idx]][j][3])]
            c = np.array(cmap((los/max_los))[:3])*255

            frame[routes[pos[min_idx]][j][1][0],routes[pos[min_idx]][j][1][1]] = c.astype(int) #[114,30,23] #

        image = Image.fromarray(frame.astype('uint8'), mode='RGB')
        image = image.resize((frame.shape[0]*8,frame.shape[1]*8), resample=Image.NEAREST)
        image.save('./Day17-Frames/day17_'+str(min_chain)+'_'+str(max_chain)+'_'+str(h).zfill(4)+'.png')

In [5]:
create_frames(data)

create_frames(data, 4, 10)