In [None]:
from functools import cache

In [None]:
test = False
filename = "test.txt" if test else "input.txt"
lights, wiring, jolts = [], [], []
with open(filename) as f:
    for ln in f:
        line = ln.strip()
        if line == "":
            continue
        parts = line.split(" ")
        lights.append([1 if c == "#" else 0 for c in parts[0][1:-1]])
        jolts.append(tuple(int(v) for v in parts[-1][1:-1].split(",")))
        wiring.append([tuple(int(v) for v in p[1:-1].split(",")) for p in parts[1:-1]])

In [None]:
@cache
def push_button(state, sequence, wiring):
    # get the indexes where the state is 1
    bad_states = [i for i, s in enumerate(state) if s == 1]
    presses = set(sequence)
    new_sequences = set()
    for button in wiring:
        # only press buttons that affect at least one bad state and haven't been pressed yet
        if button not in presses and any(b in button for b in bad_states):
            # update the state and sequence
            new_sequence = tuple(sorted(sequence + (button,)))
            new_state = list(state)
            for idx in button:
                new_state[idx] = 1 - new_state[idx]
            # check if all lights are off
            if sum(new_state) == 0:
                new_sequences.add(new_sequence)
                # for this path, it doesn't get better than this, so we can stop here
                continue
            else:
                # sorted new sequence to ensure uniqueness and caching works properly
                new_sequences.update(push_button(tuple(new_state), new_sequence, wiring))
    # make sure to return unique sequences only
    return set(new_sequences)
                


In [None]:
total = 0
for _lights, _wiring in zip(lights, wiring):
    sequences = push_button(tuple(_lights), tuple(), tuple(_wiring))
    total += min(len(sequence) for sequence in sequences)
print(total)

In [None]:
@cache
def push_button_part2(state, sequence, wiring, joltage):
    # which states still need work
    bad_states = [i for i, s in enumerate(state) if s < joltage[i]]
    new_sequences = set()
    best = None
    for button in wiring:
        # only press buttons that only affect bad states
        if all(b in bad_states for b in button):
            # update the state and sequence
            new_sequence = tuple(sorted(sequence + (button,)))
            if best is not None and len(new_sequence) >= best:
                continue
            new_state = list(state)
            for idx in button:
                new_state[idx] += 1
            # check if all lights are at required joltage
            if all(s == j for s, j in zip(new_state, joltage)):
                new_sequences.add(new_sequence)
                if best is None or len(new_sequence) < len(best):
                    best = len(new_sequence)
                # for this path, it doesn't get better than this, so we can stop here
                continue
            else:
                # sorted new sequence to ensure uniqueness and caching works properly
                _new_sequences = push_button_part2(tuple(new_state), new_sequence, wiring, joltage)
                if _new_sequences:
                    new_sequences.update(_new_sequences)
    return set(new_sequences)

In [None]:
total = 0
for joltage, _wiring in zip(jolts, wiring):
    print(joltage, _wiring)
    sequences = push_button_part2(tuple(0 for _ in joltage), tuple(), tuple(_wiring), tuple(joltage))
    total += min(len(sequence) for sequence in sequences)
print(total)