In [2]:
from aocd import get_data, submit
import numpy as np
import sys
import re
import math
np.set_printoptions(threshold=sys.maxsize)
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

from queue import PriorityQueue
from collections import defaultdict, Counter
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Tuple, Optional
from functools import reduce, cache
from operator import mul
from bisect import bisect_right

DIRECTIONS_4 = [(x, y) for x in [1, 0, -1] for y in [1, 0, -1] if x + y and (x == 0 or y == 0)]
DIRECTIONS_8 = [(x, y) for x in [1, 0, -1] for y in [1, 0, -1] if not (x ==0 and y == 0)]

def raw_read_input(day, hardcoded_input=None):
    return get_data(day=day, year=2023, block=True) if not hardcoded_input else hardcoded_input

def read_input(day, dtype=None, hardcoded_input=None):
    lines = raw_read_input(day=day, hardcoded_input=hardcoded_input).splitlines()
    if dtype is not None:
        lines = [dtype(x) if x else None for x in lines]
    return lines
    
def read_matrix(day, dtype=np.int32, hardcoded_input=None):
    lines = read_input(day, hardcoded_input=hardcoded_input)
    lines = [[dtype(x) for x in line] for line in lines]
    return np.array(lines, dtype=dtype)

# Day 1

In [2]:
lines = read_input(day=1)
digits_per_line = [[char for char in line if char.isdigit()] for line in lines]

def calc_sum(for_lines):
    return sum(int(line[0] + line[-1]) if line else 0 for line in for_lines)
print('part 1', calc_sum(digits_per_line))

def replace_digits(for_line):
    digit_names = ['zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine']
    result = ''
    for index, char in enumerate(for_line):
        for digit, digit_name in enumerate(digit_names):
            if char.isdigit():
                result += char
            elif for_line[index: index + len(digit_name)] == digit_name:
                result += str(digit)
    return result
        
true_digits_per_line = [replace_digits(line) for line in lines]
print('part 2', calc_sum(true_digits_per_line))

part 1 54597
part 2 54504


# Day 2

In [3]:
class Color(Enum):
    red = 'red'
    green = 'green'
    blue = 'blue'

@dataclass
class Game:
    game_number: int
    shows: List[Dict[Color, int]] = field(default_factory=list)

    def is_possible(self, combination: Dict[Color, int]) -> bool:
        for show in self.shows:
            for color in Color:
                if show[color] > combination[color]:
                    return False
        return True

    def power_number(self):
        possible = {color: 0 for color in Color}
        for show in self.shows:
            for color, amount in show.items():
                possible[color] = max(possible[color], amount)
        return reduce(mul, possible.values())


lines = read_input(day=2)
games = []

for line in lines:
    game_title, showings = line.split(':')
    game = Game(int(game_title.split(' ')[1]))

    for show in showings.split(';'):
        current_show = {color: 0 for color in Color}
        for color_group in show.split(', '):
            count, color = color_group.strip().split(' ')
            current_show[Color(color)] = int(count)
        game.shows.append(current_show)
    games.append(game)

part_1_check = {
    Color.red: 12,
    Color.green: 13,
    Color.blue: 14
}
print('part 1', sum(game.game_number for game in games if game.is_possible(part_1_check)))
print('part 2', sum(game.power_number() for game in games))

part 1 2331
part 2 71585


# Day 3

In [4]:
matrix = read_matrix(day=3, dtype=str)
expand = np.pad(matrix, (1, 1), constant_values=['.'])
is_symbol = np.vectorize(lambda x: not x.isdigit() and x != '.')(expand)
is_star = expand == '*'

gears = defaultdict(list)
total = 0

for i, line in enumerate(expand):
    number_start = None
    for j, char in enumerate(line):
        if char.isdigit():
            if number_start is None:
                number_start = j
        elif number_start is not None:
            if np.any(is_symbol[i - 1: i + 2, number_start - 1: j + 1]):
                number = int(''.join(line[number_start:j]))
                total += number
                ys, xs = np.where(is_star[i - 1: i + 2, number_start - 1: j + 1])
                for star_y, star_x in zip(ys.tolist(), xs.tolist()):
                    gears[(i - 1 + star_y, number_start - 1 + star_x)].append(number)
            number_start = None

print('part 1', total)
gear_ratio_sum = sum(
    gear[0] * gear[1]
    for gear in gears.values()
    if len(gear) == 2
)
print('part 2', gear_ratio_sum)

part 1 530849
part 2 84900879


# Day 4

In [5]:
lines = read_input(day=4)
total = 0
counts = [1 for _ in range(len(lines))]


for index, line in enumerate(lines):
    winners, numbers = line.split(': ')[1].split(' | ')
    winners = set(re.split('\s+', winners))
    numbers = set(re.split('\s+', numbers))    
    winning_numbers = winners & numbers
    if winning_numbers:
        total += 1 << (len(winning_numbers) - 1)
    for offset in range(len(winning_numbers)):
        counts[index + offset + 1] += counts[index]
print('part 1', total)
print('part 2', sum(counts))

part 1 22674
part 2 5747443


# Day 5

In [6]:
lines = read_input(day=5)
traverse_order = ['seed']
maps = defaultdict(list)
initial_seeds = list(map(int, lines[0].split(': ')[1].split()))
index = 1
current_map = None
while index < len(lines):
    if not lines[index].strip():
        line = lines[index + 1]
        map_title = line.split()[0].split('-to-')
        current_map = tuple(map_title)
        traverse_order.append(map_title[1])
        index += 2
        continue
    maps[current_map].append(list(map(int, lines[index].split())))
    index += 1

def get_location(start_id):
    for mapping_type in zip(traverse_order, traverse_order[1:]):
        for mapping in maps[mapping_type]:
            if start_id >= mapping[1] and start_id < mapping[1] + mapping[2]:
                start_id = mapping[0] + (start_id - mapping[1])
                break
    return start_id
 

def get_location_range(current_ranges):
    for mapping_type in zip(traverse_order, traverse_order[1:]):
        result_ranges = []
        for dest_start_id, mapping_start_id, mapping_length in maps[mapping_type]:
            new_current_ranges = []
            mapping_end_id = mapping_start_id + mapping_length - 1
            for start_id, length in current_ranges:
                end_id = start_id + length - 1
                if end_id < mapping_start_id or mapping_end_id < start_id:
                    new_current_ranges.append((start_id, length))
                    continue
                overlap_start_id = max(mapping_start_id, start_id)
                overlap_end_id = min(mapping_end_id, end_id)
                if start_id < overlap_start_id:
                    new_current_ranges.append((start_id, overlap_start_id - start_id))
                if overlap_end_id < end_id:
                    new_current_ranges.append((overlap_end_id + 1, end_id - overlap_end_id - 1))
                    
                new_start_id = dest_start_id + (overlap_start_id - mapping_start_id)
                new_length = overlap_end_id - overlap_start_id + 1
                result_ranges.append((new_start_id, new_length))
            current_ranges = new_current_ranges
        current_ranges += result_ranges

    return min(start_id for start_id, length in current_ranges)

print('part 1', min(
    get_location(seed_id)
    for seed_id in initial_seeds
))
print('part 2', get_location_range([
    (initial_seeds[index], initial_seeds[index + 1])
    for index in range(0, len(initial_seeds), 2)
]))

part 1 218513636
part 2 81956384


# Day 6

In [7]:
lines = read_input(day=6)
times, goals = [list(map(int, x.split(':')[1].split())) for x in lines]

def ways_to_beat(race_time, race_goal):
    return sum(
        time_holding * (race_time - time_holding) > race_goal
        for time_holding in range(1, race_time)
    )
    
print('part 1', reduce(mul, (
    ways_to_beat(race_time, race_goal)
    for race_time, race_goal in zip(times, goals)
)))

race_time, race_goal = [int("".join(x.split(':')[1].split())) for x in lines]
print('part 2', ways_to_beat(race_time, race_goal))

part 1 6209190
part 2 28545089


# Day 7

In [8]:
lines = read_input(day=7)

JOKER_STRENGHT = 1

def hand_strength(hand, use_joker):
    face_cards = {card: index + 10 for index, card in enumerate(['T', 'J', 'Q', 'K', 'A'])}
    if use_joker:
        face_cards['J'] = JOKER_STRENGHT

    face_values = [face_cards[card] if card in face_cards else int(card) for card in hand]
    card_counts = Counter(face_values)
    sorted_card_counts = card_counts.most_common()
    if len(sorted_card_counts) != 1 and JOKER_STRENGHT in card_counts:
        joker_count = card_counts[JOKER_STRENGHT]
        sorted_card_counts = [count for count in sorted_card_counts if count[0] != JOKER_STRENGHT]
        sorted_card_counts[0] = (sorted_card_counts[0][0], sorted_card_counts[0][1] + joker_count)
    sortable_values = [count[1] for count in sorted_card_counts]
    return "".join([chr(ord('a') + x) for x in sortable_values + face_values])


def deck_sum(deck, use_joker):
    cards_and_bids = [
        (hand_strength(line.split()[0], use_joker), int(line.split()[1]))
        for line in deck
    ]
    return sum(
        (index + 1) * bid
        for index, (_, bid) in enumerate(sorted(cards_and_bids, key=lambda x: x[0]))
    )
    
print('part 1', deck_sum(lines, use_joker=False))
print('part 2', deck_sum(lines, use_joker=True))

part 1 251136060
part 2 249400220


# Day 8

In [9]:
lines = read_input(day=8)
instructions, _, *mappings = lines
directions = {}

for mapping in mappings:
    source, destinations = mapping.split(' = ')
    dest_a, dest_b = destinations[1:-1].split(', ')
    directions[source] = (dest_a, dest_b)

def distance(start):
    index = 0
    current = start
    while not current.endswith('Z'):
        next_dir = 0 if instructions[index % len(instructions)] == 'L' else 1
        current = directions[current][next_dir]
        index += 1
    return index

print('part 1', distance('AAA'))
print('part 2', math.lcm(*[distance(node) for node in directions if node.endswith('A')]))

part 1 16043
part 2 15726453850399


# Day 9

In [10]:
lines = read_input(day=9)
lines = [list(map(int, line.split())) for line in lines]

total_right = 0
total_left = 0
for line in lines:
    rows = [[b - a for a, b in zip(line, line[1:])]]
    while any(rows[-1]):
        rows.append([b - a for a, b in zip(rows[-1], rows[-1][1:])])
    total_right += line[-1] + sum(row[-1] for row in rows)
    total_left += line[0] -  sum(row[0] * (-1 if index % 2 else 1) for index, row in enumerate(rows))
print('part 1', total_right)
print('part 2', total_left)

part 1 1762065988
part 2 1066


# Day 10

In [11]:
matrix = read_matrix(day=10, dtype=str)
expand = np.pad(matrix, (1, 1), constant_values=['.'])
NEXTS = {
    '|': [(1, 0), (-1, 0)],
    '-': [(0, -1), (0, 1)],
    'L': [(-1, 0), (0, 1)],
    'J': [(-1, 0), (0, -1)],
    '7': [(1, 0), (0, -1)],
    'F': [(1, 0), (0, 1)],
    '.': []
}

s_coords = np.where(expand == 'S')
start = s_coords[0][0], s_coords[1][0]

def find_loop(current, last):
    path = [last, current]
    while current != start:
        for dy, dx in NEXTS[expand[current]]:
            next_pos = current[0] + dy, current[1] + dx
            if next_pos == start:
                return path
            if next_pos == last or (-dy, -dx) not in NEXTS[expand[next_pos]]:
                continue
            path.append(next_pos)
            current, last = next_pos, current
            break
        else:
            return []
    return path
        

for dy, dx in DIRECTIONS_4:
    next_pos = (start[0] + dy, start[1] + dx)
    if (-dy, -dx) not in NEXTS[expand[next_pos]]:
        continue
    path = find_loop(next_pos, start)
    if path:
        break
print('part 1', len(path) // 2 + len(path) % 2)
start_char_nexts = {
    (path[1][0] - path[0][0], path[1][1] - path[0][1]),
    (path[-1][0] - path[0][0], path[-1][1] - path[0][1])
}
start_char = [char for char, nexts in NEXTS.items() if set(nexts) == start_char_nexts][0]
expand[start[0], start[1]] = start_char
path = set(path)

def is_inside_polygon(y, x):
    count = 0
    on_edge = None
    for x in range(x, expand.shape[1]):
        if (y, x) not in path:
            continue
        if expand[y, x] in ['|', on_edge]:
            count += 1
            on_edge = None
        elif expand[y, x] in ['F', 'L']:
            on_edge = 'J' if expand[y, x] == 'F' else '7'
                
    return count % 2 == 1
print('part 2', sum(
    is_inside_polygon(y, x)
    for y in range(matrix.shape[0])
    for x in range(matrix.shape[1])
    if (y, x) not in path
))

part 1 6649
part 2 601


# Day 11

In [12]:
matrix = read_matrix(day=11, dtype=str)

emptiness = matrix != '.'
empty_rows = np.cumsum(np.sum(emptiness, axis=1) == 0)
empty_columns = np.cumsum(np.sum(emptiness, axis=0) == 0)
ys, xs = np.where(matrix == '#')

def diffs(coords, empty, mult):
    coords = coords + empty[coords] * (mult - 1)
    all_coords = np.tile(coords, len(coords)).reshape(len(coords), len(coords))
    roll_amounts = np.arange(len(coords))
    rows, column_indices = np.ogrid[:all_coords.shape[0], :all_coords.shape[1]]
    column_indices = column_indices - roll_amounts[:, np.newaxis]
    return np.sum(np.abs(all_coords[rows, column_indices] - coords))

part_1_mult = 2
part_2_mult = 1000000
print('part 1', (diffs(xs, empty_columns, part_1_mult) + diffs(ys, empty_rows, part_1_mult)) // 2)
print('part 2', (diffs(xs, empty_columns, part_2_mult) + diffs(ys, empty_rows, part_2_mult)) // 2)

part 1 9522407
part 2 544723432977


# Day 12

In [13]:
lines = read_input(day=12)

@cache
def calculate_possibilities(num_count, line, groups):
    if len(groups) == 0:
        return int(num_count == 0 and '#' not in line)
    for index in range(len(line)):
        if line[index] == '#':
            num_count += 1
            if num_count > groups[0]:
                return 0
        elif line[index] == '.':
            if num_count > 0:
                if num_count != groups[0]:
                    return 0
                else:
                    num_count = 0
                    groups = groups[1:]
                    if not groups:
                        return int('#' not in line[index + 1:])
        elif line[index] == '?':
            with_group = calculate_possibilities(num_count + 1, line[index + 1:], groups)
            if num_count == 0:
               return (
                    with_group + 
                    calculate_possibilities(0, line[index + 1:], groups)
                )
            else:
                with_dot = (
                    calculate_possibilities(0, line[index + 1:], groups[1:])
                    if num_count == groups[0]
                    else 0
                )
                return with_dot + with_group
    return int(
        (num_count > 0 and len(groups) == 1 and groups[0] == num_count) or 
        num_count == 0 and not groups
    )

total = 0
total_2 = 0
for line in lines:
    values, groups_str = line.split()
    groups = tuple(map(int, groups_str.split(',')))
    answer = calculate_possibilities(0, values, groups)
    total += answer

    answer_2 = calculate_possibilities(0, '?'.join([values] * 5), groups * 5)
    total_2 += answer_2
print('part 1', total)
print('part 2', total_2)

part 1 7694
part 2 5071883216318


# Day 13

In [14]:
lines = read_input(day=13)

def calc_horizontal(matrix, error):
    for i in range(1, matrix.shape[0]):
        height = min(i, matrix.shape[0] - i)
        top = matrix[i - height:i, :]
        bottom = matrix[i: i + height]
        if np.sum(top == bottom[::-1, :]) == top.size - error:
            return i
    return 0
def calc_vertical(matrix, error):
    for i in range(1, matrix.shape[1]):
        width = min(i, matrix.shape[1] - i)
        left = matrix[:, i - width : i]
        right = matrix[:, i: i + width]
        if np.sum(left == right[:, ::-1]) == left.size - error:
            return i
    return 0

start_index = 0
total = 0
total_2 = 0
for line_index, line in enumerate(lines + ['']):
    if not line:
        matrix = np.array([list(row) for row in lines[start_index: line_index]])
        total += calc_horizontal(matrix, 0) * 100 + calc_vertical(matrix, 0)
        total_2 += calc_horizontal(matrix, 1) * 100 + calc_vertical(matrix, 1)
        start_index = line_index + 1
print('part 1', total)
print('part 2', total_2)

part 1 33047
part 2 28806


# Day 14

In [15]:
matrix = read_matrix(day=14, dtype=str)
expand = np.pad(matrix, (1, 1), constant_values=['#'])

def turn_up(table, reverse):
    working_table = table if not reverse else table[::-1, :].copy()
    stones = list(zip(*np.where(working_table == '#')))
    balls = list(zip(*np.where(working_table == 'O')))
    for index in range(working_table.shape[1]):
        stone_indexes = [y for y, x in stones if x == index]
        ball_indexes = [y for y, x in balls if x == index]
        new_column = np.full_like(working_table[:, index], '.')
        new_column[stone_indexes] = '#'
        offsets = defaultdict(lambda: 1)

        for ball in ball_indexes:
            stone_index = bisect_right(stone_indexes, ball) - 1
            new_column[stone_indexes[stone_index] + offsets[stone_index]] = 'O'
            offsets[stone_index] += 1
        table[:, index] = new_column if not reverse else new_column[::-1]

def turn_left(table, reverse):
    working_table = table if not reverse else table[:, ::-1].copy()
    stones = list(zip(*np.where(working_table == '#')))
    balls = list(zip(*np.where(working_table == 'O')))
    for index in range(working_table.shape[0]):
        stone_indexes = [x for y, x in stones if y == index]
        ball_indexes = [x for y, x in balls if y == index]
        new_column = np.full_like(working_table[index, :], '.')
        new_column[stone_indexes] = '#'
        offsets = defaultdict(lambda: 1)

        for ball in ball_indexes:
            stone_index = bisect_right(stone_indexes, ball) - 1
            new_column[stone_indexes[stone_index] + offsets[stone_index]] = 'O'
            offsets[stone_index] += 1
        table[index, :] = new_column if not reverse else new_column[::-1]

def start_cycling():
    old_states = []
    total_cycles = 1000000000
    for cycle in range(total_cycles):
        turn_up(expand, False)
        if cycle == 0:
            print('part 1', sum(expand.shape[0] - np.where(expand == 'O')[0] - 1))
        turn_left(expand, False)
        turn_up(expand, True)
        turn_left(expand, True)
    
        for match_index, old_state in enumerate(old_states):
            if np.all(old_state == expand):
                target_cycle = (total_cycles - match_index) % (cycle - match_index)
                return old_states[match_index + target_cycle - 1]
        old_states.append(expand.copy())

end_state = start_cycling()
print('part 2', sum(end_state.shape[0] - np.where(end_state == 'O')[0] - 1))

part 1 110090
part 2 95254


# Day 15

In [16]:
line = read_input(day=15)[0]

def HASH(word):
    total = 0
    for char in word:
        total = ((total + ord(char)) * 17) % 256
    return total

total = 0
boxes = defaultdict(list)
for word in line.split(','):
    total += HASH(word)

    label, op, strenght = word.partition('=' if '=' in word else '-')
    box = HASH(label)
    if op == '-':
        boxes[box] = [[lens_label, lens_strenght] for lens_label, lens_strenght in boxes[box] if lens_label != label]
    else:
        found = False
        for index, (lens_label, lens_strenght) in enumerate(boxes[box]):
            if lens_label == label:
                boxes[box][index][1] = strenght
                break
        else:
            boxes[box].append([label, strenght])

total_2 = sum(
    (1 + box_id) * (slot + 1) * int(lens[1])
    for box_id, lenses in boxes.items()
    for slot, lens in enumerate(lenses)
)
print('part 1', total)
print('part 2', total_2)

part 1 510273
part 2 212449


# Day 16

In [91]:
matrix = read_matrix(day=16, dtype=str)
all_beams = set()
for y in range(matrix.shape[0]):
    all_beams.add(((y, -1), (0, 1)))
    all_beams.add(((y, matrix.shape[1]), (0, -1)))
for x in range(matrix.shape[1]):
    all_beams.add(((-1, x), (1, 0)))
    all_beams.add(((matrix.shape[0], x), (-1, 0)))

def answer(starting_beam):
    found = set()
    beams = {starting_beam}
    while beams:
        beams = {
            next_beam
            for current_beam in beams
            for next_beam in calc_next(current_beam, found)
        }
    return len({pos for pos, _ in found})

def calc_next(beam, found):
    (y, x), (dy, dx) = beam
    next_pos = y + dy, x + dx
    pos_id = (next_pos, (dy, dx))
    if not (0 <= next_pos[0] < matrix.shape[0] and 0 <= next_pos[1] < matrix.shape[1]) or pos_id in found:
        return []
    found.add(pos_id)
    if matrix[next_pos] == '/':
        next_dirs = [(-dx, -dy)]
    elif matrix[next_pos] == '\\':
        next_dirs = [(dx, dy)]
    elif matrix[next_pos] == '-' and dy != 0:
        next_dirs = [(0, -1), (0, 1)]
    elif matrix[next_pos] == '|' and dx != 0:
        next_dirs = [(-1, 0), (1, 0)]
    else:
        next_dirs = [(dy, dx)]
    return [(next_pos, next_dir) for next_dir in next_dirs]

print('part 1', answer(((0, -1), (0, 1))))
print('part 2', max(answer(beam) for beam in all_beams))

part 1 7798
part 2 8026


# Day 17

In [218]:
matrix = read_matrix(day=17)
matrix = np.pad(matrix, (1, 1), constant_values=[-1])

@dataclass(order=True)
class State:
    cost: int
    pos: Tuple[int, int]
    last: Optional[Tuple[int, int]]
    straights: int

start = (1, 1)
end = (matrix.shape[0] - 2, matrix.shape[1] - 2)

def calculate_min_run(min_steps, max_steps):
    queue = PriorityQueue()
    queue.put(State(0, start, None, 0))
    best_found = {}
    best_end = None
    
    while not queue.empty():
        current = queue.get()
        for dy, dx in DIRECTIONS_4:
            next_pos = current.pos[0] + dy, current.pos[1] + dx
            same_direction = (
                current.last and
                (current.pos[0] - current.last[0] == dy) and
                (current.pos[1] - current.last[1] == dx)
            )
            if (
                (current.last and current.straights < min_steps and not same_direction) or
                (matrix[next_pos] == -1 or next_pos == current.last or (current.straights == max_steps and same_direction))
            ):
                continue
    
            new_straights = current.straights + 1 if same_direction else 1
            new_cost = current.cost + matrix[next_pos]
            new_best_index = (next_pos, new_straights, (dy, dx))
            if best_end and new_cost >= best_end:
                continue
            if new_best_index not in best_found or best_found[new_best_index].cost > new_cost: 
                new_state = State(
                    new_cost,
                    next_pos,
                    current.pos,
                    new_straights,
                )
                queue.put(new_state)
                best_found[new_best_index] = new_state
                if next_pos == end:
                    best_end = new_cost if not best_end else min(new_cost, best_end)
    return best_end
            
print('part 1', calculate_min_run(0, 3))
print('part 2', calculate_min_run(4, 10))

part 1 1256
part 2 1382


# Day 18

In [141]:
lines = read_input(day=18)
directions = [(0, 1), (1, 0), (0, -1), (-1, 0)]
letter_to_direction = {letter: directions[index] for index, letter in enumerate(['R', 'D', 'L', 'U'])}

def calc_answer(hex_parsing):
    current = (0, 0)
    perimeter = 0
    ys, xs = [], []
    for line in lines:
        direction, steps, hex = line.split()
        if hex_parsing:
            steps = int(hex[2:-2], 16)
            dy, dx = directions[int(hex[-2])]
        else:
            steps = int(steps)
            dy, dx = letter_to_direction[direction]
            
        next_point = (current[0] + dy * steps, current[1] + dx * steps)
        ys.append(next_point[0])
        xs.append(next_point[1])
        perimeter += abs(next_point[0] - current[0]) + abs(next_point[1] - current[1])
        current = next_point

    return int(0.5*np.abs(np.dot(xs,np.roll(ys,1))-np.dot(ys,np.roll(xs,1)))) + perimeter // 2 + 1
print('part 1', calc_answer(hex_parsing=False))
print('part 2', calc_answer(hex_parsing=True))

part 1 56923
part 2 66296566363189


# Day 19

In [306]:
lines = read_input(day=19)
line_break = lines.index('')
parts = lines[line_break + 1:]

workflows = {}

@dataclass
class Rule:
    redirect: str
    check_type: Optional[str] = None
    part_stat: Optional[str] = None
    check_value: Optional[int] = None

def parse_rule(rule_line: str) -> Rule:
    if '<' in rule_line:
            check_symbol = '<'
    elif '>' in rule_line:
        check_symbol = '>'
    else:
        check_symbol = None
    if check_symbol:
        check, redirect = rule_line.split(':')
        part_stat, check_value = check.split(check_symbol)
        return Rule(redirect=redirect, check_type=check_symbol, part_stat=part_stat, check_value=int(check_value))
    return Rule(redirect=rule_line)


workflows = {}
for workflow_line in lines[:line_break]:
    workflow_name, rules = workflow_line[:-1].split('{')
    rules = [
        parse_rule(rule_line)
        for rule_line in rules.split(',')
    ]
    workflows[workflow_name] = rules

parts = []
for part_line in lines[line_break + 1:]:
    part = {}
    for stat_line in part_line[1:-1].split(','):
        part_stat, part_value = stat_line.split('=')
        part[part_stat] = int(part_value)
    parts.append(part)

def is_accepted(part):
    current_workflow = 'in'
    while True:
        for rule in workflows[current_workflow]:
            if (
                rule.check_type is None or
                (rule.check_type == '<' and part[rule.part_stat] < rule.check_value) or
                (rule.check_type == '>' and part[rule.part_stat] > rule.check_value)
               ):
                if rule.redirect == 'A':
                    return True
                if rule.redirect == 'R':
                    return False
                current_workflow = rule.redirect
                break

print('part 1', sum(
    sum(part.values())
    for part in parts
    if is_accepted(part)
))          

def update_limit(old_limit, greater, new_value, inclusive):
    old_min, old_max = old_limit
    if greater:
        old_min = max(old_min, new_value + int(inclusive))
    else:
        old_max = min(old_max, new_value - int(inclusive))
    if old_min > old_max:
        return None
    return old_min, old_max
    

def find_distinctive(workflow_name, part_limits):
    part_limits = part_limits.copy()
    if workflow_name == 'R':
        return 0
    if workflow_name == 'A':
        return reduce(mul, [(max_value - min_value + 1) for min_value, max_value in part_limits.values()])
    total = 0
    for rule in workflows[workflow_name]:
        if rule.check_type is None:
            total += find_distinctive(rule.redirect, part_limits)
            break
        greater = rule.check_type == '>'
        new_limit = update_limit(part_limits[rule.part_stat], greater, rule.check_value, inclusive=True)
        if new_limit:
            new_part_limits = part_limits.copy()
            new_part_limits[rule.part_stat] = new_limit
            total += find_distinctive(rule.redirect, new_part_limits)
        new_limit = update_limit(part_limits[rule.part_stat], not greater, rule.check_value, inclusive=False)
        if new_limit is None:
            break
        part_limits[rule.part_stat] = new_limit
    return total
part_limits = {part_stat: (1, 4000) for part_stat in 'xmas'}
print('part 2', find_distinctive('in', part_limits))

part 1 418498
part 2 123331556462603


# Day 20

In [397]:
from tqdm import tqdm
test = '''broadcaster -> a
%a -> inv, con
&inv -> b
%b -> con
&con -> output'''
lines = read_input(day=20, hardcoded_input=None)

BROADCASTER = 'broadcaster'
FLIP_FLOP = '%'
CONJUNCTOR = '&'
target_conjunctors = {'jg', 'rh', 'jm', 'hf'}
target_conjunctors_cycle = []

module_targets = {}
module_types = {}
flip_flop_states = {}
conjunctor_inputs = defaultdict(dict)

for line in lines:
    module, targets = line.split(' -> ')
    if module != BROADCASTER:
        module_type = module[0]
        module = module[1:]
        module_types[module] = module_type
        if module_type == FLIP_FLOP:
            flip_flop_states[module] = False
    module_targets[module] = targets.split(', ')

for module, targets in module_targets.items():
    for target in targets:
        if target not in module_types:
            continue
        if module_types[target] == CONJUNCTOR:
            conjunctor_inputs[target][module] = False

low_pulses = 0
high_pulses = 0

for pulse_index in range(5000):
    if pulse_index == 1000:
        print('part 1', low_pulses * high_pulses)
    pulses = [(False, target, BROADCASTER) for target in module_targets[BROADCASTER]]
    low_pulses += 1
    while pulses:
        (is_high, target, source), *pulses = pulses
        if source in target_conjunctors and is_high:
            target_conjunctors_cycle.append(pulse_index + 1)
            if len(target_conjunctors_cycle) == len(target_conjunctors):
                print('part 2', math.lcm(*target_conjunctors_cycle))  
        if is_high:
            high_pulses += 1
        else:
            low_pulses += 1
        if target not in module_types:
            if not is_high:
                print(pulse_index)
            continue
        if module_types[target] == FLIP_FLOP and not is_high:
            flip_flop_states[target] = not flip_flop_states[target]
            pulses += [(flip_flop_states[target], response_target, target) for response_target in module_targets[target]]
        if module_types[target] == CONJUNCTOR:
            conjunctor_inputs[target][source] = is_high
            all_high = all(conjunctor_inputs[target].values())
            pulses += [(not all_high, response_target, target) for response_target in module_targets[target]] 

part 1 821985143
part 2 240853834793347


# Day 21