In [10]:
import numpy as np
from sklearn.cluster import KMeans
import torch.autograd as autograd
class Node:
    def __init__(self, data=None, parent=None):
        self.data = data
        self.parent = parent
        self.left = None
        self.right = None
        self.is_leaf = True
        self.centroid = np.mean(data, axis=0) if data is not None else None
        self.weight = (0.5)*len(data) if data is not None else 0

    def set_children(self, left, right):
        self.left = left
        self.right = right
        self.is_leaf = False

    def remove(self):
        if self.parent:
            if self.parent.left == self:
                self.parent.left = None
            elif self.parent.right == self:
                self.parent.right = None
            if not self.parent.left and not self.parent.right:  # If no children left, make it a leaf
                self.parent.is_leaf = True
            self.parent = None



In [13]:
class ClusterTree:
    def __init__(self, D, L, T, P):
        self.root = Node(data=D)
        self.L = L
        self.T = T
        
        self.P = P
        self.iter = 0
    def update_weights(self, node):
        if node:
            # Apply the weight update formula
            node.weight = node.weight * 0.5 + len(node.data) * 0.5 if node.data is not None else node.weight * 0.5
            self.update_weights(node.left)
            self.update_weights(node.right)

    def prune(self, node):
        if node and node.weight < self.P:
            node.remove()
        if node.left:
            self.prune(node.left)
        if node.right:
            self.prune(node.right)


    def split_node(self, node):
        if len(node.data) <= 1:  # Cannot split if node has 1 or no data points
            return

        # Apply k-means with k=2 to the node's data
        kmeans = KMeans(n_clusters=2, random_state=0, n_init = 'auto').fit(node.data)
        labels = kmeans.labels_

        # Partition the node's data into two clusters based on k-means labels
        left_data = node.data[labels == 0]
        right_data = node.data[labels == 1]

        # Create left and right children nodes if both clusters have data
        if len(left_data) > 0 and len(right_data) > 0:
            left_node = Node(data=left_data, parent=node)
            right_node = Node(data=right_data, parent=node)
            node.set_children(left_node, right_node)

    def grow_tree(self):
        leaf_nodes = self.get_leaf_nodes(self.root)

        max_diff_node = max(leaf_nodes, key=lambda x: np.sum(np.abs(x.data - x.centroid)))
        self.split_node(max_diff_node)
        leaf_nodes = [node for node in self.get_leaf_nodes(self.root) if node.is_leaf]

    def get_leaf_nodes(self, node, leaves=None):
        if leaves is None:
            leaves = []
        if node.is_leaf:
            leaves.append(node)
        if node.left:
            self.get_leaf_nodes(node.left, leaves)
        if node.right:
            self.get_leaf_nodes(node.right, leaves)
        return leaves

    def assign_points_to_leaves_and_update(self, data):
        
        for point in data:   # Find the closest leaf node for each point
            closest_leaf, min_dist = None, float('inf')
            for leaf in self.get_leaf_nodes(self.root):
                dist = np.linalg.norm(point - leaf.centroid)  # Euclidean distance
                if dist < min_dist:
                    closest_leaf, min_dist = leaf, dist
    
            # Append the point to the closest leaf node's data
            if closest_leaf.data is not None:
                closest_leaf.data = np.vstack([closest_leaf.data, point])
            else:  # If the leaf node has no data yet
                closest_leaf.data = np.array([point])

    def calculate_loss_NC(self):
        leaf_nodes = self.get_leaf_nodes(self.root)
        total_loss = 0
        for leaf in leaf_nodes:
            if leaf.data is not None and len(leaf.data) > 0:
                # Calculate the mean of the datapoints in the leaf node
                data_mean = np.mean(leaf.data, axis=0)
                # Calculate the loss as the Euclidean distance between the centroid and the mean of datapoints
                loss = np.linalg.norm(leaf.centroid - data_mean)
                total_loss += loss
        return total_loss/len(leaf_nodes)
    def get_sibling_pairs(self, node=None, sibling_pairs=[]):
        if node is None:
            node = self.root

        if node.left is not None and node.right is not None:
            sibling_pairs.append((node.left, node.right))
            sibling_pairs.append((node.right, node.left))
            self.get_sibling_pairs(node.left, sibling_pairs)
            self.get_sibling_pairs(node.right, sibling_pairs)

        return sibling_pairs

    def compute_P(self , n_centroid, m_centroid):
        numerator = n_centroid - m_centroid
        denominator = np.linalg.norm(n_centroid - m_centroid)
        P = numerator / denominator
        return P

    def calculate_loss_DC(self):
        total_N = 0
        total_B = 0
        total_loss = 0
        siblings = self.get_sibling_pairs()
        for n, m in siblings:
          total_N += 1
          for datapoint in n.data:  # Assuming n.data is a list of datapoints
            total_B += 1
            P_n = self.compute_P(n.centroid, m.centroid)  # Using the previously defined compute_P function
            loss_contribution = np.sum(P_n * (n.centroid - datapoint))
            total_loss += loss_contribution
        return total_loss/(total_N*total_B)

    def build_tree(self, new_D):

        self.assign_points_to_leaves_and_update(new_D)
        # loss = self.calculate_loss_NC()

        print(f"Iteration {self.iter}")

        self.prune(self.root)
        if self.iter % self.T == 0:
            self.grow_tree()
        self.iter += 1
        self.update_weights(self.root)



    def print_tree(self, node, level=0):
        if node:
            print(" " * level + str(node.centroid))
            self.print_tree(node.left, level + 1)
            self.print_tree(node.right, level + 1)
