In [1]:
def registry(registry_key):
    """Return a metaclass that enables fetching subclasses by an arbitrary key.

    Args:
        registry_key (str): class attribute to use for indexing/retrieving subclasses
    Returns:
        type: metaclass to be used in the base class

    Example usage:
        >>> class Base(metaclass=registry('platform')):
        ...    platform = None
        >>>
        >>> class Youtube(Base):
        ...    platform = 'yt'
        >>>
        >>> class Facebook(Base):
        ...   platform = 'fb'
        >>>

        >>> Base.registry['yt']
        <class 'swissarmy.class_registry.Youtube'>

        >>> Base.registry['fb']
        <class 'swissarmy.class_registry.Facebook'>
    """

    class Registry(type):
        def __init__(cls, name, bases, classdict):
            super(Registry, cls).__init__(name, bases, classdict)
            if not hasattr(cls, 'registry'):
                cls.registry = {}
            key = classdict.get(registry_key, None)
            if key is not None:
                if key in cls.registry:
                    raise Exception('Duplicate registry key found: %s' % key)
                cls.registry[key] = cls

    return Registry


In [50]:
with open('input.txt', 'r') as f:
    program = f.read()


In [82]:
from collections import defaultdict, deque


class Op(metaclass=registry('opcode')):
    opcode = None
    arg_count = None

    def __init__(self, mem, input_func, output_func = None):
        self.mem = mem
        self.read_input = input_func
        self.write_output = output_func

    def run(self, ip, arg_modes):
        self.arg_modes = str(arg_modes)
        self.ip = ip
        new_ip = self._run()
        if new_ip is None:
            new_ip = ip + self.arg_count + 1
        return new_ip

    def val(self, arg_num, dereference=True):
        val = self.mem[self.ip + arg_num]
        try:
            mode = self.arg_modes[-arg_num]
        except IndexError:
            mode = '0'
        if mode == '1':
            return val        
        if mode == '2':
            val += self.mem[-1]
        if dereference:
            return self.mem[val]
        return val
    
    def store(self, value, arg_num=3):          
        self.mem[self.val(arg_num, dereference=False)] = value
    
    def _run(self):
        pass

class Add(Op):
    opcode = 1
    arg_count = 3
    
    def _run(self):
        self.store(self.val(1) + self.val(2))

class Mult(Op):
    opcode = 2
    arg_count = 3
    
    def _run(self):
        self.store(self.val(1) * self.val(2))

class Input(Op):
    opcode = 3
    arg_count = 1
    
    def _run(self):
        self.store(self.read_input(), 1)
        
class Output(Op):
    opcode = 4
    arg_count = 1

    def _run(self):
        self.write_output(self.val(1))

class JumpIfTrue(Op):
    opcode = 5
    arg_count = 2

    def _run(self):
        if self.val(1):
            return self.val(2)

class JumpIfFalse(Op):
    opcode = 6
    arg_count = 2
    
    def _run(self):
        if not self.val(1):
            return self.val(2)

class LessThan(Op):
    opcode = 7
    arg_count = 3

    def _run(self):
        self.store(1 if self.val(1) < self.val(2) else 0)

class Equals(Op):
    opcode = 8
    arg_count = 3
    
    def _run(self):
        self.store(1 if self.val(1) == self.val(2) else 0)

class Relative(Op):
    opcode = 9
    arg_count = 1
    
    def _run(self):
        self.mem[-1] += self.val(1)

class Halt(Op):
    opcode = 99
    arg_count = 0

    def _run(self):
        raise Halt()
        
class NoInput(Exception):
    pass


class Halt(Exception):
    pass

class Amp:
    next_id = 0
    
    @classmethod
    def generate_id(cls):
        cls.next_id += 1
        return 'amp_{}'.format(cls.next_id)
    
    def __init__(self, program, inputs, output_callback=None, name=None, debug=False):
        self.mem = defaultdict(int, ((i, int(val)) for i, val in enumerate(program.strip().split(','))))
        self.ip = 0
        self.inputs = deque(inputs)
        self.last_output = None
        self.name = name or self.generate_id()
        self._debug = debug
        self.ops = [None] * 100
        for opcode, klass in Op.registry.items():
            self.ops[opcode] = klass(self.mem, self._read_input)
        self.set_output_callback(output_callback)
        self.set_state('init')
        
    def set_output_callback(self, func):
        def cb(val):
            self.last_output = val
            self.debug(val)
            if func:
                func(val)
        self._write_output = cb
        for op in self.ops:
            if op is not None:
                op.write_output = self._write_output
        
    def _read_input(self):
        if not self.inputs:
            raise NoInput()
        output = self.inputs.popleft()
        self.debug("read input {}".format(output))
        return output

    def add_input(self, val):
        self.inputs.append(val)
        if self.state == 'paused':
            self.run()

    def run(self):
        self.set_state('running')
        while True:
            arg_modes, opcode = divmod(self.mem[self.ip], 100)
            try:
                self.debug(self.ops[opcode].__class__.__name__)
                self.debug((arg_modes, opcode))
                self.debug(self.ip)
                self.ip = self.ops[opcode].run(self.ip, arg_modes)
            except NoInput:
                self.set_state('paused')
                return
            except Halt:
                self.set_state('done')
                return self.last_output
    
    def set_state(self, state):
        self.state = state
        self.debug(state)
    
    def debug(self, msg):
        if self._debug:
            print('{}: {}'.format(self.name, msg))


In [84]:
a = Amp(program, [1], output_callback=print)
a.run()

3063082071


3063082071

In [83]:
a = Amp(program, [2], output_callback=print)
a.run()

81348


81348