In [1]:
import numpy as np

In [2]:
test = ['162,817,812',
        '57,618,57',
        '906,360,560',
        '592,479,940',
        '352,342,300',
        '466,668,158',
        '542,29,236',
        '431,825,988',
        '739,650,466',
        '52,470,668',
        '216,146,977',
        '819,987,18',
        '117,168,530',
        '805,96,715',
        '346,949,466',
        '970,615,88',
        '941,993,340',
        '862,61,35',
        '984,92,344',
        '425,690,689']

In [3]:
def get_separations(boxes):
    separations = []
    for i, box in enumerate(boxes):
        separations.append(np.sqrt(np.sum(np.square(boxes-box), axis=1)))
        separations[-1][i] = 1e12
    return np.array(separations)

def where_equal(array, value):
    for i in range(0, len(array)):
        for j in range(0, len(array[i])):
            if array[i][j] == value:
                return i, j
    return -1, -1

def is_in(array, value):
    for i in range(0, len(array)):
        for j in range(0, len(array[i])):
            if array[i][j] == value:
                return True
    return False

def flat_list(lst):
    flat_lst = []
    for line in lst:
        flat_lst += line
    return flat_lst

#def merge_circuits(circuits):

def get_circuits(boxes, connections=10):
    separations = get_separations(boxes)
    circuits = []
    for i in range(0, connections):
        idx = np.unravel_index(separations.argmin(), separations.shape)
        
        if is_in(circuits, idx[0]):
            circuit0 = where_equal(circuits, idx[0])[0]
        else:
            circuit0 = None
        if is_in(circuits, idx[1]):
            circuit1 = where_equal(circuits, idx[1])[0]
        else:
            circuit1 = None

        if circuit0 is None and circuit1 is None:
            circuits.append([idx[0], idx[1]])
        elif circuit0 is not None and circuit1 is None:
            circuits[circuit0].append(idx[1])
        elif circuit0 is None and circuit1 is not None:
            circuits[circuit1].append(idx[0])
        elif circuit0 != circuit1:
            circuits[circuit0] += circuits[circuit1]
            circuits.pop(circuit1)
            
        separations[idx[0],idx[1]] = 1e12
        separations[idx[1],idx[0]] = 1e12

    for i in range(0, len(boxes)):
        if i not in flat_list(circuits):
            circuits.append([i])
        
    return circuits

def get_last_connection(boxes):
    separations = get_separations(boxes)
    circuits = []
    while len(flat_list(circuits)) < len(boxes):
        idx = np.unravel_index(separations.argmin(), separations.shape)
        
        if is_in(circuits, idx[0]):
            circuit0 = where_equal(circuits, idx[0])[0]
        else:
            circuit0 = None
        if is_in(circuits, idx[1]):
            circuit1 = where_equal(circuits, idx[1])[0]
        else:
            circuit1 = None

        if circuit0 is None and circuit1 is None:
            circuits.append([idx[0], idx[1]])
        elif circuit0 is not None and circuit1 is None:
            circuits[circuit0].append(idx[1])
        elif circuit0 is None and circuit1 is not None:
            circuits[circuit1].append(idx[0])
        elif circuit0 != circuit1:
            circuits[circuit0] += circuits[circuit1]
            circuits.pop(circuit1)
            
        separations[idx[0],idx[1]] = 1e12
        separations[idx[1],idx[0]] = 1e12

        lastx = boxes[idx[0]][0] * boxes[idx[1]][0]
        
    return lastx

def parse_data(data):
    boxes = []
    for line in data:
        x,y,z = line.strip().split(',')
        boxes.append([int(x),int(y),int(z)])
    return np.array(boxes)

def run_part(data, part=1, connections=10):
    boxes = parse_data(data)
    
    if part == 1:
        circuits = get_circuits(boxes, connections)
        
        circuit_size = []
        for circuit in circuits:
            circuit_size.append(len(circuit))
        circuit_size = np.sort(circuit_size)[::-1]

        return np.prod(circuit_size[:3])

    elif part == 2:
        return get_last_connection(boxes)
        
    return

In [4]:
print('Part 1 test:', run_part(test), run_part(test)==40)

Part 1 test: 40 True


In [5]:
with open('input_day08.txt', 'r') as f:
    data = f.readlines()
    f.close()

print('Part 1 result:', run_part(data, connections=1000))

Part 1 result: 122430


In [6]:
print('Part 2 test:', run_part(test,2), run_part(test,2)==25272)

Part 2 test: 25272 True


In [7]:
print('Part 2 result:', run_part(data, 2))

Part 2 result: 8135565324
