In [None]:
from tabulate import tabulate
from pydantic import BaseModel

EXAMPLE = "../example.txt"
INPUT = "../input.txt"

In [None]:
class Machine(BaseModel):
    A: dict[str, int] = {"x": 0, "y": 0}
    B: dict[str, int] = {"x": 0, "y": 0}
    prize: dict[str, int] = {"x": 0, "y": 0}
    winnable: bool = True

In [None]:
def get_machines(input_file_name):
    machines = []
    with open(input_file_name, 'r') as f:
        machine_index = 0
        for line in f:
            if line == "\n":
                continue
            prefix, coords = line.replace("\n", "").split(':')
            if prefix == "Prize":
                x, y = coords.split(",")
                x, y = x.strip().split('=')[1], y.strip().split('=')[1]
                coords = {"x":int(x), "y":int(y)}
                machines[machine_index].prize = coords
                machine_index+=1
            else:
                button = prefix.split(" ")[1]
                x, y = coords.split(",")
                x, y = x.strip().split('+')[1], y.strip().split('+')[1]
                coords = {"x":int(x), "y":int(y)}
                if button == 'A':
                    new_machine = Machine(A=coords)
                    machines.append(new_machine)
                else:
                    machines[machine_index].B = coords
    return machines

In [None]:
machines = get_machines(EXAMPLE)
print(tabulate(machines))

In [None]:
import math as m

def reduce(machine: Machine):
    for c in ["x", "y"]:
        cA = machine.A[c]
        cB = machine.B[c]
        cP = machine.prize[c]
        gcd = m.gcd(cA, cB)
        if cP % gcd != 0:
            machine.winnable = False
            return
        machine.A[c], machine.B[c], machine.prize[c] = cA//gcd, cB//gcd, cP//gcd

In [None]:
machine = machines[0]    
reduce(machine)
print(machine)

In [None]:
def solve_bezout(a, b):
    if m.gcd(a, b) != 1:
        raise(ValueError)
    r, u, v, r1, u1, v1 = a, 1, 0, b, 0, 1
    while r1 != 0:
        q = r//r1
        r, u, v, r1, u1, v1 = r1, u1, v1, r - q *r1, u - q*u1, v - q*v1
    if r != 1:
        raise(ValueError)
    return (u, v)

In [None]:
print(solve_bezout(120, 23))

In [None]:
def find_bezout_solutions(machine: Machine):
    bezout_solutions = {}
    for c in ["x", "y"]:
        cA = machine.A[c]
        cB = machine.B[c]
        cP = machine.prize[c]
        u, v = solve_bezout(cA, cB)
        bezout_solutions[c] = (cP*u, cP*v)
    return bezout_solutions

In [None]:
bezout_solutions = find_bezout_solutions(machine)
print(bezout_solutions)

In [None]:
def find_range(u, v, a, b):
    dir = -1 if u >= 0 else 1
    k = 0
    solutions = set()
    while u+k*b < 0:
        k += dir
    while (dir == -1 and u+k*b >= 0) or (dir == 1 and u+k*b <= 100):
        if v-k*a >= 0 and v-k*a <= 100 and u+k*b >= 0 and u+k*b <= 100:
            solutions.add((u+k*b, v-k*a))
        k += dir
    return solutions

In [None]:
range_x = find_range(bezout_solutions['x'][0], bezout_solutions['x'][1], machine.A['x'], machine.B['x'])
range_y = find_range(bezout_solutions['y'][0], bezout_solutions['y'][1], machine.A['y'], machine.B['y'])
print(range_x)
print(range_y)
print(range_x.intersection(range_y))

In [None]:
def find_strategies(machine: Machine):
    reduce(machine)
    if not machine.winnable:
        return set()
    bezout_solutions = find_bezout_solutions(machine)
    ranges = {}
    for c in ["x", "y"]:
        ranges[c] = find_range(bezout_solutions[c][0], bezout_solutions[c][1], machine.A[c], machine.B[c])
    return ranges["x"].intersection(ranges["y"])
    

In [None]:
print(find_strategies(machine))

In [None]:
def count_tokens(strategy):
    return strategy[0] * 3 + strategy[1]

In [None]:
def find_min(machine: Machine):
    strategies = find_strategies(machine)
    if not machine.winnable: 
        return 0
    tokens = []
    for strategy in strategies:
        tokens.append(count_tokens(strategy))
    if len(tokens) == 0:
        return 0
    return min(tokens)

In [None]:
print(find_min(machine))

In [None]:
for machine in machines:
    print(find_min(machine))

In [None]:
def part_1(input_file_name):
    machines = get_machines(input_file_name)
    result = 0
    for machine in machines:
        result += find_min(machine)
    print(result)

In [None]:
part_1(EXAMPLE)

In [None]:
part_1(INPUT)