In [None]:
def registry(registry_key):
    """Return a metaclass that enables fetching subclasses by an arbitrary key.

    Args:
        registry_key (str): class attribute to use for indexing/retrieving subclasses
    Returns:
        type: metaclass to be used in the base class

    Example usage:
        >>> class Base(metaclass=registry('platform')):
        ...    platform = None
        >>>
        >>> class Youtube(Base):
        ...    platform = 'yt'
        >>>
        >>> class Facebook(Base):
        ...   platform = 'fb'
        >>>

        >>> Base.registry['yt']
        <class 'swissarmy.class_registry.Youtube'>

        >>> Base.registry['fb']
        <class 'swissarmy.class_registry.Facebook'>
    """

    class Registry(type):
        def __init__(cls, name, bases, classdict):
            super(Registry, cls).__init__(name, bases, classdict)
            if not hasattr(cls, 'registry'):
                cls.registry = {}
            key = classdict.get(registry_key, None)
            if key is not None:
                if key in cls.registry:
                    raise Exception('Duplicate registry key found: %s' % key)
                cls.registry[key] = cls

    return Registry


In [34]:
from collections import deque


class Op(metaclass=registry('opcode')):
    opcode = None
    arg_count = None

    def __init__(self, mem, input_func, output_func = None):
        self.mem = mem
        self.read_input = input_func
        self.write_output = output_func

    def run(self, ip, arg_modes):
        self.arg_modes = '0000' + str(arg_modes)
        self.ip = ip
        new_ip = self._run()
        if new_ip is None:
            new_ip = ip + self.arg_count + 1
        return new_ip

    def val(self, arg_num):
        val = self.arg(arg_num)
        if self.arg_modes[-arg_num] == '0':
            return self.mem[val]
        else:
            return val

    def arg(self, arg_num):
        return self.mem[self.ip + arg_num]
    
    def store(self, value, arg_num=3):
        self.mem[self.arg(arg_num)] = value
    
    def _run(self):
        pass

class Add(Op):
    opcode = 1
    arg_count = 3
    
    def _run(self):
        self.store(self.val(1) + self.val(2))

class Mult(Op):
    opcode = 2
    arg_count = 3
    
    def _run(self):
        self.store(self.val(1) * self.val(2))

class Halt(Op):
    opcode = 99
    arg_count = 0

    def _run(self):
        raise Halt()

class Input(Op):
    opcode = 3
    arg_count = 1
    
    def _run(self):
        self.store(self.read_input(), 1)

class Output(Op):
    opcode = 4
    arg_count = 1

    def _run(self):
        self.write_output(self.val(1))

class JumpIfTrue(Op):
    opcode = 5
    arg_count = 2

    def _run(self):
        if self.val(1):
            return self.val(2)

class JumpIfFalse(Op):
    opcode = 6
    arg_count = 2
    
    def _run(self):
        if not self.val(1):
            return self.val(2)

class LessThan(Op):
    opcode = 7
    arg_count = 3

    def _run(self):
        self.store(1 if self.val(1) < self.val(2) else 0)

class Equals(Op):
    opcode = 8
    arg_count = 3
    
    def _run(self):
        self.store(1 if self.val(1) == self.val(2) else 0)

class NoInput(Exception):
    pass


class Halt(Exception):
    pass

class Amp:
    next_id = 0
    
    @classmethod
    def generate_id(cls):
        cls.next_id += 1
        return 'amp_{}'.format(cls.next_id)
    
    def __init__(self, program, inputs, output_callback=None, name=None, debug=False):
        self.mem = [int(o) for o in program.strip().split(',')]
        self.ip = 0
        self.inputs = deque(inputs)
        self.last_output = None
        self.name = name or self.generate_id()
        self._debug = debug
        self.ops = [None] * 100
        for opcode, klass in Op.registry.items():
            self.ops[opcode] = klass(self.mem, self._read_input)
        self.set_output_callback(output_callback)
        self.set_state('init')
        
    def set_output_callback(self, func):
        def cb(val):
            self.last_output = val
            self.debug(val)
            if func:
                func(val)
        self._write_output = cb
        for op in self.ops:
            if op is not None:
                op.write_output = self._write_output
        
    def _read_input(self):
        if not self.inputs:
            raise NoInput()
        output = self.inputs.popleft()
        self.debug("read input {}".format(output))
        return output

    def add_input(self, val):
        self.inputs.append(val)
        if self.state == 'paused':
            self.run()

    def run(self):
        self.set_state('running')
        while True:
            arg_modes, opcode = divmod(self.mem[self.ip], 100)
            try:
#                 self.debug(self.ops[opcode].__class__)
#                 self.debug(self.ip)
                self.ip = self.ops[opcode].run(self.ip, arg_modes)
            except NoInput:
                self.set_state('paused')
                return
            except Halt:
                self.set_state('done')
                return self.last_output
    
    def set_state(self, state):
        self.state = state
        self.debug(state)
    
    def debug(self, msg):
        if self._debug:
            print('{}: {}'.format(self.name, msg))


In [35]:
with open('input.txt', 'r') as f:
    program = f.read()

from itertools import permutations

max_output = -1000000000
for phases in permutations(range(5)):
    output = 0
    for phase_setting in phases:
        output = Amp(program, [phase_setting, output], debug=False).run()
    max_output = max(max_output, output)

max_output
    

118936

In [36]:
def part2(prog, perms, debug=False):
    max_output = -1000000000
    max_phases = None
    for phases in perms:
        amps = [Amp(prog, [phase_setting], debug=debug) for phase_setting in phases]
        for amp0, amp1 in zip(amps[-1:] + amps[:-1], amps):
            amp0.set_output_callback(amp1.add_input)
        amps[0].add_input(0)
        while amps[-1].state != 'done':
            for amp in amps:
                amp.run()
        if amps[-1].last_output > max_output:
            max_phases = phases
            max_output = amps[-1].last_output

    return max_output, max_phases


In [37]:
test = '3,26,1001,26,-4,26,3,27,1002,27,2,27,1,27,26,27,4,27,1001,28,-1,28,1005,28,6,99,0,0,5'
part2(test, [(9,8,7,6,5)], debug=False)


(139629729, (9, 8, 7, 6, 5))

In [38]:
part2(test, permutations(range(5,10)))

(139629729, (9, 8, 7, 6, 5))

In [39]:
test2 = '3,52,1001,52,-5,52,3,53,1,52,56,54,1007,54,5,55,1005,55,26,1001,54,-5,54,1105,1,12,1,53,54,53,1008,54,0,55,1001,55,1,55,2,53,55,53,4,53,1001,56,-1,56,1005,56,6,99,0,0,0,0,10'
part2(test2, permutations(range(5,10)))

(18216, (9, 7, 8, 5, 6))

In [33]:
part2(program, permutations(range(5,10)))

(57660948, (9, 7, 6, 5, 8))

In [41]:
test2 = '3,52,1001,52,-5,52,3,53,1,52,56,54,1007,54,5,55,1005,55,26,1001,54,-5,54,1105,1,12,1,53,54,53,1008,54,0,55,1001,55,1,55,2,53,55,53,4,53,1001,56,-1,56,1005,56,6,99,0,0,0,0,10'
import profile

profile.run('part2(test2, permutations(range(5,10)))')

         833645 function calls (747245 primitive calls) in 4.265 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     6120    0.012    0.000    0.012    0.000 :0(append)
    98520    0.218    0.000    0.218    0.000 :0(divmod)
        1    0.000    0.000    4.265    4.265 :0(exec)
     7200    0.019    0.000    0.019    0.000 :0(format)
      600    0.002    0.000    0.002    0.000 :0(items)
     6600    0.014    0.000    0.014    0.000 :0(popleft)
        1    0.000    0.000    0.000    0.000 :0(setprofile)
      600    0.004    0.000    0.004    0.000 :0(split)
      600    0.002    0.000    0.002    0.000 :0(strip)
      600    0.003    0.000    0.005    0.000 <ipython-input-34-d34973b197ab>:112(generate_id)
      600    0.038    0.000    0.089    0.000 <ipython-input-34-d34973b197ab>:117(__init__)
      600    0.012    0.000    0.012    0.000 <ipython-input-34-d34973b197ab>:118(<listcomp>)
98520/29400    0.585    0.000 