In [1]:
%load_ext pycodestyle_magic

In [2]:
%flake8_on

In [3]:
from math import log2, ceil

In [4]:
testdata = """mask = XXXXXXXXXXXXXXXXXXXXXXXXXXXXX1XXXX0X
mem[8] = 11
mem[7] = 101
mem[8] = 0""".splitlines()

testdata2 = """mask = 000000000000000000000000000000X1001X
mem[42] = 100
mask = 00000000000000000000000000000000X0XX
mem[26] = 1""".splitlines()

In [5]:
print(2**37 - 1)

137438953471


In [6]:
def count_bits(arg):
    return sum([(arg >> b) & 1
                for b in range(ceil(log2(arg)) + 1)
                ])

In [7]:
count_bits(15)

4

In [8]:
def print_bits(arg):
    return "".join(map(str, [(arg >> b) & 1
                             for b in range(36, -1, -1)]
                       ))

In [9]:
print(print_bits(234567))

0000000000000000000111001010001000111


In [10]:
def print_int(arg):
    return sum([
        (arg[-(b + 1)] == '1') * 2**b
        for b in range(len(arg))
    ])

In [11]:
print_int('111001010001000111')

234567

In [12]:
def parse_mask(mask):
    positive, negative, floating = (0, 2**37 - 1, 0)
    mask = mask[::-1]
    for i in range(len(mask)):
        if mask[i] == '1':
            positive += 2**i
        elif mask[i] == '0':
            negative -= 2**i
        elif mask[i] == 'X':
            floating += 2**i
    return positive, negative, floating

In [13]:
print(parse_mask('XXXXXXXXXXXXXXXXXXXXXXXXXXXXX1XXXX0X'))

(64, 137438953469, 68719476669)


In [14]:
def parse_instruction(line):
    a, b = line.split(' = ')
    print(line)
    if a == 'mask':
        return a, parse_mask(b)
    elif a[0:3] == 'mem':
        return 'mem', int(a[4:-1]), int(b)

In [15]:
print(parse_instruction('mask = XXXXXXXXXXXXXXXXXXXXXXXXXXXXX1XXXX0X'))
print(parse_instruction('mem[8] = 11'))
print(parse_instruction('mem[7] = 101'))

mask = XXXXXXXXXXXXXXXXXXXXXXXXXXXXX1XXXX0X
('mask', (64, 137438953469, 68719476669))
mem[8] = 11
('mem', 8, 11)
mem[7] = 101
('mem', 7, 101)


In [16]:
# This ain't gonna work, generates a ton of addresses
def deduce_addrs(addr):
    addr = addr | MASK_P
    addrs = [addr]
    for b in range(37):
        if (MASK_F >> b) & 1:
            if (addr >> b) & 1:
                addrs += [a - 2**b for a in addrs]
            else:
                addrs += [a + 2**b for a in addrs]
    return addrs

In [17]:
def set_mask(p_n):
    global MASK_P
    global MASK_N
    global MASK_F
    MASK_P, MASK_N, MASK_F = p_n[0]
    print(f'Set masks to {MASK_P}, {MASK_N} and {MASK_F}')


def set_mem(addr_arg):
    addr, arg = (addr_arg[0], addr_arg[1])
    print(f'Set memory {addr} to {(arg | MASK_P ) & MASK_N}'
          + f' (originally {arg}, with masks {MASK_P} and {MASK_N}')
    MEM[addr] = (arg | MASK_P) & MASK_N


def set_mem2(addr_arg):
    addr, arg = (addr_arg[0], addr_arg[1])
    for addr_f in deduce_addrs(addr):
        # print(f'Set memory {print_bits(addr_f)} ({addr_f}) to {arg}'
        #       + f' (originally {addr}, with masks {MASK_P} and {MASK_N}')
        MEM[addr_f] = arg


opcodes = {
    'mask': set_mask,
    'mem': set_mem,
}

opcodes2 = {
    'mask': set_mask,
    'mem': set_mem2,
}

In [18]:
def execute(instructions):
    for instr in [parse_instruction(instruction)
                  for instruction in instructions]:
        # print(f'executing: {instr}')
        opcodes[instr[0]](instr[1:])

In [19]:
def execute2(instructions):
    for instr in [parse_instruction(instruction)
                  for instruction in instructions]:
        # print(f'executing: {instr}')
        opcodes2[instr[0]](instr[1:])

In [20]:
MASK_P, MASK_N = (0, 2**36 - 1)
MEM = dict()

execute(testdata)
print(MEM, MASK_P, MASK_N)
print(f'Sum of memory: {sum([val for val in MEM.values()])}')

mask = XXXXXXXXXXXXXXXXXXXXXXXXXXXXX1XXXX0X
mem[8] = 11
mem[7] = 101
mem[8] = 0
Set masks to 64, 137438953469 and 68719476669
Set memory 8 to 73 (originally 11, with masks 64 and 137438953469
Set memory 7 to 101 (originally 101, with masks 64 and 137438953469
Set memory 8 to 64 (originally 0, with masks 64 and 137438953469
{8: 64, 7: 101} 64 137438953469
Sum of memory: 165


In [21]:
MASK_P, MASK_N, MASK_F = (0, 2**37 - 1, 0)
MEM = dict()

# with open('input', 'r') as inp:
#     inputdata = [line.strip() for line in inp.readlines()]

# execute(inputdata)
print(MEM, MASK_P, MASK_N, MASK_F)
print(f'Sum of memory: {sum([val for val in MEM.values()])}')

{} 0 137438953471 0
Sum of memory: 0


In [22]:
MASK_P, MASK_N, MASK_F = (0, 2**37 - 1, 0)
MEM = dict()
execute2(testdata2)
print(MEM, MASK_P, MASK_N, MASK_F)
print(f'Sum of memory: {sum([val for val in MEM.values()])}')

mask = 000000000000000000000000000000X1001X
mem[42] = 100
mask = 00000000000000000000000000000000X0XX
mem[26] = 1
Set masks to 18, 68719476787 and 33
Set masks to 0, 68719476747 and 11
{58: 100, 59: 100, 26: 1, 27: 1, 24: 1, 25: 1, 18: 1, 19: 1, 16: 1, 17: 1} 0 68719476747 11
Sum of memory: 208


In [23]:
MASK_P, MASK_N, MASK_F = (0, 2**37 - 1, 0)
MEM = dict()

# with open('input', 'r') as inp:
#    inputdata = [line.strip() for line in inp.readlines()]

# execute2(inputdata)
# print(f'Sum of memory: {sum([val for val in MEM.values()])}')