In [1]:
from z3 import Bool, Optimize, Sum, If, Xor, Int

In [2]:
def parse_input(file_name):
    with open(file_name) as f:
        data = f.read()
    data = data.split("\n")
    parsed = []
    for d in data:
        start, end = 1, d.index(']')
        diagram = d[start:end]
        diagram = [d == '#' for d in diagram]
        
        start, end = d.index(']') + 1, d.index('{') - 2
        buttons = d[start:end].replace(" ", "").split(')')
        buttons = [b.replace('(', "").split(',') for b in buttons]
        buttons = [[int(b) for b in l] for l in buttons]

        start, end = d.index('{') + 1, d.index('}')
        voltage = [int(x) for x in d[start:end].split(",")]
        
        parsed.append((diagram, buttons, voltage))
    return parsed
input = parse_input("input.txt")
example = parse_input("example.txt")

In [3]:
def part_1(input):
    total = 0
    for diagram, buttons, _ in input:
        n, m = len(diagram), len(buttons)
        x = [Bool(f"x_{j}") for j in range(m)]
        opt = Optimize()
        for i in range(n):
            changed_by = [x[j] for j in range(m) if i in buttons[j]]
            final_bit = changed_by[0]
            for b in changed_by[1:]:
                final_bit = Xor(final_bit, b)
            opt.add(final_bit == diagram[i])
        target = Sum(If(v, 1, 0) for v in x)
        opt.minimize(target)
        opt.check()
        model = opt.model()
        pressed = [j for j, v in enumerate(x) if model[v] == True]
        total += len(pressed)
    return total

In [4]:
assert(part_1(example) == 7)
print(part_1(input))

520


In [5]:
def part_2(input):
    total = 0
    for _, buttons, voltage in input:
        n, m = len(voltage), len(buttons)
        x = [Int(f"x_{j}") for j in range(m)]
        opt = Optimize()
        for j in range(m):
            opt.add(x[j] >= 0)
        for i in range(n):
            changed_by = [x[j] for j in range(m) if i in buttons[j]]
            opt.add(Sum(changed_by) == voltage[i])

        target = Sum(x)
        opt.minimize(target)
        opt.check()
        model = opt.model()
        pressed = [model[v].as_long() for v in x]
        total += sum(pressed)
    return total

In [6]:
assert(part_2(example) == 33)
print(part_2(input))

20626
