In [1]:
import numpy as np
from itertools import permutations, product
from collections import Counter, defaultdict

In [2]:
# fname = 'example_short.txt'
# fname = 'example.txt'
fname = 'input.txt'
with open(fname) as f:
    lines = f.readlines()

lines

scanners = []
start, end = 0, 0
for ii, line in enumerate(lines + ["\n"]):
    if "scanner" in line:
        start = ii + 1
    elif line == "\n":
        end = ii
    if end > start:
        scanners.append(lines[start : end])

scanners = [[[int(num) for num in beacon.strip().split(',')] for beacon in scanner] for scanner in scanners]
scanner_dict = dict()
for ii in range(len(scanners)):
    scanner_dict[ii] = np.array(scanners[ii])

In [3]:
all_basis_vectors = [
    np.array([1, 0, 0]),
    np.array([0, 1, 0]),
    np.array([0, 0, 1]),
    np.array([-1, 0, 0]),
    np.array([0, -1, 0]),
    np.array([0, 0, -1]),
]

COBs_3D = []
COBs_inv_3D = []
count = 0
for e_i in all_basis_vectors:
    for e_j in all_basis_vectors:
        if (e_i != e_j).any() and (e_i != -1 * e_j).any():
            count += 1
            e_k = np.cross(e_i, e_j)
            COB = np.vstack([e_i, e_j, e_k]).T
            COBs_3D.append(COB)
            COBs_inv_3D.append(np.linalg.inv(COB).astype(int))
assert count == 24

In [4]:
def find_relative_difference_with_inv(reference_pts, relative_pts, print_stuff=False):
    d_vects = []
    for inv_idx, inv in enumerate(COBs_inv_3D):
        a1, a2 = (inv @ relative_pts.T).T, np.array(list(reference_pts))
        i1, i2 = np.meshgrid(range(len(a1)), range(len(a2)))
        result = a2[i2] - a1[i1]
        result = result.reshape(-1, result.shape[-1])
        d_vects += [(tuple(d_vect), inv_idx) for d_vect in result]
    if print_stuff:
        print(sorted(Counter(d_vects).items(), key=lambda pair: -1 * pair[1])[0:10])
    (location, inv_idx), matches = max(Counter(d_vects).items(), key=lambda pair: pair[1])
    return matches, location, inv_idx

In [5]:
def transform_to_home_basis(points, location, inv_idx):
    inv = COBs_inv_3D[inv_idx]
    transformed = location + (inv @ points.T).T
    return [tuple(tx) for tx in transformed]

In [9]:
%%time

locations = []
home_basis_points = set([tuple(pt) for pt in scanner_dict[0]])
finished_scanners = set([0])
while len(finished_scanners) < len(scanner_dict):
    for scanner in set(scanner_dict).difference(finished_scanners):
        matches, location, inv_idx = find_relative_difference_with_inv(home_basis_points, scanner_dict[scanner])
        locations.append(location)
        if matches >= 12:
#             print(f"adding {scanner}")
            finished_scanners.add(scanner)
            new_pts = transform_to_home_basis(scanner_dict[scanner], location, inv_idx)
            home_basis_points.update(new_pts)

manhattan_max = 0
for loc1, loc2 in product(locations, locations):
    manhattan = sum([abs(a - b) for a, b in zip(loc1, loc2)])
    manhattan_max = max(manhattan_max, manhattan)
    
print(f"part 1: {len(home_basis_points)}")
print(f"part 2: {manhattan_max}")

part 1: 323
part 2: 10685
Wall time: 8.04 s


### todo: speed up by keeping track of which points belong to which scanner, and testing against each one of those sets (prioritized?) instead of testing against all transformed points