In [7]:
from collections import namedtuple

# Define a 2D point
Point = namedtuple('Point', ['x', 'y'])

# KD-Tree Node
class KDNode:
    def __init__(self, point, axis, left=None, right=None):
        self.point = point      # The (x, y) point
        self.axis = axis        # Splitting axis: 0 = x, 1 = y
        self.left = left        # Left subtree
        self.right = right      # Right subtree

# Function to build the KD-Tree
def build_kdtree(points, depth=0):
    if not points:
        return None

    k = 2  # number of dimensions
    axis = depth % k

    # Sort points and choose median as root
    points.sort(key=lambda point: point[axis])
    median = len(points) // 2

    return KDNode(
        point=points[median],
        axis=axis,
        left=build_kdtree(points[:median], depth + 1),
        right=build_kdtree(points[median + 1:], depth + 1)
    )

# Function to search nearest neighbor
def nearest_neighbor(root, target, depth=0, best=None):
    if root is None:
        return best

    axis = depth % 2

    # Choose branch to search
    if target[axis] < root.point[axis]:
        next_branch = root.left
        opposite_branch = root.right
    else:
        next_branch = root.right
        opposite_branch = root.left

    # Update best point so far
    best = nearer_point(target, best, root.point)
    best = nearest_neighbor(next_branch, target, depth + 1, best)

    # Check if we need to explore opposite branch
    if abs(target[axis] - root.point[axis]) < distance_squared(target, best)**0.5:
        best = nearest_neighbor(opposite_branch, target, depth + 1, best)

    return best

# Helper function to compute squared Euclidean distance
def distance_squared(p1, p2):
    return (p1[0] - p2[0])**2 + (p1[1] - p2[1])**2

# Helper to decide which point is closer
def nearer_point(target, p1, p2):
    if p1 is None:
        return p2
    if p2 is None:
        return p1
    return p1 if distance_squared(target, p1) < distance_squared(target, p2) else p2

# Example usage
points = [Point(2, 3), Point(5, 4), Point(9, 6), Point(4, 7), Point(8, 1), Point(7, 2)]
kdtree = build_kdtree(points)
target = Point(5, 9)
closest = nearest_neighbor(kdtree, target)

# Output the result
print("Nearest neighbor to", target, "is", closest)

# Print Time & Space Complexity Info
n = len(points)
print("\n--- Time and Space Complexity ---")
print(f"Number of points (n): {n}")
print("Build Time Complexity: O(n log n)")
print("Search Time Complexity (average): O(log n)")
print("Search Time Complexity (worst case): O(n)")
print("Space Complexity: O(n)")


Nearest neighbor to Point(x=5, y=9) is Point(x=4, y=7)

--- Time and Space Complexity ---
Number of points (n): 6
Build Time Complexity: O(n log n)
Search Time Complexity (average): O(log n)
Search Time Complexity (worst case): O(n)
Space Complexity: O(n)
