In [63]:
from itertools import product
from functools import cache

import numpy as np

In [64]:
move2dir = {(-1, 0): '^', (1, 0): 'v', (0, -1): '<', (0, 1): '>'}

In [65]:
def parse_input(file):
    with open(file) as file_in:
        codes = file_in.read().splitlines()
    return codes


def get_keys2coords(grid):
    keys = set([key.item() for key in np.unique(grid) if key != 'Z'])
    keys2coords = {key: tuple(np.argwhere(grid == key)[0].tolist()) 
                        for key in keys}
    return keys2coords


def path2dirs(path):
    dirs = []
    for i in range(len(path)-1):
        x, y = path[i]
        x_next, y_next = path[i+1]
        diff = (x_next-x, y_next-y)
        dirs.append(move2dir[diff])
    return ''.join(dirs) + 'A'


def get_min_paths(grid, start, end, keys2coords):
    start = keys2coords[start]
    end = keys2coords[end]
    n_rows, n_cols = grid.shape

    # Standard BFS but we'll keep track of the paths for each node
    queue = deque([[start]])
    paths = []

    while queue:
        path = queue.popleft()
        x_current,y_current = path[-1]

        if (x_current,y_current) == end:
            paths.append(path)

        for nx, ny in move2dir.keys():
            x_next, y_next = x_current + nx, y_current + ny
            if 0 <= x_next < n_rows and 0 <= y_next < n_cols and grid[x_next, y_next] != 'Z':
                if (x_next, y_next) not in path:
                    queue.append(path + [(x_next, y_next)])

    min_length = min(len(p) for p in paths)
    min_paths = [p for p in paths if len(p) == min_length]
    min_paths_as_dirs = [path2dirs(p) for p in min_paths]

    return min_paths_as_dirs


def get_all_minpaths(keypad_str):
    grid = np.array([list(row) for row in keypad_str.splitlines()])
    keys2coords = get_keys2coords(grid)

    all_minpaths = {key: {} for key in keys2coords}
    for x in keys2coords:
        for y in keys2coords:
            all_minpaths[x][y] = get_min_paths(grid, x, y, keys2coords)

    return all_minpaths

def get_all_minpaths_keypads():
    all_minpaths_num = get_all_minpaths('789\n456\n123\nZ0A')
    all_minpaths_dir = get_all_minpaths('Z^A\n<v>')

    return all_minpaths_num, all_minpaths_dir


def get_next_sequences(current_seq, minpaths):
    current_seq = 'A' + current_seq
    possibilities = [minpaths[src][dst] for src, dst in zip(current_seq, current_seq[1:])]
    possibilites = [''.join(seq) for seq in product(*possibilities)]
    return possibilites


def get_minpaths_dir_lengths(minpaths_dir):
    all_minpaths_dir_lengths = {key: {} for key in minpaths_dir}
    for key1 in minpaths_dir:
        for key2 in minpaths_dir[key1]:
            all_minpaths_dir_lengths[key1][key2] = min([len(seq) for seq in minpaths_dir[key1][key2]])
    return all_minpaths_dir_lengths


@cache
def compute_min_length(key1, key2, depth):
    if depth == 1:
        return minpaths_dir_lengths[key1][key2]

    min_length = float('inf')
    for seq in minpaths_dir[key1][key2]:
        current_length = 0
        for src, dst in zip('A' + seq, seq):
            current_length += compute_min_length(src, dst, depth-1)
        min_length = min(min_length, current_length)

    return min_length


def main(file, n_intermediary_robots):
    codes = parse_input(file)

    total_complexity = 0

    for code in codes:
        current_sequences = get_next_sequences(code, minpaths_num)
        min_len_code = float('inf')
        for seq in current_sequences:
            seq_len = 0
            for key1, key2 in zip('A' + seq, seq):
                seq_len += compute_min_length(key1, key2, depth=n_intermediary_robots)
            min_len_code = min(min_len_code, seq_len)
        code_num_part = int(''.join(char for char in list(code.lstrip("0")) if char.isnumeric()))
        total_complexity += min_len_code * code_num_part

    return total_complexity

In [66]:
minpaths_num, minpaths_dir = get_all_minpaths_keypads()
minpaths_dir_lengths = get_minpaths_dir_lengths(minpaths_dir)

In [67]:
assert main('example1.txt', n_intermediary_robots=2) == 126384

In [68]:
main('input.txt', n_intermediary_robots=2)

162740

In [69]:
main('input.txt', n_intermediary_robots=25)

203640915832208