In [271]:
import pandas as pd
import numpy as np
from functools import cache
from itertools import permutations

In [272]:
with open('../input/day21.txt', 'r') as f:
    lines = f.read().splitlines()

lines[:2]

['382A', '463A']

In [273]:
@cache
def dir_to_num_kp(key_cur, key_new):
    if key_cur == key_new:
        return ''
        
    if key_cur == '0' and key_new == 'A':
        return '>'

    if key_cur == 'A' and key_new == '0':
        return '<'
    
    if key_cur == 'A':
        row_new = np.ceil(int(key_new)/3).astype(int)
        col_new = (int(key_new) - 1) % 3

        if col_new != 1:
            return row_new * '^' + np.abs(col_new - 2) * '<'
        else:
            return '<^' + dir_to_num_kp('2', key_new)

    if key_cur == '0':
        return '^' + dir_to_num_kp('2', key_new)

    col_cur= (int(key_cur) - 1) % 3
    row_cur = np.ceil(int(key_cur)/3).astype(int)

    if key_new == 'A':
        if col_cur == 0:
            return '>>' + (row_cur) * 'v'
        return (row_cur) * 'v' + (2 - col_cur) * '>'

    if key_new == '0':
        if col_cur == 0:
            return '>' + (row_cur) * 'v'
        return dir_to_num_kp(key_cur, '2') + 'v'

 
    if key_cur not in '0A' and key_new not in '0A':

        row_new = np.ceil(int(key_new)/3).astype(int)
        row_cur = np.ceil(int(key_cur)/3).astype(int)
        col_new = (int(key_new) - 1) % 3
        col_cur= (int(key_cur) - 1) % 3
        command = ''
    
        if col_new < col_cur:
            command = command + (col_cur - col_new) * '<'
        if row_new < row_cur:
            command = command + (row_cur - row_new) * 'v'
        if row_new >= row_cur:
            command = command + (row_new - row_cur) * '^'
        if col_new >= col_cur:
            command = command + (col_new - col_cur) * '>'
            
    return command

In [274]:
@cache
def dir_to_dir_kp(key_cur, key_new):
    key_dict = {
        '<': 1,
        'v': 2,
        '>': 3,
        '^': 5,
        'A': 6,
    }

    command = ''
    kn = key_dict[key_new]
    kc = key_dict[key_cur]
    
    rn = int(kn > 4)
    rc = int(kc > 4)

    cn = (kn - 1) % 3
    cc = (kc - 1) % 3
    
    if key_cur == '<' and key_new != '<':
        return (cn - cc) * '>' + np.abs(rn - rc) * '^'

    if key_cur != '<' and key_new == '<':
        return np.abs(rn - rc) * 'v' + (cc - cn) * '<' 
   
    if rn <= rc:
        command = command + (rc-rn) * 'v'
        
    if cn < cc:
        command = command + (cc - cn) * '<'

    if rn > rc:
        command = command + '^'
           
    if cn >= cc:
        command = command + (cn - cc) * '>'


    return command

In [275]:
def get_first_code(line):
    cur = 'A'
    command = ''
    for k in range(len(line)):
        new = line[k]
        command = command + dir_to_num_kp(cur, new) + 'A'
        cur = new
    
    return command

In [276]:
def get_next_code(command):
    cur = 'A'
    next_command = ''
    for k in range(len(command)):
        new = command[k]
        next_command = next_command + dir_to_dir_kp(cur, new) + 'A'
        cur = new

    return next_command

In [277]:
@cache
def get_complexity(line, iters):
    
    code = get_first_code(line)
    for i in range(iters):
        code = get_next_code(code)
    return code    

In [278]:
p1 = 0
for line in lines:
    p1 += len(get_complexity(line, 2)) * int(line[:3])
p1

179444

In [279]:
dpad_to_dpad = {
        ('<', '<'): '', 
        ('<', '^'): '>^', 
        ('<', '>'): '>>', 
        ('<', 'v'): '>', 
        ('<', 'A'): '>>^', 
        ('^', '<'): 'v<', 
        ('^', '^'): '', 
        ('^', '>'): 'v>', 
        ('^', 'v'): 'v', 
        ('^', 'A'): '>', 
        ('>', '<'): '<<', 
        ('>', '^'): '<^', 
        ('>', '>'): '', 
        ('>', 'v'): '<', 
        ('>', 'A'): '^', 
        ('v', '<'): '<', 
        ('v', '^'): '^', 
        ('v', '>'): '>', 
        ('v', 'v'): '', 
        ('v', 'A'): '^>',
        ('A', '<'): 'v<<', 
        ('A', '^'): '<', 
        ('A', '>'): 'v', 
        ('A', 'v'): 'v<', 
        ('A', 'A'): '', 
    }

In [280]:
def is_valid(cur, nxt, path):
    if len(path) == 0:
        return True
        
    if cur == '<':
        if path[0] == '^':
            return False

    if nxt == '<':
        if path[-1] == 'v':
            return False

    return True

In [281]:
def get_all_paths(cur, nxt):
    p = permutations(dpad_to_dpad[(cur, nxt)])
    return ["".join(pt) + 'A' for pt in p if is_valid(cur, nxt, pt)]
    

In [282]:
@cache
def totes(code, bot):
    if bot == 0:
        return len(code)
    total = 0
    acode = 'A' + code
    for s, e in zip(acode, code):
        total += min(totes(pt, bot - 1) for pt in get_all_paths(s, e))
    return total

In [286]:
for bot in [2, 25]:
    length = 0
    for line in lines:
        code = get_first_code(line)
        length += (totes(code, bot)) * int(line[:3])
    
    print(f'For {bot=} bots, the min {length=}')

For bot=2 bots, the min length=179444
For bot=25 bots, the min length=223285811665866
