In [194]:
import re

Instruction = tuple[str, int, str]


def parse(str: str) -> list[Instruction]:
    operations = str.split(",")

    for i, op in enumerate(operations):
        matches = re.match(r"(\w+)(=|-)(\d*)", op)
        operations[i] = (
            matches.group(1),
            int(matches.group(3) or -1),
            matches.group(2),
        )

    return operations


def encode(str: str, key: int = 0) -> int:
    if not str:
        return key

    next_key = ((key + ord(str[0])) * 17) % 256

    return encode(str[1:], next_key)


def get_boxes(operations: list[Instruction]) -> dict[int, list[tuple[int, int]]]:
    boxes = dict.fromkeys(range(256), [])

    for label, focal, operation in operations:
        i = encode(label)
        try:
            if operation == "-":
                element = next(filter(lambda lens: lens[0] == label, boxes[i]), None)

                if element:
                    current = boxes[i].copy()
                    current.remove(element)
                    boxes[i] = current
            if operation == "=":
                element = next(filter(lambda lens: lens[0] == label, boxes[i]), None)

                if not element:
                    current = boxes[i].copy()
                    current.append((label, focal))
                    boxes[i] = current
                else:
                    current = boxes[i].copy()
                    current[current.index(element)] = (label, focal)
                    boxes[i] = current
        except ValueError:
            pass

    return boxes

In [195]:
test_input = "HASH"

assert encode(test_input) == 52

In [196]:
test_input = "rn=1,cm-,qp=3,cm=2,qp-,pc=4,ot=9,ab=5,pc-,pc=6,ot=7"

assert sum([encode(value) for value in test_input.split(',')]) == 1320

In [197]:
total = sum([encode(value) for value in open("15.txt").read().split(",")])

print(total)
assert total == 511257

511257


In [198]:
test_input = "rn=1,cm-,qp=3,cm=2,qp-,pc=4,ot=9,ab=5,pc-,pc=6,ot=7"
boxes = get_boxes(parse(test_input))
total = 0

for box_id, box in enumerate(boxes.values()):
    if not box:
        continue

    total += sum(
        [(box_id + 1) * (lens_id + 1) * (lens[1]) for lens_id, lens in enumerate(box)]
    )

assert total == 145

In [None]:
boxes = get_boxes(parse(open("15.txt").read()))
total = 0

for box_id, box in enumerate(boxes.values()):
    if not box:
        continue

    total += sum(
        [(box_id + 1) * (lens_id + 1) * (lens[1]) for lens_id, lens in enumerate(box)]
    )

print(total)