In [126]:
from aocd.models import Puzzle

puzzle = Puzzle(year=2024, day=21)

def parses(data):
    return data.strip().split('\n')

# import re
# def parses(data):
#     return [[int(i) for i in re.findall("-?\d+", line)] 
#              for line in data.strip().split('\n')]

data = parses(puzzle.input_data)

In [219]:
sample = parses("""029A
980A
179A
456A
379A""")

(41713, 37581)

In [211]:
from functools import cache
"""
In order to solve this problem, it is key we exploit the structure of the setup.

- A key realization is that clicking any input on the numpad, requires clicking A on ALL
   directional pads, as it requires clicking A on the first robot, whose controlled by the second robot, &c
-  This means that each digit can be solved indenpendently. E.g Solving clicking 029A, reduces to solving
   1. move from A to 0  = <A
   2. move from 0 to 2  = ^A
   3. move from 2 to 9  = >^^A
   4. move from 0 to A  = vvvA 
- We can treat the numpad as a "seed" sequence and then expand it at every step. See the example above for 029A
   
- We can empirically find that when moving, there's a priority
   0. Same move, e.g. we'll always prefer <<^ to <^<, as repeating a move is simply A. In practice this means 
      that we ignore any zigzaggy pattern and focus solely on L-shaped moves
   1. < moving right is the most expensive move. We want to avoid sequences of the
        form A<A, as that "trip" can have been better used to click on any other button
   2. v similarly, moving down requires two moves from A
   3. ^> are equally distant from A
   
- For part2 the trick is realizing that the order doesn't really matter as rule expansion is purely local.
  This is further simplified by the fact that we can assume that when translating a move to have an extra layer 
  of indirection we can safely assume that it starts with "A", since the previous character had to be inputed
  by finished with A
"""

@cache
def moves():
    buttons = '^<v>A'
    
    # This matrix is not arbitrary, on similar operations (e.g. <v vs v<), we choose to prioritize
    # the < over the v (and the v over ^>).
    move_table = [
        ['', 'v<', None, 'v>', '>'],
        ['>^', '', '>', None, '>>^'],
        [None, '<', '', '>', '^>'],
        ['<^', None, '<', '', '^'],
        ['<', 'v<<', '<v', 'v', '']
    ]

    move = {} # dict of (src,dst) -> Moves a robot would have to do
    for i, a in enumerate(buttons):
        for j, b in enumerate(buttons):
            if move_table[i][j] is not None:
                move[a,b] = move_table[i][j]+'A'
    return move

def numpad_seq(src, dst):
    numpad = ['789','456','123','X0A']
    nums = { v: (i,j)
        for i, row in enumerate(numpad)
        for j, v in enumerate(row)}
    y1, x1 = nums[src]
    y2, x2 = nums[dst]
    hmoves = '<>'[x2>x1]*abs(x1-x2)
    vmoves = '^v'[y2>y1]*abs(y1-y2)
    if (src in '741' and dst in '0A'): # can't go forbidden space
        return hmoves+vmoves+'A'
    if (src in '0A' and dst in '741'): # can't go forbidden space
        return vmoves+hmoves+'A'
    if hmoves.startswith('<'):
        return hmoves+vmoves+'A'
    if vmoves.startswith('v'):
        return vmoves+hmoves+'A'
    return hmoves+vmoves+'A' # doesn't matter

def solve_code(code):
    total_len = 0
    for src, dst in zip('A'+code, code):
        seq = numpad_seq(src, dst)
        for _ in range(2):
            seq = ''.join((moves()[a,b] for a, b in zip('A'+seq, seq)))
        total_len += len(seq)
    print(total_len, code)
    return total_len, int(code[:-1])
        
def solve_a(data):
    return sum((length * num for length, num in map(solve_code, data)))


In [216]:

### PART 2

def as_transitions(seq):
    counts = Counter()
    for a, b in zip('A'+seq, seq):
        counts[a,b] += 1
    return counts

@cache
def move_transitions():
    return {k: as_transitions(v) for k, v in moves().items()}

def remote(counts):
    new_counts = Counter()
    for (src, dst), n in counts.items():
        for (a, b), m in move_transitions()[src,dst].items():
            new_counts[a, b] += n * m
    return new_counts
    
def solve_code_counts(code, n=25):
    total_len = 0
    for src, dst in zip('A'+code, code):
        trs = as_transitions(numpad_seq(src, dst))
        for _ in range(n):
            trs = remote(trs)
        total_len += sum(trs.values())
    return total_len, int(code[:-1])

def solve_b(data):
    return sum((length * num for length, num in map(solve_code_counts, data)))

In [218]:
solve_b(sample)

154154076501218

In [None]:
a = pushes(remote_n('v<A', 10))
b = pushes(remote_n('<vA', 10))
a, b

In [217]:
solve_b(data)

157055032722640

In [215]:
solve_code_counts('029A')

(83373109466, 29)

In [160]:
numpad_seqs(None, None)

{(0, 0): '7',
 (0, 1): '8',
 (0, 2): '9',
 (1, 0): '4',
 (1, 1): '5',
 (1, 2): '6',
 (2, 0): '1',
 (2, 1): '2',
 (2, 2): '3',
 (3, 0): 'X',
 (3, 1): '0',
 (3, 2): 'A'}

In [None]:
init_seq = []

code = '029A'

def solve_code(code):
    total = 0
    for src, dst in zip('A'+code, code):
        
        seqs = numpad_seqs(src, dst)
#         print(seqs)
#         print(seqs)
        for _ in range(2):
            seqs = next_level(seqs)
        total += min([len(s) for s, _ in seqs])
#         print('🔥')
    return total ,int(code[:-1])

solve_code('029A')

In [154]:
remote(as_transitions('<A'))

Counter({('A', 'v'): 1,
         ('v', '<'): 1,
         ('<', '<'): 1,
         ('<', 'A'): 1,
         ('A', '>'): 1,
         ('>', '>'): 1,
         ('>', '^'): 1,
         ('^', 'A'): 1})

In [7]:
fromto_counts = {}
for (src, dst), seq in fromto.items():
    counts = Counter()
    for a, b in zip('A'+seq, seq):
        counts[a,b] += 1
    fromto_counts[src, dst] = counts

In [8]:
fromto_counts

{('^', '^'): Counter({('A', 'A'): 1}),
 ('^', '<'): Counter({('A', 'v'): 1, ('v', '<'): 1, ('<', 'A'): 1}),
 ('^', '>'): Counter({('A', 'v'): 1, ('v', '>'): 1, ('>', 'A'): 1}),
 ('^', 'A'): Counter({('A', '>'): 1, ('>', 'A'): 1}),
 ('<', '^'): Counter({('A', '>'): 1, ('>', '^'): 1, ('^', 'A'): 1}),
 ('<', '<'): Counter({('A', 'A'): 1}),
 ('<', 'v'): Counter({('A', '>'): 1, ('>', 'A'): 1}),
 ('<',
  'A'): Counter({('A', '>'): 1, ('>', '>'): 1, ('>', '^'): 1, ('^', 'A'): 1}),
 ('v', '<'): Counter({('A', '<'): 1, ('<', 'A'): 1}),
 ('v', 'v'): Counter({('A', 'A'): 1}),
 ('v', '>'): Counter({('A', '>'): 1, ('>', 'A'): 1}),
 ('v', 'A'): Counter({('A', '>'): 1, ('>', '^'): 1, ('^', 'A'): 1}),
 ('>', '^'): Counter({('A', '<'): 1, ('<', '^'): 1, ('^', 'A'): 1}),
 ('>', 'v'): Counter({('A', '<'): 1, ('<', 'A'): 1}),
 ('>', '>'): Counter({('A', 'A'): 1}),
 ('>', 'A'): Counter({('A', '^'): 1, ('^', 'A'): 1}),
 ('A', '^'): Counter({('A', '<'): 1, ('<', 'A'): 1}),
 ('A',
  '<'): Counter({('A', 'v'):

In [13]:
def remote(counts):
    new_counts = Counter()
    for (src, dst), n in counts.items():
        for (a, b), m in fromto_counts[src,dst].items():
            new_counts[a, b] += n * m
    return new_counts
    

In [23]:
sum(remote(remote(remote(Counter({('A', '>'): 1})))).values())

16

In [107]:
def as_transitions(seq):
    counts = Counter()
    for a, b in zip('A'+seq, seq):
        counts[a,b] += 1
    return counts

In [51]:
def pushes(counts):
    return sum(counts.values())

In [None]:
179A

In [52]:
def remote_n(seq, n):
    seq = as_transitions(seq)
    for _ in range(n):
        seq = remote(seq)
    return seq

In [58]:
a = pushes(remote_n('<<^A', 2))
b = pushes(remote_n('^<<A', 2))
a, b

(22, 26)

In [125]:
a = pushes(remote_n('>vA', 2))
b = pushes(remote_n('v>A', 2))
a, b

(21, 17)

In [124]:
a = pushes(remote_n('>^A', 1))
b = pushes(remote_n('^>A', 1))
a, b

(7, 7)

In [49]:
as_transitions('<<^A')

Counter({('A', '>'): 2,
         ('A', 'v'): 1,
         ('v', '<'): 1,
         ('<', '<'): 1,
         ('<', 'A'): 1,
         ('A', 'A'): 1,
         ('>', '^'): 1,
         ('^', 'A'): 1,
         ('>', 'A'): 1})

In [None]:
codes = {
    
}

In [None]:
'^': 

In [95]:
sample = parses("""029A
980A
179A
456A
379A""")

In [81]:
def parse_keypad(keypad):
    pos = {}
    rows = keypad.split(':')
    for i, row in enumerate(rows):
        for j, v in enumerate(row):
            pos[i+j*1j] = v
    return pos

In [82]:
numeric_keypad = parse_keypad("789:456:123:X0A")
directional_keypad = parse_keypad('X^A:<v>')
nump, dirp = numeric_keypad.copy(), directional_keypad.copy()

In [155]:
nump

{0j: '7',
 1j: '8',
 2j: '9',
 (1+0j): '4',
 (1+1j): '5',
 (1+2j): '6',
 (2+0j): '1',
 (2+1j): '2',
 (2+2j): '3',
 (3+0j): 'X',
 (3+1j): '0',
 (3+2j): 'A'}

In [83]:
rdirp = {v:k for k, v in dirp.items()}
rnump = {v:k for k, v in nump.items()}

In [88]:
dirs = {
    '>': 1j, #(0,1),
    '<': -1j, #(0,-1),
    '^': -1, # (-1,0),
    'v': 1, #(1,0),
}

In [39]:
# def heuristic(state, final):
#     ent, num, d1, d2, d3 = state
#     missing = len(final)-len(ent)
# #     for c in final[missing:]:
#     return missing

In [41]:
# from heapq import heappop, heappush
# def search(char):
#     start_state = (0, keyp['A'], dirp['A'], dirp['A'], dirp['A'])
#     heap = [start_state]
#     while heap:
#         cost, heappop(heap)
        

In [335]:
# sanity check
for (src, dst), moves in fromto.items():
    for seq in moves:
        pos = rdirp[src]
        for step in seq[:-1]:
            pos += dirs[step]
        assert seq[-1] == 'A'
        assert pos == rdirp[dst]

In [250]:
seqs = [('<A', 'A')]

newseqs = []

for seq, src in seqs:
    options = [fromto[a,b] for a, b in zip(src+seq, seq)]
    newseqs += [(''.join(steps), seq[-1]) for steps in itertools.product(*options)]
newseqs

[('v<<A>>^A', 'A')]

In [251]:
def next_level(seqs):
    newseqs = []
    for seq, src in seqs:
        options = [fromto[a,b] for a, b in zip(src+seq, seq) if (a,b) in fromto]
        newseqs += [(''.join(steps), seq[-1]) for steps in itertools.product(*options)]
    return newseqs

In [252]:
init_seq = []

code = '029A'

def solve_code(code):
    total = 0
    for src, dst in zip('A'+code, code):
        
        seqs = numpad_seqs(src, dst)
#         print(seqs)
#         print(seqs)
        for _ in range(2):
            seqs = next_level(seqs)
        total += min([len(s) for s, _ in seqs])
#         print('🔥')
    return total ,int(code[:-1])

solve_code('029A')

(68, 29)

In [262]:
as_transitions('<A', 'A')

Counter({('A', '<'): 1, ('<', 'A'): 1})

In [253]:
def solve_a(data):
    total = 0
    for code in data:
        a, b = solve_code(code)
        print(a,b)
        total += a * b
    return total

In [254]:
solve_a(sample) == 126384, solve_a(data) == 125742

68 29
60 980
68 179
64 456
64 379


126384

72 279
68 286
72 508
70 463
70 246


125742

In [90]:
def as_transitions(input):
    seq, start = input
    trans = Counter()
    for src, dst in zip(start+seq, seq):
        trans[src,dst] += 1
    return trans

In [None]:
def next_level(seqs):
    newseqs = []
    for seq, src in seqs:
        options = [fromto[a,b] for a, b in zip(src+seq, seq) if (a,b) in fromto]
        newseqs += [(''.join(steps), seq[-1]) for steps in itertools.product(*options)]
    return newseqs

In [269]:
c = as_transitions(('<A', 'A'))

In [274]:
next_level([('<A', 'A')])

[('v<<A>>^A', 'A')]

In [275]:
as_transitions(('v<<A>>^A', 'A'))

Counter({('A', 'v'): 1,
         ('v', '<'): 1,
         ('<', '<'): 1,
         ('<', 'A'): 1,
         ('A', '>'): 1,
         ('>', '>'): 1,
         ('>', '^'): 1,
         ('^', 'A'): 1})

In [290]:
next_level_trans([c])

[Counter({('A', 'v'): 1,
          ('v', '<'): 1,
          ('<', '<'): 1,
          ('<', 'A'): 1,
          ('A', '>'): 1,
          ('>', '>'): 1,
          ('>', '^'): 1,
          ('^', 'A'): 1})]

In [75]:
# ONE
moves = [
    ['', 'v<', None, 'v>', '>'],
    ['>^', '', '>', None, '>>^'],
    [None, '<', '', '>', '^>'],
    ['<^', None, '<', '', '^'],
    ['<', 'v<<', '<v', 'v', '']
]

# # ALL
# moves = [
#     ['', 'v<', None, 'v>,>v', '>'],
#     ['>^', '', '>', None, '>>^'],
#     [None, '<', '', '>', '>^,^>'],
#     ['<^,^<', None, '<', '', '^'],
#     ['<', 'v<<', 'v<,<v', 'v', '']
# ]

buttons = '^<v>A'

fromto = {}
for i, a in enumerate(buttons):
    for j, b in enumerate(buttons):
        seq = moves[i][j]
        if seq is None: continue
        options = seq.split(',')
        fromto[a,b] = [option+'A' for option in options]

In [102]:
def next_level_trans(cs):
    newcs = []
    for c in cs:
        newc = Counter()
        for (a, b), count in c.items():
            for seq in fromto[a,b]:
                for x, y in zip('A'+seq, seq):
                    newc[x,y] += count
        newcs.append(newc)
    return newcs


def solve_code(code):
    total = 0
    for src, dst in zip('A'+code, code):
        
        seqs = numpad_seqs(src, dst)
        seqs = [as_transitions(seq) for seq in seqs]
#         print(seqs)
#         print(seqs)
        for _ in range(25):
            seqs = next_level_trans(seqs)
            
        
        mins = float('inf')
        for c in seqs:
            mins = min(mins, sum(c.values()))
        total += mins
#         total += min([len(s) for s, _ in seqs])
#         print(locals())
#         print('🔥')
        
    return total ,int(code[:-1])

solve_code('029A')

(82050061710, 29)

In [103]:
def solve_b(data):
    total = 0
    for code in data:
        a, b = solve_code(code)
#         print(a,b)
        total += a * b
    return total

In [104]:
solve_b(sample) == 126384, solve_b(data) == 125742

(False, False)

In [105]:
solve_b(data)

157055032722640

In [100]:
solve_b(data)

125742

In [240]:
solve_a(data)

72 279
68 286
72 508
70 463
70 246


125742

In [185]:
data

['279A', '286A', '508A', '463A', '246A']

In [184]:
solve_a(data)

128918

In [154]:
for code in sample:
    print(code, solve_code(code))

029A (68, 29)
980A (60, 980)
179A (64, 179)
456A (60, 456)
379A (64, 379)


In [85]:
def is_valid(src, moves):
    pos = rnump[src]
    for dir_ in moves:
        pos += dirs[dir_]
        if nump[pos] == 'X':
            return False
    return True

In [86]:
def numpad_seqs(src, dst):
    seqs = []
    src_pos = rnump[src]
    dst_pos = rnump[dst]
    hmoves = vmoves = ''
    delta = dst_pos-src_pos
    if delta.real > 0:
        vmoves += 'v' * int(abs(delta.real))
    if delta.real < 0:
        vmoves += '^' * int(abs(delta.real))
    if delta.imag > 0:
        hmoves += '>' * int(abs(delta.imag))
    if delta.imag < 0:
        hmoves += '<' * int(abs(delta.imag))
    moves = hmoves + vmoves
    
    to_return = []
    for moves in (hmoves+vmoves, vmoves+hmoves):
        if is_valid(src, moves):
            to_return.append((moves+'A', 'A'))
    return list(set(to_return))
    
    
   
# #     return list(set([hmoves+vmoves+'A', vmoves+hmoves+'A']))
#     return [(moves+'A', 'A')]

In [226]:
numpad_seqs('A', '0')

[('<', 'A')]

In [160]:
rdirp

{'X': 0j, '^': 1j, 'A': 2j, '<': (1+0j), 'v': (1+1j), '>': (1+2j)}

In [114]:
moves

'^^^<<'

In [None]:
seq = 

In [None]:
def solve_a(data):
    pass

In [None]:
solve_a(sample)

In [None]:
solve_a(data)

In [None]:
def solve_b(data):
    pass

In [None]:
solve_b(sample)

In [None]:
solve_b(data)