In [1]:
import numpy as np
#from collections import deque
from collections import defaultdict
from copy import deepcopy
import heapq

In [2]:
test = ['2413432311323',
        '3215453535623',
        '3255245654254',
        '3446585845452',
        '4546657867536',
        '1438598798454',
        '4457876987766',
        '3637877979653',
        '4654967986887',
        '4564679986453',
        '1224686865563',
        '2546548887735',
        '4322674655533']

test_path = ['2>>34^>>>1323',
             '32v>>>35v5623',
             '32552456v>>54',
             '3446585845v52',
             '4546657867v>6',
             '14385987984v4',
             '44578769877v6',
             '36378779796v>',
             '465496798688v',
             '456467998645v',
             '12246868655<v',
             '25465488877v5',
             '43226746555v>']

data = np.genfromtxt('day17_input.txt', dtype=str, delimiter='\n', comments=None)

In [3]:
def print_test_path(test_path):
    for i in range(0, len(test_path)):
        string = ''
        for j in range(0, len(test_path[i])):
            if test_path[i][j].isdigit():
                string += '.'
            else:
                string += test_path[i][j]
        print(string)
    print()

def gen_graph(data):
    graph = []
    for line in data:
        graph_line = []
        for c in line:
            graph_line.append(int(c))
        graph.append(graph_line)
    return np.array(graph)

def dijkstra(graph, start, end, min_chain=1, max_chain=3, prnt=False):
    # 0 >
    # 1 v
    # 2 <
    # 3 ^
    
    dxdy = [[0,1], [1,0], [0,-1], [-1,0]]
    
    distances = {}
    for i in range(0, len(dxdy)):
        for j in range(1, max_chain+1):
            distances[(start[0],start[1],i,j)] = 0
    
    if prnt:
        routes = {}
        for i in range(0, len(dxdy)):
            for j in range(1, max_chain+1):
                routes[(start[0],start[1],i,j)] = [(0, start, 0, 0)]
    
    queue = [(0, start, 0, 0)] #distance, position, direction, chain
    while queue:
        cur_dist, cur_pos, cur_dir, cur_chain = heapq.heappop(queue)
        
        for nxt_dir in range(0, len(dxdy)):
            #cannot backtrack
            if nxt_dir == (cur_dir+2)%4:
                continue
                
            #Must move at least min_chain in certain direction
            if cur_chain < min_chain and nxt_dir != cur_dir:
                continue
            
            #cannot travel more than max_chain in one direction
            nxt_chain = 1
            if nxt_dir == cur_dir:
                nxt_chain += cur_chain
            if nxt_chain > max_chain:
                continue
                
            nxt_pos = (cur_pos[0]+dxdy[nxt_dir][0],cur_pos[1]+dxdy[nxt_dir][1])
            #position must exist...
            if 0 > nxt_pos[0] or nxt_pos[0] >= len(graph) or 0 > nxt_pos[1] or nxt_pos[1] >= len(graph[0]):
                continue
            #...and must not have been visited
            if (nxt_pos[0], nxt_pos[1], nxt_dir, nxt_chain) in distances.keys():
                continue
                
            nxt_dist = cur_dist+graph[nxt_pos[0],nxt_pos[1]]
            if nxt_chain >= min_chain:
                #if nxt_pos[0] == 12 and nxt_pos[1] == 12:
                #    print(nxt_dist, min_chain, nxt_chain)
                #    print((nxt_pos[0], nxt_pos[1], nxt_dir, nxt_chain))
                distances[(nxt_pos[0], nxt_pos[1], nxt_dir, nxt_chain)] = nxt_dist
            heapq.heappush(queue, (nxt_dist, nxt_pos, nxt_dir, nxt_chain))
            
            if prnt:
                if cur_pos[0] == 0 and cur_pos[1] == 0 and cur_dist == 0 and cur_chain == 0:
                    rout = [(0, start, 0, 0)]
                else:
                    rout = deepcopy(routes[(cur_pos[0], cur_pos[1], cur_dir, cur_chain)])
                rout.append((nxt_dist, nxt_pos, nxt_dir, nxt_chain))
                routes[(nxt_pos[0], nxt_pos[1], nxt_dir, nxt_chain)] = rout
            
    end_distances = []
    ends = []
    for i in range(0, len(dxdy)):
        for j in range(1, max_chain+1):
            if (end[0],end[1],i,j) in distances.keys():
                end_distances.append(distances[(end[0],end[1],i,j)])
                ends.append((end[0],end[1],i,j))
            
    min_idx = np.argmin(end_distances)
    
    if prnt:
        return end_distances[min_idx], routes[ends[min_idx]]
    else:
        return end_distances[min_idx], None

def print_path(graph, route, max_chain=3):
    move = ['>', 'v', '<', '^']
    
    path = {}
    for value in route:
        if value[1] == (0,0):
            continue
        path[value[1]] = value[2]
    
    for i in range(0, len(graph)):
        string = ''
        for j in range(0, len(graph[i])):
            empty = True
            if (i,j) in path.keys():
                string += move[path[(i,j)]]
            else:
                string += '.'
        print(string)
    print()

def least_loss(data, min_chain=1, max_chain=3, prnt=False):
    graph = gen_graph(data)
    end = (len(graph)-1,len(graph[0])-1)
    dist, route = dijkstra(graph, (0,0), end, min_chain, max_chain, prnt)
    
    if prnt:
        print_path(graph, route, max_chain)
    
    return dist

print_test_path(test_path)

print(least_loss(test, prnt=False))
print('Part 1 result:', least_loss(data))

.>>..^>>>....
..v>>>..v....
........v>>..
..........v..
..........v>.
...........v.
...........v.
...........v>
............v
............v
...........<v
...........v.
...........v>

102
Part 1 result: 953


In [4]:
test2 = ['111111111111',
         '999999999991',
         '999999999991',
         '999999999991',
         '999999999991']

test_path_p2 = ['2>>>>>>>>1323',
                '32154535v5623',
                '32552456v4254',
                '34465858v5452',
                '45466578v>>>>',
                '143859879845v',
                '445787698776v',
                '363787797965v',
                '465496798688v',
                '456467998645v',
                '122468686556v',
                '254654888773v',
                '432267465553v']

test2_path = ['1>>>>>>>1111',
              '9999999v9991',
              '9999999v9991',
              '9999999v9991',
              '9999999v>>>>']

print_test_path(test_path_p2)
print(least_loss(test, 4, 10, prnt=False))
print_test_path(test2_path)
print(least_loss(test2, 4, 10, prnt=False))
print('Part 2 result:', least_loss(data, 4, 10))

.>>>>>>>>....
........v....
........v....
........v....
........v>>>>
............v
............v
............v
............v
............v
............v
............v
............v

94
.>>>>>>>....
.......v....
.......v....
.......v....
.......v>>>>

71
Part 2 result: 1180
