In [947]:
import re

NodeConfig = tuple[list[str], int]
State = dict[str, dict[str, complex]]


def parse(input_string: str) -> dict[str, list[NodeConfig]]:
    output = {}

    for line in input_string.splitlines():
        matches = re.match(r"(%|&)?(\w+) -> (.+)", line)

        output[matches.group(2)] = (matches.group(3).split(", "), matches.group(1))

    return output

In [948]:
def get_initial_state(nodes: dict[str, NodeConfig]) -> State:
    output = {}

    for node_id, node in nodes.items():
        _, operator = node
        if operator == "%":
            output[node_id] = {"self": 1}
        elif operator == "&":
            output[node_id] = {
                id: 1 for id, step in nodes.items() if node_id in step[0]
            }

    return output


def iterate(
    nodes: dict[str, NodeConfig],
    state: State,
    target_node: str = None,
    debug: bool = False,
):
    count = 1 + 0j
    todos = [(node, 1, "broadcaster") for node in nodes["broadcaster"][0]]
    target_activation = False

    if debug:
        for todo in todos:
            print(f"broadcaster -low-> {todo[0]}")

    while todos:
        current_node, pulse, previous_node = todos[0]
        todos = todos[1:]
        count += pulse

        if target_node and current_node == target_node and pulse == 1:
            return None, None, True

        if current_node not in nodes:
            continue

        next_nodes, operator = nodes[current_node]
        next_pulse = None

        if operator == "%":
            if pulse == 1j:
                continue
            next_pulse = state[current_node]["self"] = (
                1 if state[current_node]["self"] == 1j else 1j
            )

        elif operator == "&":
            state[current_node][previous_node] = pulse
            next_pulse = (
                1 if all(value == 1j for value in state[current_node].values()) else 1j
            )

        for node_id in next_nodes:
            todos.append((node_id, next_pulse, current_node))
            if debug:
                print(
                    f"{current_node} -{'low' if next_pulse == 1 else 'high'}-> {node_id}"
                )

    return count, state, target_activation


def sum_complexes(complexes: list[complex]) -> tuple[int, int]:
    return (sum(x.real for x in complexes), sum(x.imag for x in complexes))


def get_pulse_count(nodes: dict[str, NodeConfig], iterations: int) -> int:
    state = get_initial_state(nodes)
    counts = []

    while not counts or (
        state != get_initial_state(nodes) and len(counts) < iterations
    ):
        count, state, _ = iterate(nodes=nodes, state=state)
        counts.append(count)

    imag, real = sum_complexes(counts)
    imag2, real2 = sum_complexes(counts[: iterations % len(counts)])
    return imag * real * (iterations // len(counts)) ** 2 + imag2 * real2


# TODO: In order to have the number of cycle for rx only low
# we need to find the number of for each conjecture and find the LCM
# of all of them. First we identify the cycle for each conjecture via the target node parameter (to be modified)
# We need the value at which they emit one low pulse. Then we find the LCM of all of them

In [949]:
test_input = """\
broadcaster -> a, b, c
%a -> b
%b -> c
%c -> inv
&inv -> a"""

nodes = parse(test_input)
pulses, _, _ = iterate(nodes=nodes, state=get_initial_state(nodes))

assert pulses.real * pulses.imag * 1000**2 == 32000000

In [950]:
test_input = """\
broadcaster -> a
%a -> inv, con
&inv -> b
%b -> con
&con -> output"""
nodes = parse(test_input)
assert get_pulse_count(nodes, 1000) == 11687500

In [951]:
nodes = parse(open("20.txt").read())
value = get_pulse_count(nodes, 1000)

print(f"Part 1: {value}")

assert value == 812721756.0

Part 1: 812721756.0


In [952]:
import math


iterations = 0
activation = 0
targets = []

ANTE_LAST_NODE = "lv"

for node_id, node in nodes.items():
    if ANTE_LAST_NODE in node[0]:
        targets.append(node_id)

iterations = []
for target in targets:
    counter = 0
    activation = False
    state = get_initial_state(nodes)
    while not activation:
        _, state, activation = iterate(nodes=nodes, state=state, target_node=target)
        counter += 1
    iterations.append(counter)

value = math.lcm(*iterations)
print(f"Part 2: {value}")
assert value == 233338595643977

Part 2: 233338595643977
