In [None]:
import torch
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # registers 3D projection
import hdbscan
from collections import Counter

# --- CONFIG ---
umap_dir = Path("../Results/EmbeddingDataUMAP")

# --- LOAD ALL BATCHES ---
batch_files = sorted(umap_dir.glob("batch_*.pt"))
all_embeds = []

for f in batch_files:
    print(f"Loading {f.name} ...")
    batch_tensor = torch.load(f, weights_only=True)
    all_embeds.append(batch_tensor.numpy())

all_embeds = np.vstack(all_embeds)  # shape: (total_points, 3)
print("Final shape:", all_embeds.shape)

# --- RUN HDBSCAN ---
clusterer = hdbscan.HDBSCAN(
    min_cluster_size=30,  # smaller clusters allowed
    min_samples=10        # fewer points needed to avoid noise
)
cluster_labels = clusterer.fit_predict(all_embeds)
print(f"Found {len(np.unique(cluster_labels))} clusters (including -1 noise)")

# --- COUNT POINTS IN EACH CLUSTER ---
counts = Counter(cluster_labels)
print("Cluster counts:")
for cluster_id, count in sorted(counts.items()):
    label = "Noise" if cluster_id == -1 else f"Cluster {cluster_id}"
    print(f"{label}: {count} points")

# --- FILTER OUT NOISE ---
mask = cluster_labels != -1
embeds_filtered = all_embeds[mask]
labels_filtered = cluster_labels[mask]

# --- 3D PLOT WITH COLORS (NO NOISE) ---
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection="3d")

scatter = ax.scatter(
    embeds_filtered[:, 0],
    embeds_filtered[:, 1],
    embeds_filtered[:, 2],
    c=labels_filtered,          # now only clusters
    cmap="tab20",               # distinct colors
    s=5,
    alpha=0.8
)

ax.set_xlabel("UMAP-1")
ax.set_ylabel("UMAP-2")
ax.set_zlabel("UMAP-3")
ax.set_title("3D UMAP Embeddings (HDBSCAN Clusters, Noise Removed)")
plt.show()
