In [1]:
import sys
import numpy as np
import math
import operator
import functools
import json
from pprint import pprint
%matplotlib widget
# import sympy
from mpl_toolkits.mplot3d import axes3d

In [2]:
allRotationMatrices = [
    np.identity(3),
    np.array([[-1,0,0], [0, -1, 0], [0,0,1]]),
    np.array([[-1,0,0], [0, 1, 0], [0,0,-1]]),
    np.array([[1,0,0], [0, -1, 0], [0,0,-1]]),
    
    np.array([[-1,0,0], [0, 0, 1], [0,1,0]]),
    np.array([[1,0,0], [0, 0, -1], [0,1,0]]),
    np.array([[1,0,0], [0, 0, 1], [0,-1,0]]),
    np.array([[-1,0,0], [0, 0, -1], [0,-1,0]]),
    
    np.array([[0,-1,0], [1, 0, 0], [0,0,1]]),
    np.array([[0,1,0], [-1, 0, 0], [0,0,1]]),
    np.array([[0,1,0], [1, 0, 0], [0,0,-1]]),
    np.array([[0,-1,0], [-1, 0, 0], [0,0,-1]]),
    
    np.array([[0,1,0], [0, 0, 1], [1,0,0]]),
    np.array([[0,-1,0], [0, 0, -1], [1,0,0]]),
    np.array([[0,-1,0], [0, 0, 1], [-1,0,0]]),
    np.array([[0,1,0], [0, 0, -1], [-1,0,0]]),
    
    np.array([[0,0,1], [1, 0, 0], [0,1,0]]),
    np.array([[0,0,-1], [-1, 0, 0], [0,1,0]]),
    np.array([[0,0,-1], [1, 0, 0], [0,-1,0]]),
    np.array([[0,0,1], [-1, 0, 0], [0,-1,0]]),
    
    np.array([[0,0,-1], [0, 1, 0], [1,0,0]]),
    np.array([[0,0,1], [0, -1, 0], [1,0,0]]),
    np.array([[0,0,1], [0, 1, 0], [-1,0,0]]),
    np.array([[0,0,-1], [0, -1, 0], [-1,0,0]]),
]

In [486]:
from scipy import linalg
import matplotlib.pyplot as plt
# %matplotlib widget

class Scanner:
    def __init__(self, id):
        self.id = id
        self.beacons = np.array([])
        self.rot = np.identity(3)
        self.offset = np.array([0,0,0])
        self.world_rot = np.identity(3)
        self.world_offset = np.array([0,0,0])
        self.parent = None
        self.children = []
    
    def add_beacons(self, beacon_coords):
        self.beacons = np.array(beacon_coords)
    
    def dist_matrix_by_index(self, index):
        # Creates a matrix with distance to self.beacons[index]
        sq_dist = np.sum(np.square(self.beacons - self.beacons[index]), axis=1)
        return np.sqrt(sq_dist)
    
    @property
    def world_beacons(self):
        # Returns becaons in world coordinates
        # return (self.beacons @ self.rot + self.offset) @ self.parent.world_rot + self.parent.world_offset
        return self.beacons @ self.world_rot + self.world_offset
    
    def graph(self):
        fig = plt.figure()
        ax = plt.axes(projection='3d')
        ax.scatter(self.beacons[:,0], self.beacons[:,1], self.beacons[:,2])
        plt.show()
        
    def overlap(self, other):
        # There are always 12 points that are overlapped
        # Returns two sets of points that matches by comparing distance to a single point
        for i in range(len(self.beacons)):
            dist_0 = self.dist_matrix_by_index(i)
            for j in range(len(other.beacons)):
                dist_1 = other.dist_matrix_by_index(j)
                inter = set(dist_0).intersection(dist_1)
                if len(inter) == 12:
                    # Found overlap
                    set_self = np.empty((0, len(self.beacons[0])))
                    set_other = np.empty((0, len(self.beacons[0])))
                    for dist in inter:
                        set_self = np.append(set_self, self.beacons[np.where(dist_0 == dist)], axis=0)
                        set_other = np.append(set_other, other.beacons[np.where(dist_1 == dist)], axis=0)
                    return set_self, set_other
        # no matching 12 points found
        return None
        
    def setRelativeTo(self, base, DEBUG=False):
        # Sets rotation and translation matrices relative to base Scanner
        overlapping_sets = self.overlap(base)
        if not overlapping_sets:
            return False
        if DEBUG:
            print(overlapping_sets)
        set_self, set_base = overlapping_sets
        for rot in allRotationMatrices:
            rot_2 = np.apply_along_axis(lambda x: x @ rot, 1, set_self)
            diff = set_base - rot_2
            if len(np.where(diff == diff[0])[0]) == 12*3:
                self.parent = base
                base.children.append(self)
                self.rot = rot
                self.offset = diff[0]
                # print(base.world_offset, self.offset, base.world_offset + self.offset)
                self.world_rot = self.rot @ base.world_rot
                self.world_offset = self.offset @ base.world_rot + base.world_offset
                if DEBUG:
                    print(self.beacons @ self.rot + self.offset)
                return True
    
    def __repr__(self):
        return f"Scanner:{self.id}({self.world_offset}), c={self.children}"
    

In [525]:
def getScanners(arr):
    scanners = []
    beacon_coords = []
    for a in arr:
        if a.find("scanner") != -1:
            s = Scanner(int(a.strip("-").split()[1]))
        elif not a:
            s.add_beacons(beacon_coords)
            beacon_coords = []
            scanners.append(s)
        else:
            beacon_coords.append([float(i) for i in a.split(",")])
    s.add_beacons(beacon_coords)
    scanners.append(s)
    return scanners

In [526]:
def connectScanners(scanners, root_index):
    root = scanners[root_index]
    connected = [root]
    to_connect = [x for x in scanners if x != root]
    DEBUG = False
    while len(to_connect) > 0:
        scanner = to_connect.pop(0)
        for base in connected:
            if scanner.setRelativeTo(base):
                connected.append(scanner)
                break
        else:
            # did not find a base with matching points
            to_connect.append(scanner)
    return root

In [518]:
def part1(arr):
    scanners = getScanners(arr)
    connectScanners(scanners, 0)
    
    world_beacons = set()
    for scanner in scanners:
        scanner_beacons = set([tuple(i) for i in scanner.world_beacons])
        world_beacons |= scanner_beacons - world_beacons
    # pprint(world_beacons)
    return len(world_beacons)

In [619]:
def part2(arr):
    scanners = getScanners(arr)
    connectScanners(scanners, 0)
    
    locs = np.array([x.world_offset.tolist() for x in scanners])
    locs = np.broadcast_to(locs, (len(locs), len(locs), 3))
    
    diff_table = locs - np.transpose(locs, (1,0,2))
    dist_table = np.sum(diff_table, axis=2)
    
    return np.max(dist_table)

In [620]:
np.set_printoptions(suppress=True)
if __name__ == "__main__":
    file = sys.argv[-1] if sys.argv[-1].endswith(".txt") else "input.txt"
    with open(file, "r") as f:
        arr = f.read().splitlines()

        print(part2(arr))

12092.0
