In [None]:
import numpy as np
from scipy.spatial.distance import pdist, squareform

In [None]:
# read the data
test = False
filename = "test.txt" if test else "input.txt"
points = np.loadtxt(filename, dtype=int, delimiter=','  )
niter = 10 if test else 1000
# calculate distance matrix
dists = squareform(pdist(points))
maxval = np.max(dists) * 10
dists += np.eye(dists.shape[0]) * maxval  # so we don't self connect


In [None]:
# version two combining connecting and merging
def solver(niter, dists=dists, points=points, maxval=maxval):
    circuits = []
    _dists = dists.copy()
    last_two = None
    connections = []
    for it in range(niter):
        # find the closest pair
        i,j = np.where(_dists == np.min(_dists))[0]
        # now just for checking
        connections.append((i,j))
        # overwrite their distance with max
        _dists[i,j] = maxval
        _dists[j,i] = maxval
        print(f"iter {it}: connect {i} to {j}, dist={dists[i,j]}, {points[i]} and {points[j]}")
        # merge into circuits
        circuits.append(set((i,j)))
        prev = len(circuits)+1
        n = len(circuits)
        print("merging:")
        while n != prev:
            print(f" circuits count: {n}")
            prev = n
            for _i in range(len(circuits)-1):
                for _j in range(_i+1, len(circuits)):
                    if len(circuits[_i].intersection(circuits[_j])) > 0:
                        # merge
                        circuits[_i] = circuits[_i].union(circuits[_j])
                        circuits[_j] = set()
            # remove empty circuits
            circuits = [c for c in circuits if len(c) > 0]
            n = len(circuits)
            if n == 1 and len(circuits[0]) == points.shape[0]:
                print("All points connected!")
                last_two = (i,j)
                return circuits, connections, last_two
    return circuits, connections, last_two


In [None]:
# part 1
circuits, connections, last_two = solver(niter)
circuit_sizes = [len(c) for c in circuits]
result = np.prod(sorted(circuit_sizes)[-3:])
print(f"Result Part 1: {result}")

In [None]:
# part 2
circuits, connections, last_two = solver(1000000)
print(points[last_two[0]], points[last_two[1]])
print(f"Result Part 2: {points[last_two[0]][0]*points[last_two[1]][0]}")
