# Day 14
https://adventofcode.com/2020/day/14

In [1]:
import aocd
data = aocd.get_data(year=2020, day=14)

In [16]:
import numpy as np
import re

##### Part 1: Masking the written value

In [14]:
def written_value(value, mask):
    return int(''.join(char if mask[ix] == 'X' else mask[ix]
                       for ix, char in enumerate(np.binary_repr(value, 36))),
               2)

In [56]:
re_changemask = re.compile(r'mask = ([X10]+)')
re_write = re.compile(r'mem\[(\d+)\] = (\d+)')

def memory_total_after_program(text):
    memory = {}
    mask = 'XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX'
    for line in text.split('\n'):
        match = re_changemask.search(line)
        if match:
            mask = match.group(1)
        
        match = re_write.search(line)
        if match:
            address = int(match.group(1))
            value = written_value(int(match.group(2)), mask)
            memory[address] = value
    
    return sum(memory.values())

In [58]:
p1 = memory_total_after_program(data)
print('Part 1: {}'.format(p1))

Part 1: 7440382076205


##### Part 2: Masking the written address
0 : memory address bit unchanged

1 : memory address bit overwritten with a 1

X : memory address bit is floating - write every combination

In [83]:
def all_memory_addresses(original, mask):
    if len(original) == 0:
        yield ''
        return
    
    for address in all_memory_addresses(original[1:], mask[1:]):
        if mask[0] == '0':
            yield original[0] + address
        if mask[0] == 'X':
            yield '0' + address
        if mask[0] in ('X', '1'):
            yield '1' + address   

In [89]:
def memory_total_after_part2_program(text):
    memory = {}
    mask = 'XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX'
    for line in text.split('\n'):
        match = re_changemask.search(line)
        if match:
            mask = match.group(1)
        
        match = re_write.search(line)
        if match:
            original = np.binary_repr(int(match.group(1)), 36)
            value = int(match.group(2))
            for address in all_memory_addresses(original, mask):
                memory[address] = value
    
    return sum(memory.values())

In [91]:
p2 = memory_total_after_part2_program(data)
print('Part 2: {}'.format(p2))

Part 2: 4200656704538
