In [20]:
from collections import deque
from itertools import count
import math

# Load the puzzle input from a file
with open('/content/AoC_2023_Day20.txt', 'r') as f:
    puzzle_input = f.read()

def parse_input(puzzle_input):
    """
    Parse the puzzle input to create the graph structure,
    initialize the states of flip-flop modules and conjunction memory.

    Returns:
        graph (dict): Maps each module to its destination modules.
        flip_flop (dict): Stores the state (on/off) of flip-flop modules.
        memory (dict): Tracks the most recent pulse received by conjunction modules.
    """
    graph = {}
    flip_flop = {}
    memory = {}

    for line in puzzle_input.split('\n'):
        # Split each line into a source module and its destinations
        source, destinations = line.split(' -> ')
        destinations = destinations.split(', ')
        module_name = source.lstrip('%&')  # Remove type prefixes (% or &) from the source name

        # Build the graph representing module connections
        graph[module_name] = destinations

        # Initialize the state of flip-flop modules (off by default)
        if source.startswith('%'):
            flip_flop[module_name] = 0
        # Initialize memory for conjunction modules (default to low pulse)
        elif source.startswith('&'):
            memory[module_name] = {}

    # For each conjunction module, initialize memory to track the most recent pulse
    # received from each connected input module
    for conjunction in memory.keys():
        for source, destinations in graph.items():
            if conjunction in destinations:
                memory[conjunction][source] = 0  # Initially remember low pulses

    return graph, flip_flop, memory

def part1(puzzle_input):
    """
    Simulate the pulse propagation for 1000 button presses and calculate the product
    of the total number of low and high pulses sent.

    Returns:
        int: The product of the total low and high pulses sent.
    """
    # Parse input to create the graph and initialize states
    graph, flip_flop, memory = parse_input(puzzle_input)

    # Initialize pulse count: signal_count[0] for low pulses, signal_count[1] for high pulses
    signal_count = [0, 0]

    for _ in range(1000):
        # The initial low pulse from the button to the broadcaster
        signal_count[0] += 1

        # Queue holds (source module, destination module, pulse signal) tuples for processing
        queue = deque([('broadcaster', dest, 0) for dest in graph['broadcaster']])

        while queue:
            # Dequeue a pulse and process it
            out_module, in_module, signal = queue.popleft()
            signal_count[signal] += 1  # Increment the count of low or high pulses

            if in_module in flip_flop and signal == 0:
                # If a flip-flop module receives a low pulse, it toggles its state
                flip_flop[in_module] = 1 - flip_flop[in_module]  # Toggle state (0 -> 1 or 1 -> 0)
                out_signal = flip_flop[in_module]

            elif in_module in memory:
                # Update the conjunction module's memory with the received pulse
                memory[in_module][out_module] = signal
                # Conjunction module sends a low pulse if all inputs are high, otherwise a high pulse
                out_signal = 1 if 0 in memory[in_module].values() else 0

            else:
                continue  # Non-flip-flop/non-conjunction modules do not send pulses further

            # Enqueue the next pulses to be sent based on the current module's output
            queue.extend([(in_module, nxt, out_signal) for nxt in graph[in_module]])

    # Return the product of the total low and high pulses sent
    return math.prod(signal_count)

def part2(puzzle_input):
    """
    Determine the fewest number of button presses required to deliver a single low pulse
    to the module named 'rx'.

    Returns:
        int: The fewest number of button presses required.
    """
    # Parse input to create the graph and initialize states
    graph, flip_flop, memory = parse_input(puzzle_input)

    # Identify the module that sends a pulse directly to 'rx'
    final_layer = [module for module in graph if 'rx' in graph[module]]
    assert len(final_layer) == 1, "There should be exactly one module pointing to rx"
    assert final_layer[0] in memory, "The final module before rx should be a conjunction module"

    # Identify the modules that feed into the final module before 'rx'
    semi_final_layer = {module for module in graph if final_layer[0] in graph[module]}
    cycle_lengths = []  # To store the cycle lengths of modules in the semi-final layer

    # Simulate pulse propagation until the cycle that delivers a low pulse to 'rx' is found
    for button_push in count(1):
        # Queue holds (source module, destination module, pulse signal) tuples for processing
        queue = deque([('broadcaster', dest, 0) for dest in graph['broadcaster']])

        while queue:
            # Dequeue a pulse and process it
            out_module, in_module, signal = queue.popleft()

            if in_module in flip_flop and signal == 0:
                # If a flip-flop module receives a low pulse, it toggles its state
                flip_flop[in_module] = 1 - flip_flop[in_module]
                out_signal = flip_flop[in_module]

            elif in_module in memory:
                # Update the conjunction module's memory with the received pulse
                memory[in_module][out_module] = signal
                out_signal = 1 if 0 in memory[in_module].values() else 0

                # If a module in the semi-final layer sends a high pulse, record the cycle length
                if in_module in semi_final_layer and out_signal == 1:
                    cycle_lengths.append(button_push)
                    semi_final_layer.remove(in_module)

            else:
                continue  # Non-flip-flop/non-conjunction modules do not send pulses further

            # Enqueue the next pulses to be sent based on the current module's output
            queue.extend([(in_module, nxt, out_signal) for nxt in graph[in_module]])

        # If all modules in the semi-final layer have signaled, break out of the loop
        if not semi_final_layer:
            break

    # Return the least common multiple of cycle lengths as the result
    return math.lcm(*cycle_lengths)

# Execute and print results for both parts of the puzzle
print('Part 1:', part1(puzzle_input))
print('Part 2:', part2(puzzle_input))


Part 1: 807069600
Part 2: 221453937522197
