In [None]:
import torch
import numpy as np

import matplotlib.pyplot as plt

import copy

In [None]:
cpu_gpu_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#cpu_gpu_device = torch.device('cpu')
print(f"Using device: {cpu_gpu_device}")

Using device: cuda


# * MyHDBScan

## ** MST

In [None]:
class MST:
    def __init__(
        self,
        number_of_nearest_neighbors=5,
        cpu_gpu_device='cpu',
        backend=None,  # Options: None, 'faiss', 'hnsw'
        faiss_M=32,
        faiss_efConstruction=64,
        faiss_efSearch=64,
    ):
        if number_of_nearest_neighbors < 2:
            raise ValueError("number_of_nearest_neighbors must be at least 2")

        self.k = number_of_nearest_neighbors

        # Validate device
        self.cuda_flag = False
        if str(cpu_gpu_device).startswith('cuda'):
            self.cuda_flag = True
            if not torch.cuda.is_available():
                raise ValueError("CUDA was requested but is not available.")
            try:
                torch.cuda.get_device_properties(cpu_gpu_device)
            except AssertionError:
                raise ValueError(f"CUDA device '{cpu_gpu_device}' is not available.")

        self.cpu_gpu_device = cpu_gpu_device

        if backend not in [None, 'faiss', 'hnsw']:
            raise ValueError("backend must be one of: None, 'faiss', 'hnsw'")

        self.backend = backend
        if self.backend == 'faiss':
            import faiss
        elif self.backend == 'hnsw':
            if self.cuda_flag:
                raise ValueError("HNSWlib does not support CUDA")
            import hnswlib

        self.faiss_M = faiss_M
        if self.faiss_M not in [16, 32, 48]:
            raise ValueError("faiss_M must be one of: 16, 32, 48")

        self.faiss_efConstruction = faiss_efConstruction
        if self.faiss_efConstruction not in [32, 64, 128]:
            raise ValueError("faiss_efConstruction must be one of: 32, 64, 128")

        self.faiss_efSearch = faiss_efSearch
        if self.faiss_efSearch not in [32, 64, 128, 256]:
            raise ValueError("faiss_efSearch must be one of: 32, 64, 128, 256")

    def find(self, x, parent):
        while parent[x] != x:
            x = parent[x]
        return x

    def mst_none(self, X, n):
        """
        Compute the mutual reachability distance matrix using PyTorch and CUDA.
        """
        # Compute pairwise Euclidean distances
        pairwise_dist = torch.cdist(X, X)

        # k-NN distances; +1 accounts for self-distance (zero) in cdist diagonal
        knn_distances, _ = torch.topk(pairwise_dist, self.k + 1, largest=False)

        # Core distance = distance to k-th nearest neighbor (last column)
        # Expand dimensions to broadcast across rows
        core_distances = knn_distances[:, -1].unsqueeze(1)

        # Compute mutual reachability distance
        mutual_dist_matrix = torch.max(pairwise_dist, core_distances)
        mutual_dist_matrix = torch.max(mutual_dist_matrix, core_distances.T)

        """
        Compute the Minimum Spanning Tree using Kruskal’s algorithm in PyTorch.
        """
        triu_indices = torch.triu_indices(n, n, 1, device=self.cpu_gpu_device)


        # Extract edge weights
        edge_weights = mutual_dist_matrix[triu_indices[0], triu_indices[1]]
        # Get upper-triangle indices sorted by weights
        sorted_indices = torch.argsort(edge_weights)

        # Union-Find structure
        parent = torch.arange(n, device=self.cpu_gpu_device)

        self.mst_edges = []
        for idx in sorted_indices:
            u, v = triu_indices[:, idx]
            root_u, root_v = self.find(u, parent), self.find(v, parent)

            if root_u != root_v:  # No cycle condition
                self.mst_edges.append((u.item(), v.item(), edge_weights[idx].item()))
                parent[root_v] = root_u

                if len(self.mst_edges) == n - 1:
                    break

    def mst_faiss(self, X, n, d):
        """
        Compute the mutual reachability distance matrix using FAISS and CUDA.
        """
        # Create and configure FAISS HNSW index
        index = faiss.IndexHNSWFlat(d, self.faiss_M)
        index.hnsw.efConstruction = self.faiss_efConstruction
        index.hnsw.efSearch = self.faiss_efSearch

        # Optional: move to GPU
        if self.cuda_flag:
            res = faiss.StandardGpuResources()
            device = self.cpu_gpu_device.split(':')
            device_id = int(device[1]) if len(device) > 1 else 0
            index = faiss.index_cpu_to_gpu(res, device_id, index)

        # Add data
        index.add(X)

        # Approximate k-NN search
        knn_distances, knn_indices = index.search(X, self.k + 1)
        # Core distance = distance to k-th nearest neighbor (last column)
        core_distances = knn_distances[:, -1]

        if self.cuda_flag:
            # Force move to CPU to safely build Python-native edge list
            knn_distances = knn_distances.copy()
            knn_indices = knn_indices.copy()
            core_distances = core_distances.copy()

        # Build sparse mutual reachability edge list
        edges = []
        for i in range(n):
            for neighbor_idx, dist in zip(knn_indices[i][1:], knn_distances[i][1:]):  # skip self
                j = neighbor_idx
                core_i = core_distances[i]
                core_j = core_distances[j]
                mrd = max(dist, core_i, core_j)
                edges.append((i, j, mrd))

        """
        Compute the Minimum Spanning Tree using Kruskal’s algorithm on sparse edges.
        """
        # Union-Find structure
        parent = torch.arange(n, device=self.cpu_gpu_device)

        # Sort edges by mutual reachability distance (edge weight)
        edges = sorted(edges, key=lambda x: x[2])

        self.mst_edges = []
        for u, v, weight in edges:
            root_u, root_v = self.find(u, parent), self.find(v, parent)

            if root_u != root_v:
                self.mst_edges.append((u, v, weight))
                parent[root_v] = root_u

                if len(self.mst_edges) == n - 1:
                    break

    def mst_hnsw(self, X, n, d):
        """
        Compute the mutual reachability distance matrix using HNSWlib and CPU.
        """
        # Initialize HNSWlib index
        index = hnswlib.Index(space='l2', dim=d)
        index.init_index(max_elements=n, ef_construction=200, M=16)
        index.add_items(X)
        index.set_ef(50)  # Search effort

        # Query k+1 neighbors (including self)
        knn_indices, knn_distances = index.knn_query(X, k=self.k + 1)
        # Convert squared L2 to actual L2 (Euclidean distances)
        knn_distances = np.sqrt(knn_distances)

        # Core distance = distance to k-th nearest neighbor (last column)
        core_distances = knn_distances[:, -1]

        # Build sparse mutual reachability edge list
        edges = []
        for i in range(n):
            for neighbor_idx, dist in zip(knn_indices[i][1:], knn_distances[i][1:]):  # skip self
                j = neighbor_idx
                core_i = core_distances[i]
                core_j = core_distances[j]
                mrd = max(dist, core_i, core_j)
                edges.append((i, j, mrd))

        """
        Compute the Minimum Spanning Tree using Kruskal’s algorithm on sparse edges.
        """
        # Union-Find structure
        parent = torch.arange(n)

        # Sort edges by mutual reachability distance (edge weight)
        edges = sorted(edges, key=lambda x: x[2])

        self.mst_edges = []
        for u, v, weight in edges:
            root_u, root_v = self.find(u, parent), self.find(v, parent)
            if root_u != root_v:
                self.mst_edges.append((u, v, weight))
                parent[root_v] = root_u
                if len(self.mst_edges) == n - 1:
                    break

    def minimum_spanning_tree(self, X):
        # Check input validity: PyTorch tensor or NumPy array 2D
        if X.ndim != 2:
            raise ValueError("Input data X must be 2D: (n_samples, n_features)")

        X_type = None
        if isinstance(X, torch.Tensor):
            X_type = 'tensor'
        if isinstance(X, np.ndarray):
            X_type = 'numpy'
        if X_type is None:
            raise ValueError("Input data X must be a PyTorch tensor or NumPy array")

        # Create tree
        n, d = X.shape
        if self.backend is None:
            if X_type == 'tensor':
                if X.device != torch.device(self.cpu_gpu_device):
                    X = X.to(dtype=torch.float32, device=self.cpu_gpu_device)
                else:
                    X = X.to(dtype=torch.float32)
            else:
                X = torch.tensor(X, dtype=torch.float32, device=self.cpu_gpu_device)
            self.mst_none(X, n)
        else:
            if X_type == 'tensor':
                X = X.detach().cpu().numpy().astype(np.float32)
            if self.backend == 'faiss':
                self.mst_faiss(X, n, d)
            elif self.backend == 'hnsw':
                self.mst_hnsw(X, n, d)

    def get_mst_edges(self, X):
        self.minimum_spanning_tree(X)
        return self.mst_edges

## ** Clastering

In [None]:
class Cluster:
    def __init__(
        self,
        cluster_id,
        death_size,
        size,
        lambda_birth=None,
        lambda_death=None,
        children=None,
        nodes_ids=None,
        nodes_lambdas=None,
        persistence=0,
        is_singleton=False
    ):
        self.cluster_id = cluster_id

        self.death_size = death_size
        self.size = size
        self.lambda_birth = lambda_birth
        self.lambda_death = lambda_death
        self.children = children if children is not None else []
        self.nodes_ids = nodes_ids if nodes_ids is not None else []
        self.nodes_lambdas = nodes_lambdas if nodes_lambdas is not None else []
        self.persistence = persistence
        self.is_singleton = is_singleton

        self.is_noise = True

class DisjointSet:
    def __init__(
            self,
            n_points,
            min_cluster_size,
            epsilon=1e-10
    ):
        self.parents = list(range(n_points))
        # For Union-Find optimization (tree height)
        self.clusters_ranks = [0] * n_points

        self.n_points = n_points
        self.min_cluster_size = min_cluster_size
        self.epsilon = epsilon

        # Initialize singletons clusters
        self.clusters_hierarchy = {}
        for i in range(self.n_points):
            self.clusters_hierarchy[i] = Cluster(
                cluster_id=i,
                death_size=0,
                size=1,
                nodes_ids=[i],
                is_singleton=True
            )

        #self.debugging_history = []

    def find(self, u):
        if self.parents[u] != u:
            self.parents[u] = self.find(self.parents[u])  # Path compression
        return self.parents[u]

    def union(self, root_u, root_v):
        if self.clusters_ranks[root_u] < self.clusters_ranks[root_v]:
            root_u, root_v = root_v, root_u
        self.parents[root_v] = root_u
        # Rank increases only if equal height
        if self.clusters_ranks[root_u] == self.clusters_ranks[root_v]:
            self.clusters_ranks[root_u] += 1
        return root_u, root_v

    def calculate_persistence(self, cluster):
        """
        Calculate persistence for each cluster.
        """
        if cluster.death_size:
            delta_lambda = cluster.lambda_death - cluster.lambda_birth
            # persistance of existing nodes
            cluster.persistence += cluster.death_size * delta_lambda
        # persistance of fall off nodes
        if len(cluster.nodes_lambdas) > 0:
            nodes_lambdas = np.array(cluster.nodes_lambdas)
            cluster.persistence += np.sum(nodes_lambdas - cluster.lambda_birth)
        cluster.nodes_lambdas = []

    # Flatten hierarchy
    def collapse_cluster(self, cluster):
        self.calculate_persistence(cluster)
        if cluster.children == []:
            return  # Wait for parent comparison

        children_persistence = sum(ch.persistence for ch in cluster.children)
        # compare persistences
        if cluster.persistence > children_persistence:
            cluster.is_noise = False
            clusters_to_delete = []
            nodes_ids = []

            def collect_nodes_and_ids(clust):
                nonlocal nodes_ids # Required for extend
                for ch in clust.children:
                    collect_nodes_and_ids(ch)
                    clusters_to_delete.append(ch.cluster_id)
                nodes_ids.extend(clust.nodes_ids)

            collect_nodes_and_ids(cluster)
            cluster.nodes_ids = nodes_ids
            cluster.children = []

            # Delete after traversal
            for cid in clusters_to_delete:
                del self.clusters_hierarchy[cid]
        else:
            cluster.persistence = children_persistence
            cluster.is_noise = True

    # merge 2 singletons
    def merge_the_singletons(self, cluster_u, cluster_v, lambda_val):
        cluster_u.death_size = 0
        cluster_u.size = 2
        cluster_u.nodes_ids.extend(cluster_v.nodes_ids)
        cluster_u.nodes_lambdas = [lambda_val, lambda_val]
        cluster_u.is_singleton = False
        del self.clusters_hierarchy[cluster_v.cluster_id]

    # cluster_u absorbs cluster_v
    def absorb_cluster(self, cluster_u, cluster_v):
        cluster_u.size += cluster_v.size
        if len(cluster_u.nodes_ids) >= self.min_cluster_size:
            cluster_u.is_noise = False
        cluster_u.nodes_ids.extend(cluster_v.nodes_ids)
        cluster_u.nodes_lambdas.extend(cluster_v.nodes_lambdas)
        del self.clusters_hierarchy[cluster_v.cluster_id]

    # cluster_u and cluster_v merged into new cluster
    def merge_clusters(self, cluster_u, cluster_v, lambda_val, cluster_id):
        cluster_u.lambda_birth = lambda_val
        self.collapse_cluster(cluster_u)

        cluster_v.lambda_birth = lambda_val
        self.collapse_cluster(cluster_v)

        size = cluster_u.size + cluster_v.size
        children = [cluster_u, cluster_v]

        self.clusters_hierarchy[cluster_id] = Cluster(
            cluster_id=cluster_id,
            death_size=size,
            size=size,
            lambda_death=lambda_val,
            children=children
        )

    def build_hierarchy(self, mst_edges):
        """
        Build hierarchy from MST
        Args:
            mst_edges: List of [(u, v, distance), ...] sorted by increasing distance
        """
        # Initialize trackers
        root_clusters_map = {i: i for i in range(self.n_points)}

        # Start assigning new IDs after original points
        cluster_id = self.n_points - 1
        # Process MST edges in order of increasing distance
        for u, v, distance in mst_edges:
            # Lambda = inverse of distance
            #lambda_val = 1 / distance
            lambda_val = 1 / (distance + self.epsilon)
            #lambda_val = -np.log(lambda_val + self.epsilon)
            root_u, root_v = self.find(u), self.find(v)

            # debug_dict = {
            #     'u': u,
            #     'v': v
            # }

            if root_u != root_v:
                # Unite the clusters
                root_u, root_v = self.union(root_u, root_v)
                # Get the clusters
                u_id = root_clusters_map[root_u]
                cluster_u = self.clusters_hierarchy[u_id]
                v_id = root_clusters_map[root_v]
                cluster_v = self.clusters_hierarchy[v_id]
                # debug_dict['u_id'] = u_id
                # debug_dict['v_id'] = v_id
                # debug_dict['u_nodes'] = copy.deepcopy(cluster_u.nodes_ids)
                # debug_dict['v_nodes'] = copy.deepcopy(cluster_v.nodes_ids)
                # Absorb or merge
                if cluster_u.size < self.min_cluster_size or cluster_v.size < self.min_cluster_size:
                    if cluster_u.is_singleton:
                        # debug_dict['action'] = 'merge_the_singletons'

                        self.merge_the_singletons(
                            cluster_u=cluster_u,
                            cluster_v=cluster_v,
                            lambda_val=lambda_val
                        )
                        new_cluster_id = cluster_u.cluster_id

                    else:
                        # debug_dict['action'] = 'absorb'

                        self.absorb_cluster(
                            cluster_u=cluster_u,
                            cluster_v=cluster_v
                        )

                        new_cluster_id = cluster_u.cluster_id

                    # Update cluster map
                    root_clusters_map[root_u] = cluster_u.cluster_id
                else:
                    # Increment cluster ID and update cluster map
                    cluster_id += 1
                    root_clusters_map[root_u] = cluster_id
                    # Merge the clusters

                    # debug_dict['action'] = 'merge'

                    self.merge_clusters(
                        cluster_u=cluster_u,
                        cluster_v=cluster_v,
                        lambda_val=lambda_val,
                        cluster_id=cluster_id
                    )
                    new_cluster_id = cluster_id

            new_cluster = self.clusters_hierarchy[new_cluster_id]

            # debug_dict['nodes_ids'] = copy.deepcopy(new_cluster.nodes_ids)
            # debug_dict['size'] = new_cluster.size
            # debug_dict['cluster_id'] = new_cluster_id
            # debug_dict['is_noise'] = new_cluster.is_noise

            # debug_dict['final clusters'] = copy.deepcopy(self.final_clusters)
            # self.debugging_history.append(debug_dict)

        # root_cluster = self.clusters_hierarchy[root_id]
        # root_cluster.lambda_birth = 0
        # self.collapse_cluster(root_cluster)

    def get_clusters_dict(self):
        clusters_dict = {
            -1: []
        }
        for cluster_id, cluster in self.clusters_hierarchy.items():
            nodes_ids = cluster.nodes_ids
            is_noise = cluster.is_noise
            if is_noise:
                clusters_dict[-1].extend(nodes_ids)
            else:
                clusters_dict[cluster.cluster_id] = nodes_ids
        if len(clusters_dict[-1]) == 0:
            del clusters_dict[-1]
        return clusters_dict

    def get_labels(self):
        labels = np.full(self.n_points, -1)
        for cluster_id, cluster in self.clusters_hierarchy.items():
            if cluster.is_noise:
                labels[cluster.nodes_ids] = -1
            else:
                labels[cluster.nodes_ids] = cluster_id
        return labels


## ** main class

In [None]:
class HDBSCAN:
    def __init__(
            self,
            min_cluster_size=5,
            number_of_nearest_neighbors=None,
            cpu_gpu_device='cpu',
            backend=None,  # Options: None, 'faiss', 'hnsw'
            faiss_M=32,
            faiss_efConstruction=64,
            faiss_efSearch=64,
            epsilon=1e-10
    ):
        number_of_nearest_neighbors = number_of_nearest_neighbors or min_cluster_size
        if number_of_nearest_neighbors > min_cluster_size:
            raise ValueError("number_of_nearest_neighbors should be <= min_cluster_size")
        self.mst_agent = MST(
            number_of_nearest_neighbors=number_of_nearest_neighbors,
            cpu_gpu_device=cpu_gpu_device,
            backend=backend,
            faiss_M=faiss_M,
            faiss_efConstruction=faiss_efConstruction,
            faiss_efSearch=faiss_efSearch
        )

        self.min_cluster_size = min_cluster_size
        self.epsilon = epsilon

    def fit(self, X):
        n_points = X.shape[0]
        if n_points <= self.min_cluster_size:
            raise ValueError("Amount of data points should be > min_cluster_size")

        mst_edges = self.mst_agent.get_mst_edges(X)

        ds_agent = DisjointSet(
            n_points=n_points,
            min_cluster_size=self.min_cluster_size,
            epsilon=self.epsilon
        )

        ds_agent.build_hierarchy(mst_edges)
        self.labels_ = ds_agent.get_labels()

        return self

    def fit_predict(self, X):
        self.fit(X)
        return self.labels_