In [1]:
import numpy as np
from collections import defaultdict, deque

In [2]:
test = ['###############',
        '#...#...#.....#',
        '#.#.#.#.#.###.#',
        '#S#...#.#.#...#',
        '#######.#.#.###',
        '#######.#.#...#',
        '#######.#.###.#',
        '###..E#...#...#',
        '###.#######.###',
        '#...###...#...#',
        '#.#####.#.###.#',
        '#.#...#.#.#...#',
        '#.#.#.#.#.#.###',
        '#...#...#...###',
        '###############']

In [3]:
def get_graph(data):
    graph = []
    for line in data:
        line = line.strip()
        graph.append(list(line))
    graph = np.array(graph)
    
    start = np.where(graph == 'S')
    start = (start[0][0], start[1][0])
    
    end = np.where(graph == 'E')
    end = (end[0][0], end[1][0])
    
    return graph, start, end

def in_graph(graph, pos):
    if pos[0] < 0 or pos[0] >= len(graph):
        return False
    elif pos[1] < 0 or pos[1] >= len(graph[pos[0]]):
        return False
    return True

def BFS(graph, start):
    dxdy = [[-1,0],[0,1],[1,0],[0,-1]]

    queue = deque([start])    
    dist = {start:0}
    
    while len(queue):
        cur_pos = queue.pop()
        cur_dst = dist[cur_pos]
        nxt_dst = cur_dst+1
        
        for xy in dxdy:
            nxt_pos = (cur_pos[0]+xy[0], cur_pos[1]+xy[1])
            if graph[nxt_pos[0],nxt_pos[1]] == '#':
                continue
            
            if nxt_pos in dist.keys():
                if nxt_dst < dist[nxt_pos]:
                    dist[nxt_pos] = nxt_dst
                    if graph[nxt_pos[0],nxt_pos[1]] != 'E':
                        queue.append(nxt_pos)
            else:
                dist[nxt_pos] = nxt_dst
                if graph[nxt_pos[0],nxt_pos[1]] != 'E':
                    queue.append(nxt_pos)
    return dist
                
def get_clean_path(dist, start, end):
    path = [end]
    dxdy = [[-1,0],[0,1],[1,0],[0,-1]]
    
    while True:
        cur_dst = dist[path[-1]]
        for xy in dxdy:
            nxt_pos = (path[-1][0]-xy[0], path[-1][1]-xy[1])
            
            if nxt_pos in dist.keys() and dist[nxt_pos] == cur_dst-1:
                path.append(nxt_pos)
                break
                
        if path[-1] == start:
            break
            
    path.reverse()
    
    vec_path = []
    for pos in path:
        vec_path.append([pos[0],pos[1]])
    vec_path = np.array(vec_path)
    
    return vec_path

def count_shortcuts(clean_path, pos, dist, shortcuts, length=2, limit=100, prnt=False):
    start_dist = dist[(pos[0],pos[1])]
    
    can_cut = np.where(np.abs(clean_path[:,0]-pos[0])+np.abs(clean_path[:,1]-pos[1]) <= length)
    clean_path = clean_path[can_cut]
    cut_dist = np.sum(np.abs(clean_path-pos), axis=1)
    
    delta = []
    for i in range(0, len(clean_path)):
        delta.append(dist[(clean_path[i][0],clean_path[i][1])] - start_dist - cut_dist[i])
    delta = np.array(delta)
    
    able = np.where(delta >= limit)[0]
    if prnt:
        unq, unq_cnt = np.unique(delta[able], return_counts=True)
        for i in range(0, len(unq)):
            shortcuts[unq[i]] += unq_cnt[i]
    else:
        
        shortcuts += len(able)
    
    
    return shortcuts

def get_shortcuts(graph, dist, clean_path, length=2, limit=100, prnt=False):
    if prnt:
        shortcuts = defaultdict(int)
    else:
        shortcuts = 0
    
    for pos in clean_path:
        shortcuts = count_shortcuts(clean_path, pos, dist, shortcuts, length, limit, prnt)
        
    return shortcuts
            
def run(data, prnt=False, limit=[100,100]):
    graph, start, end = get_graph(data)
    dist = BFS(graph, start)
    clean_path = get_clean_path(dist, start, end)
    
    #PART 1
    shortcuts = get_shortcuts(graph, dist, clean_path, limit=limit[0], prnt=prnt)
    
    if prnt:
        saved_count = 0
        keys = list(shortcuts.keys())
        keys.sort()
        for key in keys:
            print('There are', shortcuts[key], 'cheats that save', key, 'picoseconds.')
            saved_count += shortcuts[key]
        shortcuts = saved_count
        
    print('Part 1 result:', shortcuts)
    
    # PART 2
    shortcuts = get_shortcuts(graph, dist, clean_path, length=20, limit=limit[1], prnt=prnt)
    
    if prnt:
        saved_count = 0
        keys = list(shortcuts.keys())
        keys.sort()
        for key in keys:
            print('There are', shortcuts[key], 'cheats that save', key, 'picoseconds.')
            saved_count += shortcuts[key]
        shortcuts = saved_count
            
    print('Part 2 result:', shortcuts)
    
run(test, True, [1,50])

There are 14 cheats that save 2 picoseconds.
There are 14 cheats that save 4 picoseconds.
There are 2 cheats that save 6 picoseconds.
There are 4 cheats that save 8 picoseconds.
There are 2 cheats that save 10 picoseconds.
There are 3 cheats that save 12 picoseconds.
There are 1 cheats that save 20 picoseconds.
There are 1 cheats that save 36 picoseconds.
There are 1 cheats that save 38 picoseconds.
There are 1 cheats that save 40 picoseconds.
There are 1 cheats that save 64 picoseconds.
Part 1 result: 44
There are 32 cheats that save 50 picoseconds.
There are 31 cheats that save 52 picoseconds.
There are 29 cheats that save 54 picoseconds.
There are 39 cheats that save 56 picoseconds.
There are 25 cheats that save 58 picoseconds.
There are 23 cheats that save 60 picoseconds.
There are 20 cheats that save 62 picoseconds.
There are 19 cheats that save 64 picoseconds.
There are 12 cheats that save 66 picoseconds.
There are 14 cheats that save 68 picoseconds.
There are 12 cheats that save

In [4]:
with open('input_day20.txt', 'r') as f:
    data = f.readlines()
    f.close()
    
run(data)

Part 1 result: 1372
Part 2 result: 979014
