In [47]:
import re, copy, numpy as np
from collections import defaultdict

In [48]:
class Device(object):
    def __init__(self) -> None:
        self.registers = [0, 0, 0, 0]

    def exec_instruction(self, opcode, A, B, C):
        match opcode:
            case 'addr':
                self.registers[C] = self.registers[A] + self.registers[B]
            case 'addi':
                self.registers[C] = self.registers[A] + B
            case 'mulr':
                self.registers[C] = self.registers[A] * self.registers[B]
            case 'muli':
                self.registers[C] = self.registers[A] * B
            case 'banr':
                self.registers[C] = self.registers[A] & self.registers[B]
            case 'bani':
                self.registers[C] = self.registers[A] & B
            case 'borr':
                self.registers[C] = self.registers[A] | self.registers[B]
            case 'bori':
                self.registers[C] = self.registers[A] | B
            case 'setr':
                self.registers[C] = self.registers[A]
            case 'seti':
                self.registers[C] = A
                
            case 'gtir':
                self.registers[C] = 1 if A > self.registers[B] else 0
            case 'gtri':
                self.registers[C] = 1 if self.registers[A] > B else 0
            case 'gtrr':
                self.registers[C] = 1 if self.registers[A] > self.registers[B] else 0
            
            case 'eqir':
                self.registers[C] = 1 if A == self.registers[B] else 0
            case 'eqri':
                self.registers[C] = 1 if self.registers[A] == B else 0
            case 'eqrr':
                self.registers[C] = 1 if self.registers[A] == self.registers[B] else 0

In [75]:
opcodes = ['addr', 'addi', 'mulr', 'muli', 'banr', 'bani', 'borr', 'bori', 'setr', 'seti', \
            'gtir', 'gtri', 'gtrr', 'eqir', 'eqri', 'eqrr']

sample = """
Before: [3, 2, 1, 1]
9 2 1 2
After:  [3, 2, 2, 1]
""".strip().splitlines()

samples = open("16.txt").read().strip()
samples = samples.split("\n\n")
testprogram = samples.pop(-1).strip()
_ = samples.pop(-1)

In [50]:
# Part 1
result = 0
for sample in samples:
    sample = sample.splitlines()
    before = tuple([int(x) for x in re.findall(r'\d+', sample[0])])
    instruction = [int(x) for x in re.findall(r'\d+', sample[1])]
    after = tuple([int(x) for x in re.findall(r'\d+', sample[2])])

    dev = Device()
    matching = set()
    for opcode in opcodes:
        dev.registers = list(before)
        #print(before, dev.registers)
        dev.exec_instruction(opcode, *instruction[1:])
        #print(opcode, dev.registers)
        if dev.registers == list(after):
            matching.add(opcode)
    if len(matching) >= 3:
        result += 1
print(result)

605


In [78]:
## Part 2

# Translation
matching = defaultdict(set)
for sample in samples:
    sample = sample.splitlines()
    before = tuple([int(x) for x in re.findall(r'\d+', sample[0])])
    instruction = [int(x) for x in re.findall(r'\d+', sample[1])]
    after = tuple([int(x) for x in re.findall(r'\d+', sample[2])])

    dev = Device()
    for opcode in opcodes:
        dev.registers = list(before)
        #print(before, dev.registers)
        dev.exec_instruction(opcode, *instruction[1:])
        #print(opcode, dev.registers)
        if dev.registers == list(after):
            matching[instruction[0]].add(opcode)

matched_opcodes = set()
while len(matched_opcodes) < len(opcodes):
    for key in sorted(matching, key = lambda x: len(matching[x])):
        if len(matching[key]) > 1:
            matching[key] -= matched_opcodes
        if len(matching[key]) == 1:
            matched_opcodes |= matching[key]
for m in matching:
    matching[m] = matching[m].pop()

# Test program
dev = Device()
for instruction in testprogram.splitlines():
    instruction = [int(x) for x in instruction.split()]
    dev.exec_instruction(matching[instruction[0]], *instruction[1:])
dev.registers

[653, 1, 4, 653]

In [74]:
matching

defaultdict(set,
            {0: 'eqri',
             6: 'seti',
             11: 'addr',
             9: 'bori',
             5: 'addi',
             7: 'gtir',
             8: 'muli',
             3: 'gtrr',
             4: 'banr',
             13: 'borr',
             2: 'gtri',
             12: 'bani',
             14: 'eqir',
             1: 'mulr',
             10: 'setr',
             15: 'eqrr'})