In [None]:
import math
from tqdm import tqdm

In [None]:
# filename = "sample.txt"
filename = "input.txt"
with open(filename, encoding="utf-8") as f:
    data = f.read()

lines = data.strip().split("\n")
points = [tuple(map(int, l.split(","))) for l in lines]

In [None]:
def sq_dist(a: tuple[int, ...], b: tuple[int, ...]) -> int:
    return sum((x2 - x1) ** 2 for x1, x2 in zip(a, b))

In [None]:
## Part 1
# (sqdist, point1, point2)
# dists_dict = dict()
dists = []
for i, p1 in tqdm(enumerate(points), total=len(points)):
    for p2 in points[i+1:]:
        d = sq_dist(p1, p2)
        # dists_dict[(p1, p2)] = d
        dists.append((d, p1, p2))

dists.sort()

In [None]:
circuits: list[set[tuple[int, ...]]] = []

def find_circuit(p: tuple[int, ...]) -> set[tuple[int, ...]] | None:
    for c in circuits:
        if p in c:
            return c
    return None

# n_connections = 10
n_connections = 1000
# Connect the first 1000 pairs
for i, (_d, p1, p2) in tqdm(enumerate(dists[:n_connections], 1)):
    s1 = find_circuit(p1)
    s2 = find_circuit(p2)
    # print(f"{p1=} {p2=}")
    # print(f"{s1=} {s2=}")
    if s1 and s2:
        if s1 == s2:
            # They're already in the same set
            continue
        s1.update(s2)
        circuits.remove(s2)
    elif s1:
        s1.add(p2)
    elif s2:
        s2.add(p1)
    else:
        circuits.append({p1, p2})
    # print(f"After step {i} {circuits=}")

# (Plus singleton circuits)

In [None]:
# Prod sizes of 3 largest circuits
ls = [len(c) for c in circuits]
top3 = sorted(ls, reverse=True)[:3]
math.prod(top3)

In [None]:
# Part 2
n_points = len(points)
seen = 0
circuits: list[set[tuple[int, ...]]] = []

def find_circuit(p: tuple[int, ...]) -> set[tuple[int, ...]] | None:
    for c in circuits:
        if p in c:
            return c
    return None

for i, (_d, p1, p2) in tqdm(enumerate(dists, 1), total=len(dists)):
    if (i % 1000) == 0:
        print(f"Step {i}: {seen=} {len(circuits)=}")
    s1 = find_circuit(p1)
    s2 = find_circuit(p2)
    if s1 and s2:
        if s1 == s2:
            # They're already in the same set
            continue
        s1.update(s2)
        circuits.remove(s2)
    elif s1:
        s1.add(p2)
        seen += 1
    elif s2:
        s2.add(p1)
        seen += 1
    else:
        circuits.append({p1, p2})
        seen += 2
    
    # Only consider stopping after every point's been seen
    if (seen == n_points) and (len(circuits) == 1):
        end_p1 = p1
        end_p2 = p2
        print(f"{end_p1=} {end_p2=}")
        print(f"Part 2: {end_p1[0] * end_p2[0]=}")
        break