In [18]:
import helper

INPUTS = helper.get_input(day=8)
TEST_INPUTS = """
162,817,812
57,618,57
906,360,560
592,479,940
352,342,300
466,668,158
542,29,236
431,825,988
739,650,466
52,470,668
216,146,977
819,987,18
117,168,530
805,96,715
346,949,466
970,615,88
941,993,340
862,61,35
984,92,344
425,690,689
""".strip()

print(TEST_INPUTS)

162,817,812
57,618,57
906,360,560
592,479,940
352,342,300
466,668,158
542,29,236
431,825,988
739,650,466
52,470,668
216,146,977
819,987,18
117,168,530
805,96,715
346,949,466
970,615,88
941,993,340
862,61,35
984,92,344
425,690,689


## Part 1

I'll admit my experience with DSA is a bit lacking:
I'm not the one who writes these algorithms,
I just learn how to recognize them and then apply them correctly

In this case we need a [disjoint set](https://en.wikipedia.org/wiki/Disjoint-set_data_structure),
also known as "union-find" data structure.
Points are arranged into individual sets inside the one `UnionFind` data structure
(installed from the [aoc-lube](https://github.com/salt-die/aoc_lube/blob/main/aoc_lube/utils.py) package, thanks salt-die!).
As we determine which two items go together
(after calculating their distances; `math.dist` does the Euclidean distance for us),
we call `.merge()` to combine the sets of the two items, forming a single set containing all of those items.

Once that's done, we find the 3 largest circuits by calling
[`heapq.nlargest`](https://docs.python.org/3/library/heapq.html#heapq.nlargest),
and multiply them together with a simple call to `math.prod` to get our final result.

In [None]:
from dataclasses import dataclass
import math


@dataclass
class Point3D:
    x: int
    y: int
    z: int

    @classmethod
    def from_input(cls, val: str):
        x, y, z = list(map(int, val.strip().split(",")))
        return cls(x=x, y=y, z=z)

    def __iter__(self):
        yield self.x
        yield self.y
        yield self.z

    def __hash__(self):
        return tuple.__hash__((self.x, self.y, self.z))


def prep_inputs(inputs: str) -> list[Point3D]:
    return list(
        map(
            Point3D.from_input,
            inputs.strip().splitlines(),
        ),
    )


p1 = Point3D.from_input("162,817,812")
p2 = Point3D.from_input("431,825,988")
math.dist(p1, p2)

321.560258738545

In [None]:
import itertools
from aoc_lube.utils import UnionFind
import heapq

math.dist


def part1(inputs: tuple[str, int]):
    points, pairs = inputs
    points = UnionFind(prep_inputs(inputs=points))
    distances = sorted(
        (math.dist(p1, p2), p1, p2)
        for p1, p2 in itertools.combinations(points.elements(), 2)
    )

    for _, p1, p2 in distances[:pairs]:
        points.merge(p1, p2)
    return math.prod(heapq.nlargest(3, map(len, points.components)))


helper.run_part(
    func=part1,
    test_inputs=(TEST_INPUTS, 10),
    test_expected=40,
    real_inputs=(INPUTS, 1000),
)

--- TEST ---
>> EXPECTED: 40
>> RESULT:   40 ✅ 

--- REAL ---
>> RESULT:   50568


## Part 2

The previous problem had us form connections for the first `n` closest junctions to form different circuits,
but this one has us continue the process until *every* set has been combined into one.
At that moment, the last 2 junction boxes to be connected are used instead to calculate our result,
which is simply `p1.x * p2.x`.

In [None]:
def part2(inputs: str) -> int:
    points = UnionFind(prep_inputs(inputs=inputs))
    distances = sorted(
        (math.dist(p1, p2), p1, p2)
        for p1, p2 in itertools.combinations(points.elements(), 2)
    )

    for _, p1, p2 in distances:
        points.merge(p1, p2)
        if len(points.components) == 1:
            return p1.x * p2.x
    raise ValueError("Something went wrong and we ran out of points to combine!")


helper.run_part(
    func=part2,
    test_inputs=TEST_INPUTS,
    test_expected=25272,
    real_inputs=INPUTS,
)

--- TEST ---
>> EXPECTED: 25272
>> RESULT:   25272 ✅ 

--- REAL ---
>> RESULT:   36045012
