In [13]:
from enum import Enum
import re

In [4]:
reg = [0, 0, 0, 0]

In [45]:
class OpCodes:
    # add
    @staticmethod
    def addr(a, b, c, reg):
        reg[c]=reg[a]+reg[b]
    
    @staticmethod
    def addi(a, b, c, reg):
        reg[c]=reg[a]+b
        
    # multiply
    @staticmethod
    def mulr(a, b, c, reg):
        reg[c]=reg[a]*reg[b]
    
    @staticmethod
    def muli(a, b, c, reg):
        reg[c]=reg[a]*b
        
    # bitwise and
    @staticmethod
    def banr(a, b, c, reg):
        reg[c]=reg[a]&reg[b]
    
    @staticmethod
    def bani(a, b, c, reg):
        reg[c]=reg[a]&b
        
    # bitwise or
    @staticmethod
    def borr(a, b, c, reg):
        reg[c]=reg[a]|reg[b]
    
    @staticmethod
    def bori(a, b, c, reg):
        reg[c]=reg[a]|b
        
    # assign
    @staticmethod
    def setr(a, b, c, reg):
        reg[c]=reg[a]
    
    @staticmethod
    def seti(a, b, c, reg):
        reg[c]=a
        
    # greater than
    @staticmethod
    def gtir(a, b, c, reg):
        if a > reg[b]:
            reg[c]=1  
        else: 
            reg[c]=0
    
    @staticmethod
    def gtri(a, b, c, reg):
        if reg[a] > b: 
            reg[c]=1 
        else: 
            reg[c]=0
    
    @staticmethod
    def gtrr(a, b, c, reg):
        if reg[a] > reg[b]:
            reg[c]=1 
        else:
            reg[c]=0
        
    # equal
    @staticmethod
    def eqir(a, b, c, reg):
        if a == reg[b]:
            reg[c]=1 
        else: 
            reg[c]=0
    
    @staticmethod
    def eqri(a, b, c, reg):
        if reg[a] == b:
            reg[c]=1 
        else:
            reg[c]=0
    
    @staticmethod
    def eqrr(a, b, c, reg):
        if reg[a] == reg[b]:
            reg[c]=1 
        else:
            reg[c]=0

In [88]:
op_codes = [
    OpCodes.addr,
    OpCodes.addi,
    OpCodes.mulr,
    OpCodes.muli,
    OpCodes.banr,
    OpCodes.bani,
    OpCodes.borr,
    OpCodes.bori,
    OpCodes.setr,
    OpCodes.seti,
    OpCodes.gtir,
    OpCodes.gtri,
    OpCodes.gtrr,
    OpCodes.eqir,
    OpCodes.eqri,
    OpCodes.eqrr,
]

In [31]:
# instruction opcode, A, B, C (A, B, C = index of registers)
BEFORE_PATTERN = re.compile(r'Before: \[(.+)\]')
AFTER_PATTERN  = re.compile(r'After:  \[(.+)\]')

In [47]:
def get_transitions(filepath):
    with open(filepath) as file:
        transitions = []
        lines = list(file.readlines())
        for i in range(0, len(lines), 4):
            before_match = BEFORE_PATTERN.findall(lines[i])
            if not before_match:
                break
            before = [int(n) for n in before_match[0].split(', ')]
            instruction = [int(n) for n in lines[i+1].split(' ')]
            after = [int(n) for n in AFTER_PATTERN.findall(lines[i+2])[0].split(', ')]
            transitions.append((before, instruction, after))
    return transitions

In [89]:
opcode_transitions_filepath = 'data/day16_opcodes.txt'
program_filepath = 'data/day16_program.txt'
transitions = get_transitions(opcode_transitions_filepath)

In [90]:
total_greater = 0
for before, op_code, after in transitions:
    total_op_codes = 0
    for op in op_codes:
        before_copy = [i for i in before]
        op(*op_code[1:], before_copy)
        if before_copy == after:
            total_op_codes += 1
    if total_op_codes >= 3:
        total_greater += 1

In [91]:
print(total_greater)

651


In [92]:
op_code_numbers = {}
for before, op_code, after in transitions:
    op_number = op_code[0]
    op_code_numbers.setdefault(op_number, [i for i in op_codes])
    remove_ops = []
    for op in op_code_numbers[op_number]:
        if op not in op_code_numbers[op_number]:
            continue
        before_copy = [i for i in before]
        op(*op_code[1:], before_copy)
        if before_copy != after:
            remove_ops.append(op)
    for op in remove_ops:
        op_code_numbers[op_number].remove(op)

In [93]:
total_op_codes = 16
assigned_op_codes = []
assigned_op_code_numbers = {}
while len(assigned_op_codes) < total_op_codes-1:
    for op_number, op_codes in op_code_numbers.items():
        for assigned_op_code in assigned_op_codes:
            if assigned_op_code in op_codes:
                op_codes.remove(assigned_op_code)
        # elimination successful
        if len(op_codes) == 1:
            assigned_op_codes.append(op_codes[0])
            assigned_op_code_numbers[op_number] = op_codes[0]

In [94]:
assigned_op_code_numbers

{0: <function __main__.OpCodes.bori(a, b, c, reg)>,
 1: <function __main__.OpCodes.muli(a, b, c, reg)>,
 2: <function __main__.OpCodes.banr(a, b, c, reg)>,
 3: <function __main__.OpCodes.bani(a, b, c, reg)>,
 4: <function __main__.OpCodes.gtir(a, b, c, reg)>,
 5: <function __main__.OpCodes.setr(a, b, c, reg)>,
 6: <function __main__.OpCodes.addr(a, b, c, reg)>,
 7: <function __main__.OpCodes.eqir(a, b, c, reg)>,
 8: <function __main__.OpCodes.seti(a, b, c, reg)>,
 9: <function __main__.OpCodes.addi(a, b, c, reg)>,
 10: <function __main__.OpCodes.eqrr(a, b, c, reg)>,
 11: <function __main__.OpCodes.eqri(a, b, c, reg)>,
 12: <function __main__.OpCodes.borr(a, b, c, reg)>,
 13: <function __main__.OpCodes.gtrr(a, b, c, reg)>,
 14: <function __main__.OpCodes.mulr(a, b, c, reg)>,
 15: <function __main__.OpCodes.gtri(a, b, c, reg)>}

In [95]:
def get_program(filepath):
    with open(filepath) as file:
        for line in file.readlines():
            yield [int(n) for n in line.split(' ')]

In [99]:
program = list(get_program(program_filepath))

In [101]:
registers = [0, 0, 0, 0]
for op_codes in program:
    op_number = op_codes[0]
    assigned_op_code_numbers[op_number](*op_codes[1:], registers)

In [102]:
print(registers)

[706, 0, 4, 706]
