In [29]:
import re
pattern = r":(?P<A>.*)\n.*:(?P<B>.*)\n.*:(?P<C>.*)\n\n.*:(?P<Program>.*)"

In [11]:
with open('day17-p1.txt') as f:
    lines = f.read()
    for matches in re.finditer(pattern, lines):
        A = int(matches.group('A'))
        B = int(matches.group('B'))
        C = int(matches.group('C'))
        program = [int(x) for x in matches.group('Program').split(',')]

A, B, C, program

(25986278, 0, 0, [2, 4, 1, 4, 7, 5, 4, 1, 1, 4, 5, 5, 0, 3, 3, 0])

# Part 1

In [34]:
def run(program, A, B, C):    
    ins_ptr = 0
    output = []

    def combo_value(num):
        match num:
            case num if num < 4:
                return num
            case 4:
                return A
            case 5:
                return B
            case 6:
                return C
            case _:
                raise ValueError(f"Invalid value {num}")
            
    while ins_ptr < len(program):
        opcode = program[ins_ptr]
        literal = program[ins_ptr + 1]
    
        match opcode:
            case 0:
                A = A // (2 ** combo_value(literal))
            
            case 1:
                B = B ^ literal
            
            case 2:
                B = combo_value(literal) % 8
            
            case 3:
                if A != 0:
                    ins_ptr = literal
                    continue
            
            case 4:
                B = B ^ C
            
            case 5:
                output.append(combo_value(literal) % 8)
            
            case 6:
                B = A // (2 ** combo_value(literal))
            
            case 7:
                C = A // (2 ** combo_value(literal))    
        
        ins_ptr += 2

    return output

In [13]:
output = run(program, A, B, C)
print(','.join(map(str, output)))

7,0,7,3,4,1,3,0,1


# Part 2

In [42]:
from collections import defaultdict

output_maps = defaultdict(list)
for i in range(1024):
    output = run(program[:-2], i, 0, 0)
    output_maps[output[0]].append(i)

In [46]:
B_C_pairs = defaultdict(set)

for key in output_maps:
    for num in output_maps[key]:
        B = num % 8
        C = (num // (2 ** (B ^ 4))) % 8
        B_C_pairs[key].add((B, C))

In [48]:
bit_strings = {
    0: '000',
    1: '001',
    2: '010',
    3: '011',
    4: '100',
    5: '101',
    6: '110',
    7: '111'
}

In [59]:
def backtracking(ind, curr):
    num = program[ind]
    res = []
    for B, C in B_C_pairs[num]:
        bitB, bitC = bit_strings[B], bit_strings[C]
        currcpy = curr.copy()
        skip = False
        
        sind = -(ind + 1) * 3
        for i in range(3):
            if currcpy[sind + i] == '.' or currcpy[sind + i] == bitB[i]:
                currcpy[sind + i] = bitB[i]
            else:
                skip = True
        
        sind = -(ind + 1) * 3 - (B ^ 4)
        for i in range(3):
            if -(sind + i) <= len(currcpy) and (currcpy[sind + i] == '.' or currcpy[sind + i] == bitC[i]):
                currcpy[sind + i] = bitC[i]
            else:
                if -(sind + i) > len(currcpy) and C == 0:
                    break    
                skip = True
        
        if skip:
            del currcpy
            continue
        
        if ind != len(program) - 1:
            res.extend(backtracking(ind + 1, currcpy))
        else:
            res.append(currcpy)
    return res
        

In [18]:
res = backtracking(0, ['.' for _ in range(48)])
res = [int(''.join(x), 2) for x in res]
min(res)

16247842866690

In [26]:
run(program, 156985331222018, 0, 0)

[2, 4, 1, 4, 7, 5, 4, 1, 1, 4, 5, 5, 0, 3, 3, 0]