In [1]:
import random
import numpy as np

In [9]:
class Datapoint:
    def __init__(self, coordinate, color, weight):
        assert isinstance(coordinate, list), 'Coordinate must be a list of integers'
        assert isinstance(color, str), 'Color must be a str'
        assert isinstance(weight, (float, int)), 'Weight must be a float or int'

        self.coordinate = coordinate
        self.color = color
        self.weight = float(weight)

class Node:
    def __init__(self, range_1=None, range_2=None, dp: Datapoint=None):
        assert dp is not None or (range_1 is not None and range_2 is not None), 'The Node must be a rectangle range or a datapoint(leaf)'

        self.rectangle = (range_1, range_2)
        self.datapoint = dp
        self.left_child = None
        self.right_child = None

class KDTree:
    def __init__(self):
        self.root: Node or None = None
        self.dimension = 0

    def build_tree(self, points: list[Datapoint]):
        assert len(points) > 0, 'There must be at least one datapoint to build tree'

        self.dimension = len(points[0].coordinate)
        self.root = self.__build_tree(points)

    def __build_tree(self, points: list[Datapoint]):
        if len(points) == 0:
            return None
        if len(points) == 1:
            return Node(dp=points[0])

        coordinates = np.array([p.coordinate for p in points])
        r_1 = np.min(coordinates, axis=0)
        r_2 = np.max(coordinates, axis=0)

        root = Node(range_1=r_1, range_2=r_2)

        if np.array_equal(r_1, r_2):
            root.left_child = Node(dp=points[0])
            root.right_child = Node(dp=points[1])
            return root

        variances = np.var(coordinates, axis=0)
        index_max_variance = np.argmax(variances)

        points.sort(key=lambda pts: pts.coordinate[index_max_variance])

        median_index = len(points) // 2

        root.left_child = self.__build_tree(points[:median_index])
        root.right_child = self.__build_tree(points[median_index:])

        return root


In [10]:
col = ['Red', 'Blue', 'Green', 'White', 'Black']
w = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
datapoints = [Datapoint([random.randint(1, 100), random.randint(1, 100), random.randint(1, 100), random.randint(1, 100), random.randint(1, 100), random.randint(1, 100)], random.choice(col), random.choice(w)) for i in range(1000000)]

In [8]:
for datapoint in datapoints:
    print(len(datapoint.coordinate), datapoint.coordinate, datapoint.weight, datapoint.color)

6 [44, 14, 2, 61, 71, 46] 0.9 Blue
6 [92, 27, 89, 22, 99, 83] 0.6 White
6 [23, 39, 69, 27, 44, 2] 0.8 Black
6 [29, 69, 87, 8, 51, 86] 0.2 Blue
6 [67, 43, 15, 50, 47, 98] 0.4 White
6 [97, 58, 100, 53, 19, 55] 1.0 Blue
6 [98, 86, 57, 30, 66, 91] 1.0 Blue
6 [73, 51, 39, 18, 98, 24] 1.0 Red
6 [21, 15, 79, 3, 74, 45] 0.7 White
6 [24, 13, 76, 27, 30, 60] 0.5 White
6 [58, 21, 24, 74, 100, 66] 0.6 White
6 [19, 37, 83, 21, 82, 90] 0.3 Green
6 [24, 38, 76, 100, 24, 97] 0.2 Black
6 [83, 33, 10, 40, 94, 29] 0.3 Green
6 [12, 90, 89, 86, 77, 99] 0.5 Black
6 [97, 65, 88, 93, 73, 92] 0.3 Blue
6 [38, 73, 56, 22, 1, 84] 0.6 Black
6 [95, 34, 57, 77, 14, 71] 0.4 Red
6 [22, 64, 35, 91, 52, 14] 0.0 Green
6 [46, 42, 60, 32, 79, 42] 0.7 Black
6 [50, 79, 3, 12, 75, 78] 0.1 Red
6 [67, 99, 100, 86, 44, 85] 0.3 Red
6 [49, 95, 7, 74, 64, 49] 0.2 Green
6 [4, 38, 8, 44, 8, 10] 0.8 Green
6 [57, 98, 40, 34, 54, 28] 0.0 Black
6 [17, 77, 34, 26, 12, 65] 0.2 White
6 [46, 64, 75, 47, 17, 73] 0.2 Blue
6 [63, 5, 71, 27, 60,

KeyboardInterrupt: 

In [11]:
tree = KDTree()
tree.build_tree(datapoints)

In [5]:
def is_balanced(root):
    def check(node):
        if node is None:
            return 0
        left = check(node.left_child)
        right = check(node.right_child)
        if left == -1 or right == -1 or abs(left - right) > 1:
            return -1
        return 1 + max(left, right)

    return check(root) != -1

is_balanced(tree.root)

True

In [13]:
print(tree.root.rectangle)

(array([1, 1, 1, 1, 1, 1]), array([100, 100, 100, 100, 100, 100]))
