In [117]:
import torch
import numpy as np

import matplotlib.pyplot as plt
import pandas as pd

from collections import defaultdict
import copy

In [118]:
# cpu_gpu_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# #cpu_gpu_device = torch.device('cpu')
# print(f"Using device: {cpu_gpu_device}")

# * MyHDBScan

## ** MST

In [120]:
class MST:
    def __init__(
            self,
            number_of_nearest_neighbors=5,
            cpu_gpu_device='cpu',
            ):
        # number of nearest neighbors
        self.k = number_of_nearest_neighbors
        self.cpu_gpu_device = cpu_gpu_device

    def mutual_reachability_distance(self, X):
        """
        Compute the mutual reachability distance matrix using PyTorch and CUDA.
        """
        X = torch.tensor(X, dtype=torch.float32, device=self.cpu_gpu_device)  # Move data to device
        pairwise_dist = torch.cdist(X, X)  # Compute pairwise Euclidean distances
        knn_distances, _ = torch.topk(pairwise_dist, self.k + 1, largest=False)  # k-NN distances
        core_distances = knn_distances[:, -1].unsqueeze(1)  # Core distance is k-th neighbor

        # Compute mutual reachability distance
        mutual_dist_matrix = torch.max(pairwise_dist, core_distances)
        self.mutual_dist_matrix = torch.max(mutual_dist_matrix, core_distances.T)

    def minimum_spanning_tree(self):
        """
        Compute the Minimum Spanning Tree using Kruskal’s algorithm in PyTorch.
        """
        n = self.mutual_dist_matrix.shape[0]
        triu_indices = torch.triu_indices(n, n, 1, device=self.cpu_gpu_device)  # Get upper-triangle indices
        edge_weights = self.mutual_dist_matrix[triu_indices[0], triu_indices[1]]  # Extract edge weights
        sorted_indices = torch.argsort(edge_weights)  # Sort edges by weight

        parent = torch.arange(n, device=self.cpu_gpu_device)  # Union-Find structure

        def find(x):
            while parent[x] != x:
                x = parent[x]
            return x

        self.mst_edges = []
        for idx in sorted_indices:
            u, v = triu_indices[:, idx]
            root_u, root_v = find(u), find(v)

            if root_u != root_v:  # No cycle condition
                self.mst_edges.append((u.item(), v.item(), edge_weights[idx].item()))
                parent[root_v] = root_u

                if len(self.mst_edges) == n - 1:
                    break

    def get_mst_edges(self, X):
        self.mutual_reachability_distance(X)
        self.minimum_spanning_tree()
        return self.mst_edges

## ** Clastering

In [None]:
class Cluster:
    def __init__(
        self,
        cluster_id,
        death_size,
        size,
        lambda_birth=None,
        lambda_death=None,
        children=None,
        nodes_ids=None,
        nodes_lambdas=None,
        persistence=0,
        is_singleton=False
    ):
        self.cluster_id = cluster_id

        self.death_size = death_size
        self.size = size
        self.lambda_birth = lambda_birth
        self.lambda_death = lambda_death
        self.children = children if children is not None else []
        self.nodes_ids = nodes_ids if nodes_ids is not None else []
        self.nodes_lambdas = nodes_lambdas if nodes_lambdas is not None else []
        self.persistence = persistence
        self.is_singleton = is_singleton

        self.is_noise = True

class DisjointSet:
    def __init__(self, n_points, min_cluster_size):
        self.parents = list(range(n_points))
        # For Union-Find optimization (tree height)
        self.clusters_ranks = [0] * n_points

        self.n_points = n_points
        self.min_cluster_size = min_cluster_size or n_points

        # Initialize singletons clusters
        self.clusters_hierarchy = {}
        for i in range(self.n_points):
            self.clusters_hierarchy[i] = Cluster(
                cluster_id=i,
                death_size=0,
                size=1,
                nodes_ids=[i],
                is_singleton=True
            )

    def find(self, u):
        if self.parents[u] != u:
            self.parents[u] = self.find(self.parents[u])  # Path compression
        return self.parents[u]

    def union(self, root_u, root_v):
        if self.clusters_ranks[root_u] < self.clusters_ranks[root_v]:
            root_u, root_v = root_v, root_u
        self.parents[root_v] = root_u
        # Rank increases only if equal height
        if self.clusters_ranks[root_u] == self.clusters_ranks[root_v]:
            self.clusters_ranks[root_u] += 1
        return root_u, root_v

    def calculate_persistence(self, cluster):
        """
        Calculate persistence for each cluster.
        """
        if cluster.death_size:
            delta_lambda = cluster.lambda_death - cluster.lambda_birth
            # persistance of existing nodes
            cluster.persistence += cluster.death_size * delta_lambda
        # persistance of fall off nodes
        if len(cluster.nodes_lambdas) > 0:
            nodes_lambdas = np.array(cluster.nodes_lambdas)
            cluster.persistence += np.sum(nodes_lambdas - cluster.lambda_birth)
        cluster.nodes_lambdas = []

    # Flatten hierarchy
    def collapse_cluster(self, cluster):
        self.calculate_persistence(cluster)
        if cluster.children == []:
            return  # Wait for parent comparison

        children_persistence = sum(ch.persistence for ch in cluster.children)
        # compare persistences
        if cluster.persistence > children_persistence:
            cluster.is_noise = False
            clusters_to_delete = []
            nodes_ids = []

            def collect_nodes_and_ids(clust):
                nonlocal nodes_ids # Required for extend
                for ch in clust.children:
                    collect_nodes_and_ids(ch)
                    clusters_to_delete.append(ch.cluster_id)
                nodes_ids.extend(clust.nodes_ids)

            collect_nodes_and_ids(cluster)
            cluster.nodes_ids = nodes_ids
            cluster.children = []

            # Delete after traversal
            for cid in clusters_to_delete:
                del self.clusters_hierarchy[cid]
        else:
            cluster.persistence = children_persistence
            cluster.is_noise = True

    # merge 2 singletons
    def merge_the_singletons(self, cluster_u, cluster_v, lambda_val):
        cluster_u.death_size = 0
        cluster_u.size = 2
        cluster_u.nodes_ids.extend(cluster_v.nodes_ids)
        cluster_u.nodes_lambdas = [lambda_val, lambda_val]
        cluster_u.is_singleton = False
        del self.clusters_hierarchy[cluster_v.cluster_id]

    # cluster_u absorbs cluster_v
    def absorb_cluster(self, cluster_u, cluster_v):
        cluster_u.size += cluster_v.size
        if len(cluster_u.nodes_ids) >= self.min_cluster_size:
            cluster_u.is_noise = False
        cluster_u.nodes_ids.extend(cluster_v.nodes_ids)
        cluster_u.nodes_lambdas.extend(cluster_v.nodes_lambdas)
        del self.clusters_hierarchy[cluster_v.cluster_id]

    # cluster_u and cluster_v merged into new cluster
    def merge_clusters(self, cluster_u, cluster_v, lambda_val, cluster_id):
        cluster_u.lambda_birth = lambda_val
        self.collapse_cluster(cluster_u)

        cluster_v.lambda_birth = lambda_val
        self.collapse_cluster(cluster_v)

        size = cluster_u.size + cluster_v.size
        children = [cluster_u, cluster_v]

        self.clusters_hierarchy[cluster_id] = Cluster(
            cluster_id=cluster_id,
            death_size=size,
            size=size,
            lambda_death=lambda_val,
            children=children
        )

    def build_hierarchy(self, mst_edges):
        """
        Build hierarchy from MST
        Args:
            mst_edges: List of [(u, v, distance), ...] sorted by increasing distance
        """
        epsilon = 1e-10
        # Initialize trackers
        root_clusters_map = {i: i for i in range(self.n_points)}

        # Start assigning new IDs after original points
        cluster_id = self.n_points - 1
        # Process MST edges in order of increasing distance
        for u, v, distance in mst_edges:
            # Lambda = inverse of distance
            lambda_val = 1 / distance
            #lambda_val = 1 / (distance + epsilon)
            #lambda_val = -np.log(lambda_val + epsilon)
            root_u, root_v = self.find(u), self.find(v)

            if root_u != root_v:
                # Unite the clusters
                root_u, root_v = self.union(root_u, root_v)
                # Get the clusters
                u_id = root_clusters_map[root_u]
                cluster_u = self.clusters_hierarchy[u_id]
                v_id = root_clusters_map[root_v]
                cluster_v = self.clusters_hierarchy[v_id]

                # Absorb or merge
                if cluster_u.size < self.min_cluster_size or cluster_v.size < self.min_cluster_size:
                    if cluster_u.is_singleton:
                        self.merge_the_singletons(
                            cluster_u=cluster_u,
                            cluster_v=cluster_v,
                            lambda_val=lambda_val
                        )
                    else:
                        self.absorb_cluster(
                            cluster_u=cluster_u,
                            cluster_v=cluster_v
                        )
                    # Update cluster map
                    root_clusters_map[root_u] = cluster_u.cluster_id
                else:
                    # Increment cluster ID and update cluster map
                    cluster_id += 1
                    root_clusters_map[root_u] = cluster_id
                    # Merge the clusters
                    self.merge_clusters(
                        cluster_u=cluster_u,
                        cluster_v=cluster_v,
                        lambda_val=lambda_val,
                        cluster_id=cluster_id
                    )

    def get_clusters_dict(self):
        clusters_dict = {
            -1: []
        }
        for cluster_id, cluster in self.clusters_hierarchy.items():
            nodes_ids = cluster.nodes_ids
            is_noise = cluster.is_noise
            if is_noise:
                clusters_dict[-1].extend(nodes_ids)
            else:
                clusters_dict[cluster.cluster_id] = nodes_ids
        if len(clusters_dict[-1]) == 0:
            del clusters_dict[-1]
        return clusters_dict

# * Tests