In [None]:
with open("input.txt", "r") as file:
    data = file.read()
    data = data.splitlines()

example_data = ["162, 817, 812",
"57, 618, 57",
"906, 360, 560",
"592, 479, 940",
"352, 342, 300",
"466, 668, 158",
"542, 29, 236",
"431, 825, 988",
"739, 650, 466",
"52, 470, 668",
"216, 146, 977",
"819, 987, 18",
"117, 168, 530",
"805, 96, 715",
"346, 949, 466",
"970, 615, 88",
"941, 993, 340",
"862, 61, 35",
"984, 92, 344",
"425, 690, 689"]
example = False

In [273]:
def into_coordinates(line):
    x, y, z = line.split(",")
    return int(x), int(y), int(z)

coordinates = [into_coordinates(line) for line in data]
n_connections = 1000
if example:
    coordinates = [into_coordinates(line) for line in example_data]
    n_connections = 10





## Part 1

In [274]:
def pairwise_distances(coords):
    distances = [[0] * len(coords) for _ in range(len(coords))] 
    for i in range(len(coords)):
        for j in range(i + 1, len(coords)):
            x1, y1, z1 = coords[i]
            x2, y2, z2 = coords[j]
            diffs = x1 - x2, y1 - y2, z1 - z2
            dist = sum(diffs_i ** 2 for diffs_i in diffs) ** 0.5
            distances[i][j] = dist

    return distances 
#distances = pairwise_distances(coordinates)
distances = pairwise_distances(coordinates)



In [275]:
# Find the 500 closest pairs
all_pairs = [
    (distances[i][j], (i, j))
    for i in range(len(coordinates))
    for j in range(i + 1, len(coordinates))
    
]
sorted_pairs = sorted(all_pairs, key=lambda x: x[0])
closest_pairs = sorted_pairs[:n_connections]

# Compose "circuits"
circuits = []
for dist, (i, j) in closest_pairs:
    found = False
    for circuit in circuits:
        if i in circuit or j in circuit:
            circuit.add(i)
            circuit.add(j)
            found = True
            break
    if not found:
        circuits.append(set([i, j]))

# Join overlapping circuits
merged = True
while merged:
    merged = False
    new_circuits = []
    while circuits:
        current = circuits.pop()
        for other in circuits:
            if current & other:
                current |= other
                circuits.remove(other)
                merged = True
        new_circuits.append(current)
    circuits = new_circuits

# Order circuits by size
circuits.sort(key=lambda x: -len(x))
print(f"Three largest circuits sizes: {[len(circuit) for circuit in circuits[:3]]}")
print(f"Product of sizes: {len(circuits[0]) * len(circuits[1]) * len(circuits[2])}")

Three largest circuits sizes: [5, 4, 2]
Product of sizes: 40


## Part 2

In [276]:
continuing_pairs = sorted_pairs[n_connections:]
last_connected_pair = None

for idx, (dist, (i, j)) in enumerate(continuing_pairs):
    found = False
    for circuit in circuits:
        if i in circuit or j in circuit:
            circuit.add(i)
            circuit.add(j)
            found = True
            break
    if not found:
        circuits.append(set([i, j]))
    last_connected_pair = (i, j)
    # Join overlapping circuits    
    merged = True
    while merged:
        merged = False
        new_circuits = []
        while circuits:
            current = circuits.pop()
            for other in circuits:
                if current & other:
                    current |= other
                    circuits.remove(other)
                    merged = True
            new_circuits.append(current)
        circuits = new_circuits
    if len(circuits) == 1 and circuits[0] == set(range(len(coordinates))): # All connected and contains all nodes
        break
    

    

print(f"Last connected pair: {last_connected_pair}")
print("Final number of circuits:", len(circuits))
p1, p2 = coordinates[last_connected_pair[0]], coordinates[last_connected_pair[1]]
print(f"Coordinates of last connected pair: {p1} and {p2}")
if example:
    assert (p1 == (216, 146, 977) and p2 == (117, 168, 530)), "Unexpected coordinates"
print(f"Result: {p1[0]*p2[0]}")


Last connected pair: (10, 12)
Final number of circuits: 1
Coordinates of last connected pair: (216, 146, 977) and (117, 168, 530)
Result: 25272
