# Algorithm for finding the closest pair of points in a set of 3D points

In [10]:
# Algorithm for finding the closest pair of points in a set of 3D points
# Using a divide and conquer approach

import numpy as np
import math
import sys
import time

class Point:
    def __init__(self, x, y, z):
        self.x = x
        self.y = y
        self.z = z
    
    def __str__(self):
        return f"({self.x}, {self.y}, {self.z})"
    
    def __repr__(self):
        return f"({self.x}, {self.y}, {self.z})"

def distance(p1, p2):
    return math.sqrt((p1.x - p2.x)**2 + (p1.y - p2.y)**2 + (p1.z - p2.z)**2)

def closest_pair(points):
    # Sort points by x-coordinate
    points.sort(key=lambda p: p.x)
    points.sort(key=lambda p: p.y)
    return closest_pair_rec(points)

def closest_pair_rec(points):
    n = len(points)
    if n <= 3:
        return brute_force(points)

    mid = n // 2

    left = points[:mid]
    right = points[mid:]

    left_pair = closest_pair_rec(left)
    right_pair = closest_pair_rec(right)

    delta = min(left_pair[0], right_pair[0])

    split_pair = closest_pair_split(points, delta)

    if split_pair is not None:
        return split_pair
    elif left_pair[0] < right_pair[0]:
        return left_pair
    else:
        return right_pair
    

def brute_force(points):
    n = len(points)
    best = sys.maxsize
    best_pair = None
    for i in range(n):
        for j in range(i+1, n):
            p1 = points[i]
            p2 = points[j]
            d = distance(p1, p2)
            if d < best:
                best = d
                best_pair = (d, p1, p2)

    return best_pair

def closest_pair_split(points, delta):
    n = len(points)
    mid = n // 2
    mid_x = points[mid].x

    # Create a list of points within delta distance of the middle x-coordinate
    strip = [p for p in points if abs(p.x - mid_x) < delta]

    # Sort points by y-coordinate
    strip.sort(key=lambda p: p.y)

    best = delta
    best_pair = None
    for i in range(len(strip)):
        for j in range(i+1, min(i+8, len(strip))):
            p1 = strip[i]
            p2 = strip[j]
            d = distance(p1, p2)
            if d < best:
                best = d
                best_pair = (d, p1, p2)

    return best_pair

In [5]:
%matplotlib qt

In [13]:
import matplotlib.pyplot as plt

points = [Point(np.random.randint(0, 100), np.random.randint(0, 100), np.random.randint(0, 100)) for _ in range(50)]
print(points)
# Show 3D points
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
for p in points:
    ax.scatter(p.x, p.y, p.z)

start = time.time()
min_dist, p1, p2 = closest_pair(points)
end = time.time()
print(f"Closest pair: {p1} and {p2} with distance {min_dist}")
print(f"Time: {end - start}")
ax.plot([p1.x, p2.x], [p1.y, p2.y], [p1.z, p2.z], color='red')
plt.show()


[(89, 23, 28), (87, 28, 12), (32, 90, 14), (44, 1, 71), (32, 25, 89), (86, 17, 96), (92, 65, 63), (69, 86, 13), (71, 26, 17), (42, 18, 0), (81, 83, 7), (37, 86, 91), (50, 61, 78), (11, 89, 9), (83, 97, 37), (37, 47, 10), (17, 34, 18), (59, 35, 47), (96, 94, 48), (36, 73, 67), (55, 10, 0), (95, 1, 80), (95, 77, 92), (87, 50, 82), (33, 72, 72), (66, 27, 17), (94, 13, 22), (16, 35, 84), (96, 78, 56), (41, 56, 16), (1, 74, 45), (90, 14, 20), (44, 73, 78), (35, 92, 17), (65, 86, 61), (37, 17, 46), (43, 98, 30), (49, 18, 3), (63, 9, 98), (91, 39, 94), (10, 54, 34), (82, 32, 7), (21, 50, 13), (26, 99, 67), (86, 61, 35), (23, 71, 72), (97, 59, 41), (51, 69, 11), (48, 24, 69), (99, 67, 58)]
Closest pair: (94, 13, 22) and (90, 14, 20) with distance 4.58257569495584
Time: 0.00042176246643066406
