In [None]:
from collections import Counter
import itertools
from tqdm import tqdm

In [None]:
filename = "sample.txt"
# filename = "input.txt"
with open(filename, encoding="utf-8") as f:
    data = f.read()

lines = data.strip().split("\n")

In [None]:
# (lights, buttons, joltages)
machines = []
# all_lights = []
# all_buttons = []
# all_joltages = []
for l in lines:
    raw_i, *raw_b, raw_j = l.split(" ")
    # print(raw_i, raw_b, raw_j)
    lights = []
    for c in raw_i:
        match c:
            case ".":
                lights.append(0)
            case "#":
                lights.append(1)
            case _:
                pass
    buttons = [list(map(int, b[1:-1].split(","))) for b in raw_b]
    # buttons = [frozenset(map(int, b[1:-1].split(","))) for b in raw_b]
    joltages = [int(j) for j in raw_j[1:-1].split(",")]
    machines.append((lights, buttons, joltages))
# machines

In [None]:
def counter_modulo(c: Counter, modulus=2):
    result = Counter()
    for k, v in c.items():
        result[k] = v % modulus
    return +result

In [None]:
## Part 1
# It's a light-out (ish) puzzle!
# For part 1, order doesn't matter, and pressing the same button twice does nothing
#  so each button is pressed at most once
total_presses = 0
solutions = []

for target_lights, buttons, _target_joltages in tqdm(machines):
    # target_lights, buttons, joltages = machines[0]
    target = {k: v for k, v in enumerate(target_lights)}
    target = +Counter(target)
    # Try all presses of length 1 before length 2, etc.
    attempts = itertools.chain(*(itertools.combinations(buttons, r) for r in range(len(buttons) + 1)))
    # all_attempts = list(attempts)
    # len(all_attempts)
    for pressed_buttons in attempts:
        lights = Counter()
        for b in pressed_buttons:
            # lights += b
            lights.update(b)
        lights = counter_modulo(lights, 2)
        # print(f"{lights=} {target=}")
        if lights == target:
            n = len(pressed_buttons)
            # print(f"{n=} {pressed_buttons=}")
            total_presses += n
            solutions.append(pressed_buttons)
            break

total_presses

In [None]:
# ## Part 2
# Nope! Takes too long bc too many options
# # Joltages instead
# total_presses = 0
# solutions = []

# for _target_lights, buttons, target_joltages in tqdm(machines):
#     # target = {k: v for k, v in enumerate(target_lights)}
#     target = {k: v for k, v in enumerate(target_joltages)}
#     target = +Counter(target)
#     # No button adds more than +1 to a particular joltage. So min presses is max(joltages)
#     #  and max presses is sum(joltages)
#     presses_search_range = range(max(target_joltages), sum(target_joltages) + 1)
#     print(f"{target_joltages=} {presses_search_range=}")
#     # continue
#     # for i in presses_search_range:
#     #     attempts = 

#     # Try all presses of length 1 before length 2, etc.
#     attempts = itertools.chain(*(itertools.combinations_with_replacement(buttons, r) for r in presses_search_range))
#     # all_attempts = list(attempts)
#     # len(all_attempts)
#     for pressed_buttons in attempts:
#         joltages = Counter()
#         for b in pressed_buttons:
#             joltages.update(b)
#         # print(f"{joltages=} {target=}")
#         if joltages == target:
#             n = len(pressed_buttons)
#             print(f"{n=} {pressed_buttons=}")
#             total_presses += n
#             solutions.append(pressed_buttons)
#             break

# total_presses

In [None]:
## Part 2 - Attempt 2
# Try MILP (mixed integer linear programming), as suggested on reddit
from scipy.optimize import milp, LinearConstraint

# Based on https://realpython.com/linear-programming-python/#using-scipy
# and the milp example in https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.milp.html
# _target_lights, buttons, target_joltages = machines[0]

total_presses = 0
solutions = []
for _target_lights, buttons, target_joltages in tqdm(machines):
    # print(f"{buttons=}")
    # print(f"{target_joltages=}")
    # Each button [b0, ..., bn] is pressed [p0, ..., pn] times
    len_bs = len(buttons)
    len_js = len(target_joltages)

    # Objective: minimise sum(p0, p1, ..., pn)
    c = [1 for _ in range(len_bs)]
    integrality = [1 for _ in range(len_bs)]

    # For each j, which buttons will increase it?
    A = [[] for _ in range(len_js)]
    for i, b in enumerate(buttons):
        for j, jolts in enumerate(A):
            # Coefficient=1 if this button increases this joltage, else 0
            jolts.append(int(j in b))
    # print(f"{A=}")
    # Joltages should match targets exactly, so lower and upper bounds of constraint should equal target
    b_l = target_joltages
    b_u = target_joltages
    constraints = LinearConstraint(A, b_l, b_u)
    res = milp(c=c, constraints=constraints, integrality=integrality)
    if not res.success:
        print(f"Failed to optimise {buttons=} {target_joltages=}")
        print(res)
        raise NotImplementedError
    # print(f"{res=}")
    n_presses = int(res.fun)
    result_presses = list(map(int, res.x))
    total_presses += n_presses
    solutions.append(result_presses)

print(f"{total_presses=}")