In [None]:
from tabulate import tabulate
from pydantic import BaseModel
import math as m

EXAMPLE = "../example.txt"
INPUT = "../input.txt"

In [None]:
class Machine(BaseModel):
    A: dict[str, int] = {"x": 0, "y": 0}
    B: dict[str, int] = {"x": 0, "y": 0}
    prize: dict[str, int] = {"x": 0, "y": 0}
    winnable: bool = True

In [None]:
def get_machines(input_file_name):
    machines = []
    with open(input_file_name, 'r') as f:
        machine_index = 0
        for line in f:
            if line == "\n":
                continue
            prefix, coords = line.replace("\n", "").split(':')
            if prefix == "Prize":
                x, y = coords.split(",")
                x, y = x.strip().split('=')[1], y.strip().split('=')[1]
                coords = {"x":int(x), "y":int(y)}
                machines[machine_index].prize = coords
                machine_index+=1
            else:
                button = prefix.split(" ")[1]
                x, y = coords.split(",")
                x, y = x.strip().split('+')[1], y.strip().split('+')[1]
                coords = {"x":int(x), "y":int(y)}
                if button == 'A':
                    new_machine = Machine(A=coords)
                    machines.append(new_machine)
                else:
                    machines[machine_index].B = coords
    return machines

In [None]:
machines = get_machines(EXAMPLE)
print(tabulate(machines))

In [None]:
def reduce(machine: Machine):
    # Check if the didiophantine equations have solutions
    for c in ["x", "y"]:
        cA = machine.A[c]
        cB = machine.B[c]
        cP = machine.prize[c]
        gcd = m.gcd(cA, cB)
        if cP % gcd != 0:
            machine.winnable = False
            return
        # Divide to get coprime coordinates 
        machine.A[c], machine.B[c], machine.prize[c] = cA//gcd, cB//gcd, cP//gcd

In [None]:
machine = machines[0]    
reduce(machine)
print(machine)

In [None]:
def solve_bezout(a, b):
    # Find Bézout coefficients for two coprime numbers, using Euclid's algorithm
    if m.gcd(a, b) != 1:
        raise(ValueError)
    r, u, v, r1, u1, v1 = a, 1, 0, b, 0, 1
    while r1 != 0:
        q = r//r1
        r, u, v, r1, u1, v1 = r1, u1, v1, r - q *r1, u - q*u1, v - q*v1
    if r != 1:
        raise(ValueError)
    return (u, v)

In [None]:
print(solve_bezout(120, 23))

In [None]:
def find_bezout_solutions(machine: Machine):
    # Find solutions for both diophantine equations
    bezout_solutions = {}
    for c in ["x", "y"]:
        cA = machine.A[c]
        cB = machine.B[c]
        cP = machine.prize[c]
        u, v = solve_bezout(cA, cB)
        bezout_solutions[c] = (cP*u, cP*v)
    return bezout_solutions

In [None]:
bezout_solutions = find_bezout_solutions(machine)
print(bezout_solutions)

In [None]:
def add_offset(machines: list[Machine]):
    for machine in machines:
        machine.prize["x"] += 10000000000000
        machine.prize["y"] += 10000000000000

In [None]:
def find_unique_solution(bezout_solutions, machine):
    # Calculate unique coefficients that satisfy both diophantine equations
    ux, vx, ax, bx = bezout_solutions["x"][0], bezout_solutions["x"][1], machine.A["x"], machine.B["x"]
    uy, vy, ay, by = bezout_solutions["y"][0], bezout_solutions["y"][1], machine.A["y"], machine.B["y"]
    # If there is an integer that verifies the following, there's a solution
    ky: float = (ax*(ux-uy) + bx*(vx-vy))/(ax*by-bx*ay)
    if ky.is_integer():
        u = uy + by * ky
        v = vy - ay * ky
        return (int(u), int(v))
    else:
        return None

In [None]:
def find_solution(machine: Machine):
    reduce(machine)
    if not machine.winnable:
        return None
    bezout_solutions = find_bezout_solutions(machine)
    solution = find_unique_solution(bezout_solutions, machine)
    return solution

In [None]:
def count_tokens(solution):
    return solution[0] * 3 + solution[1]

In [None]:
def get_total_count(input_file_name, offset = False):
    machines = get_machines(input_file_name)
    if offset:
        add_offset(machines)
    result = 0
    for machine in machines:
        solution = find_solution(machine)
        if solution:
            result += count_tokens(solution)
    return result

In [None]:
def part_1(input_file_name):
    print(get_total_count(input_file_name))

In [None]:
part_1(EXAMPLE)

In [None]:
part_1(INPUT)

In [None]:
def part_2(input_file_name):
    print(get_total_count(input_file_name, offset=True))

In [None]:
part_2(EXAMPLE)

In [None]:
part_2(INPUT)

Let's rewrite it without even using diophantine equations

In [None]:
def find_solution(machine: Machine):
    ax, bx, px = machine.A["x"], machine.B["x"], machine.prize["x"]
    ay, by, py = machine.A["y"], machine.B["y"], machine.prize["y"]
    kb: float = (ay*px - ax*py)/(ay*bx-ax*by)
    ka: float = (px - bx*kb)/ax
    if ka.is_integer() and kb.is_integer():
        return (int(ka), int(kb))
    return None

In [None]:
def get_total_count(input_file_name, offset = False):
    machines = get_machines(input_file_name)
    if offset:
        add_offset(machines)
    result = 0
    for machine in machines:
        solution = find_solution(machine)
        if solution:
            result += count_tokens(solution)
    return result

In [None]:
def part_1(input_file_name):
    print(get_total_count(input_file_name))

part_1(EXAMPLE)
part_1(INPUT)

def part_2(input_file_name):
    print(get_total_count(input_file_name, offset=True))
part_2(EXAMPLE)
part_2(INPUT)