In [28]:
from math import prod
from pprint import pprint
import re
from functools import reduce
import itertools
from collections import Counter
import numpy as np
from dataclasses import dataclass


sample = """px{a<2006:qkq,m>2090:A,rfg}
pv{a>1716:R,A}
lnx{m>1548:A,A}
rfg{s<537:gd,x>2440:R,A}
qs{s>3448:A,lnx}
qkq{x<1416:A,crn}
crn{x>2662:A,R}
in{s<1351:px,qqz}
qqz{s>2770:qs,m<1801:hdj,R}
gd{a>3333:R,R}
hdj{m>838:A,pv}

{x=787,m=2655,a=1222,s=2876}
{x=1679,m=44,a=2067,s=496}
{x=2036,m=264,a=79,s=2244}
{x=2461,m=1339,a=466,s=291}
{x=2127,m=1623,a=2188,s=1013}"""

@dataclass
class Part():
    x: int
    m: int
    a: int
    s: int


def make_rule(key, operator, value, next_sorter):
    def rule(x):
        if eval(f'{x[key]} {operator} {value}'):
            return next_sorter
        return None
    return rule

def make_range_rule(key, operator, value, next_sorter):

    def rule(x):
        xmin,xmax = x[key]
        thresh = int(value)
        l,r = (None, next_sorter) if operator == '>' else (next_sorter,None)
        t,f = next_sorter, None

        steps = []

        offset_lo = +1 if operator == '>' else 0
        offset_hi = -1 if operator == '<' else 0

        if xmin < thresh:
            pi = dict(**x)
            pi[key] = (xmin,min(thresh+offset_hi,xmax))
            steps.append((pi,l))
        if xmax > thresh:
            pi = dict(**x)
            pi[key] = (max(xmin,thresh+offset_lo),xmax)
            steps.append((pi,r))




        # if xmin < thresh:
        #     pi = dict(**x)
        #     pi[key] = (xmin,min(thresh,xmax))
        #     steps.append((pi,l))
        # if xmax > thresh:
        #     pi = dict(**x)
        #     pi[key] = (max(xmin,thresh+1),xmax)
        #     steps.append((pi,r))
        return steps

    return rule

class Sorter():
    rules: list
    raw_rules: str
    name:str
    def __init__(self,name,rules):
        self.rules = []
        self.range_rules = []
        self.name=name
        self.raw_rules = rules
        rules = rules.split(',')
        for r in rules:
            if r.isalpha():
                self.rules.append(lambda x: r)
                self.range_rules.append(lambda x: [(x,r)])
            else:
                pattern = re.compile(r'([xmas])([<>])(\d+):(\w+)')
                key, operator, value, next_sorter = pattern.match(r).groups()
                self.rules.append(make_rule(key, operator, value, next_sorter))
                self.range_rules.append(make_range_rule(key, operator, value, next_sorter))

    def __repr__(self) -> str:
        return f'{self.name}: {self.raw_rules}'

    def evaluate(self, part):
        for r in self.rules:
            next_sorter = r(part)
            if next_sorter:
                return next_sorter
        return None


    def evaluate_range(self, part, sorters):
        part
        for r in self.range_rules:
            steps = r(part)
            for pi, next_sorter in steps:
                if not next_sorter:
                    part = pi
                elif next_sorter not in  'AR':
                    yield from sorters[next_sorter].evaluate_range(pi,sorters)
                elif next_sorter == 'A':
                    yield pi
                else:
                    pass


def get_input(n):
    with open('input_'+n+'.txt', 'r') as infile:
        return infile.read().strip()
puzzle = get_input('19')

def parse_input(puzzle):
    sDict = {}
    sorters, parts = puzzle.split('\n\n')
    for s in sorters.split('\n'):
        sorter_name, rules = s.split('{')
        sDict[sorter_name] = Sorter(sorter_name, rules[:-1])
    partsList = list(map(lambda x: eval(x.replace('{','dict(').replace('}',')')), parts.split('\n')))
    return sDict, partsList

def score_part(sorters, part):
    next_sorter = 'in'
    while True:
        if next_sorter in 'AR':
            break
        sorter = sorters[next_sorter]
        next_sorter = sorter.evaluate(part)
    if next_sorter  == 'A':
        return sum(part.values())
    return 0

def part_2(sorters):
    next_sorter = 'in'
    part = { i :(1, 4000) for i in 'xmas'}


    accepted_parts = list(sorters[next_sorter].evaluate_range(part, sorters))

    score  = sum(map(lambda x: prod(map(lambda y:y[1]-y[0]+1,x.values())), accepted_parts))
    diff = score -167409079868000
    for p in accepted_parts:  print(p)
    return score

def solve1(puzzle):
    sorters, parts =  parse_input(puzzle)
    return sum(map(lambda x: score_part(sorters, x), parts))

def solve2(puzzle):
    sorters,_ =  parse_input(puzzle)
    return part_2(sorters)

solve1(sample)
solve2(sample)

{'x': (1, 1415), 'm': (1, 4000), 'a': (1, 2005), 's': (1, 1350)}
{'x': (2663, 4000), 'm': (1, 4000), 'a': (1, 2005), 's': (1, 1350)}
{'x': (1, 4000), 'm': (2091, 4000), 'a': (2006, 4000), 's': (1, 1350)}
{'x': (1, 2440), 'm': (1, 2090), 'a': (2006, 4000), 's': (537, 1350)}
{'x': (1, 4000), 'm': (1, 4000), 'a': (1, 4000), 's': (3449, 4000)}
{'x': (1, 4000), 'm': (1549, 4000), 'a': (1, 4000), 's': (2771, 3448)}
{'x': (1, 4000), 'm': (1, 1548), 'a': (1, 4000), 's': (2771, 3448)}
{'x': (1, 4000), 'm': (839, 1800), 'a': (1, 4000), 's': (1351, 2770)}
{'x': (1, 4000), 'm': (1, 838), 'a': (1, 1716), 's': (1351, 2770)}


167409079868000

In [None]:
solve1(puzzle)

383682

In [30]:
solve2(puzzle)

{'x': (1, 1006), 'm': (1, 1613), 'a': (1669, 2150), 's': (1, 695)}
{'x': (1, 1006), 'm': (1, 1613), 'a': (2151, 2504), 's': (1, 695)}
{'x': (1, 1914), 'm': (1, 1613), 'a': (757, 1668), 's': (1, 695)}
{'x': (1207, 1914), 'm': (1, 673), 'a': (1, 756), 's': (322, 695)}
{'x': (3595, 4000), 'm': (1, 2371), 'a': (1, 2504), 's': (1, 695)}
{'x': (3595, 4000), 'm': (2372, 4000), 'a': (1, 2504), 's': (1, 695)}
{'x': (1915, 2465), 'm': (1, 2496), 'a': (1648, 2504), 's': (1, 695)}
{'x': (2466, 2714), 'm': (1, 4000), 'a': (1648, 2195), 's': (414, 695)}
{'x': (2466, 2714), 'm': (1, 4000), 'a': (1648, 2195), 's': (1, 241)}
{'x': (2466, 2714), 'm': (1, 4000), 'a': (1648, 1898), 's': (242, 413)}
{'x': (1915, 2447), 'm': (1, 4000), 'a': (1, 1647), 's': (1, 247)}
{'x': (2448, 2905), 'm': (2223, 4000), 'a': (1, 1647), 's': (1, 247)}
{'x': (1915, 2106), 'm': (1, 4000), 'a': (1, 1647), 's': (248, 695)}
{'x': (1, 4000), 'm': (553, 920), 'a': (2505, 4000), 's': (305, 695)}
{'x': (2015, 4000), 'm': (1, 202), '

117954800808317