In [1]:
from collections import deque

In [2]:
def parse_input(file_name):
    with open(file_name) as f:
        data = f.read()
    data = [d.split(",") for d in data.split("\n")]
    data = [[int(x) for x in d] for d in data]
    return data
input = parse_input("input.txt")
example = parse_input("example.txt")

In [3]:
def dist(x, y):
    x1, x2, x3 = x
    y1, y2, y3 = y
    return (x1 - y1) ** 2 + (x2 - y2) ** 2 + (x3 - y3) ** 2

In [4]:
class UnionFind():
    def __init__(self, n):
        self.parent = [i for i in range(n)]
        self.size = [1 for _ in range(n)]
        self.count = n

    def find(self, x):
        if x != self.parent[x]:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]
    
    def union(self, x, y):
        px, py = self.find(x), self.find(y)
        if px == py:
            return False
        if self.size[px] <= self.size[py]:
            self.parent[px] = py
            self.size[py] += self.size[px]
        else:
            self.parent[py] = px
            self.size[px] += self.size[py]
        self.count -= 1
        return True

In [5]:
def part_1(input, max_circuits):
    n = len(input)
    q = []
    for i in range(n):
        for j in range(i + 1, n):
            q.append((i, j, dist(input[i], input[j])))
    q.sort(key=lambda x: x[2])
    q = deque(q)

    uf = UnionFind(n)
    for _ in range(max_circuits):
        i, j, _ = q.popleft()
        uf.union(i, j)
    components = uf.size
    components.sort()
    components = components[-3:]
    return components[0] * components[1] * components[2]

In [6]:
assert(part_1(example, 10) == 40)
print(part_1(input, 1000))

79560


In [7]:
def part_2(input):
    n = len(input)
    q = []
    for i in range(n):
        for j in range(i + 1, n):
            q.append((i, j, dist(input[i], input[j])))
    q.sort(key=lambda x: x[2])
    q = deque(q)

    uf = UnionFind(n)
    while q and uf.count > 1:
        i, j, _ = q.popleft()
        uf.union(i, j)
    return input[i][0] * input[j][0]

In [8]:
assert(part_2(example) == 25272)
print(part_2(input))

31182420
