In [None]:
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Iterable

HERE = Path.cwd()
HERE

## Approach

We implement a Disjoint Set Union (Union-Find) to maintain circuits (connected components).

We generate all pairs $(i, j)$ with $i < j$ and compute squared distance:
$$d^2 = (x_i-x_j)^2 + (y_i-y_j)^2 + (z_i-z_j)^2$$
We never need square-roots because ordering by $d$ equals ordering by $d^2$.

To keep memory lower, we pack each edge into a single integer key so sorting is by `(distance^2, i, j)` automatically:
- reserve 10 bits each for `i` and `j` (enough for up to 1024 points),
- store `key = (dist2 << 20) | (i << 10) | j`.

Then:
- **Part 1:** iterate over the first 1000 sorted keys and union their endpoints; compute component sizes; multiply top 3.
- **Part 2:** iterate over all sorted keys, union when it merges two components; the final successful union gives the requested X-product.

In [None]:
def parse_points(text: str) -> list[tuple[int, int, int]]:
    pts: list[tuple[int, int, int]] = []
    for line in text.splitlines():
        line = line.strip()
        if not line:
            continue
        x_str, y_str, z_str = line.split(",")
        pts.append((int(x_str), int(y_str), int(z_str)))
    return pts


class DSU:
    def __init__(self, n: int):
        self.parent = list(range(n))
        self.size = [1] * n
        self.components = n

    def find(self, a: int) -> int:
        parent = self.parent
        while parent[a] != a:
            parent[a] = parent[parent[a]]
            a = parent[a]
        return a

    def union(self, a: int, b: int) -> bool:
        ra = self.find(a)
        rb = self.find(b)
        if ra == rb:
            return False
        # union by size
        if self.size[ra] < self.size[rb]:
            ra, rb = rb, ra
        self.parent[rb] = ra
        self.size[ra] += self.size[rb]
        self.components -= 1
        return True


def build_sorted_edge_keys(points: list[tuple[int, int, int]]) -> list[int]:
    n = len(points)
    if n > 1024:
        raise ValueError(f"Packing expects n<=1024, got n={n}")

    keys: list[int] = []
    keys_append = keys.append

    for i in range(n - 1):
        xi, yi, zi = points[i]
        for j in range(i + 1, n):
            xj, yj, zj = points[j]
            dx = xi - xj
            dy = yi - yj
            dz = zi - zj
            dist2 = dx * dx + dy * dy + dz * dz
            # sort by (dist2, i, j) via bit packing
            key = (dist2 << 20) | (i << 10) | j
            keys_append(key)

    keys.sort()
    return keys


def unpack_edge_key(key: int) -> tuple[int, int, int]:
    dist2 = key >> 20
    i = (key >> 10) & 1023
    j = key & 1023
    return dist2, i, j


def part1(points: list[tuple[int, int, int]], keys: list[int], k: int = 1000) -> int:
    n = len(points)
    dsu = DSU(n)

    for key in keys[: min(k, len(keys))]:
        _, i, j = unpack_edge_key(key)
        dsu.union(i, j)

    # collect component sizes
    sizes_by_root: dict[int, int] = {}
    for v in range(n):
        r = dsu.find(v)
        sizes_by_root[r] = sizes_by_root.get(r, 0) + 1

    top3 = sorted(sizes_by_root.values(), reverse=True)[:3]
    while len(top3) < 3:
        top3.append(1)

    return top3[0] * top3[1] * top3[2]


def part2(points: list[tuple[int, int, int]], keys: list[int]) -> int:
    n = len(points)
    dsu = DSU(n)

    for key in keys:
        _, i, j = unpack_edge_key(key)
        if dsu.union(i, j) and dsu.components == 1:
            xi = points[i][0]
            xj = points[j][0]
            return xi * xj

    raise RuntimeError("Never connected all points (unexpected for complete graph).")

## Validate against the example (`test.txt`)

From the problem statement:
- after the **10** shortest attempted connections, the three largest circuit sizes multiply to **40**
- the final connection that makes everything one circuit has X coordinates 216 and 117, product **25272**

In [None]:
test_points = parse_points((HERE / "test.txt").read_text())
test_keys = build_sorted_edge_keys(test_points)

assert part1(test_points, test_keys, k=10) == 40
assert part2(test_points, test_keys) == 25272

print("test ok")

## Solve the real input (`input.txt`)

In [None]:
points = parse_points((HERE / "input.txt").read_text())
keys = build_sorted_edge_keys(points)

ans1 = part1(points, keys, k=1000)
ans2 = part2(points, keys)

ans1, ans2