In [None]:
def cosine_distance(x, centroids, eps=1e-8):
    # Normalize to unit vectors
    x_norm = x / (x.norm(dim=1, keepdim=True) + eps)
    c_norm = centroids / (centroids.norm(dim=1, keepdim=True) + eps)

    # Compute cosine similarity (dot product)
    sim = torch.mm(x_norm, c_norm.t())  # [N, K]

    # Convert to distance (1 - similarity)
    distances = 1.0 - sim

    return distances

def weighted_cosine_distance(x, centroids, weights=None, eps=1e-8):
    """
    Compute cosine distance between points and centroids with per-dimension weights.

    Args:
        x (torch.Tensor): Data points [N, D]
        centroids (torch.Tensor): Cluster centers [K, D]
        weights (torch.Tensor): Feature weights [D], larger = more important
        eps (float): Small constant for numerical stability

    Returns:
        distances (torch.Tensor): [N, K] cosine distance (1 - weighted cosine similarity)
    """
    # Normalize weights so they’re non-negative and scale appropriately
    weights = weights.to(x.device)
    weights = weights / (weights.norm() + eps)

    # Apply weights
    xw = x * weights
    cw = centroids * weights

    # Normalize weighted vectors to unit length
    xw_norm = xw / (xw.norm(dim=1, keepdim=True) + eps)
    cw_norm = cw / (cw.norm(dim=1, keepdim=True) + eps)

    # Compute weighted cosine similarity
    sim = torch.mm(xw_norm, cw_norm.t())  # [N, K]

    # Convert to distance
    distances = 1.0 - sim

    return distances

# === Weighted cosine distance per cluster ===
def cluster_weighted_cosine_distance(x, centroids, weights, eps=1e-8):
    """
    Weighted cosine distance where each cluster has its own weight vector.

    Args:
        x (torch.Tensor): Data points [N, D]
        centroids (torch.Tensor): Cluster centers [K, D]
        weights (torch.Tensor): Per-cluster weights [K, D]
        eps (float): Numerical stability
    Returns:
        distances [N, K]
    """
    N, D = x.shape
    K = centroids.shape[0]

    # Normalize both x and centroids per cluster (with weights)
    x_exp = x.unsqueeze(1).expand(N, K, D)       # [N, K, D]
    c_exp = centroids.unsqueeze(0).expand(N, K, D)  # [N, K, D]
    w_exp = weights.unsqueeze(0).expand(N, K, D)    # [N, K, D]

    # Apply weights before cosine similarity
    x_w = x_exp * w_exp
    c_w = c_exp * w_exp

    # Compute cosine similarity
    dot = (x_w * c_w).sum(dim=-1)
    x_norm = x_w.norm(dim=-1) + eps
    c_norm = c_w.norm(dim=-1) + eps
    sim = dot / (x_norm * c_norm)

    distances = 1.0 - sim  # cosine distance
    return distances

In [None]:
import torch

# === Basic K-Means using custom distance ===
def kmeans(x, num_clusters, distance_fn, num_iters=100, tol=1e-4, verbose=False):
    device = x.device
    n, d = x.shape

    # Random initialization of centroids
    indices = torch.randperm(n, device=device)[:num_clusters]
    centroids = x[indices]

    for i in range(num_iters):
        distances = distance_fn(x, centroids)
        cluster_assignments = torch.argmin(distances, dim=1)

        new_centroids = torch.zeros_like(centroids)
        for k in range(num_clusters):
            mask = cluster_assignments == k
            if mask.any():
                new_centroids[k] = x[mask].mean(dim=0)
            else:
                new_centroids[k] = x[torch.randint(0, n, (1,), device=device)]

        shift = torch.norm(centroids - new_centroids, dim=1).sum()
        if verbose:
            print(f"Iteration {i+1}/{num_iters} | centroid shift = {shift.item():.6f}")
        if shift < tol:
            break
        centroids = new_centroids

    return cluster_assignments, centroids

In [None]:
# === Compute per-cluster per-dimension variance ===
def cluster_variances(x, cluster_assignments, num_clusters):
    n, d = x.shape
    device = x.device
    variances = torch.zeros(num_clusters, d, device=device)

    for k in range(num_clusters):
        mask = (cluster_assignments == k)
        if mask.any():
            cluster_points = x[mask]
            variances[k] = cluster_points.var(dim=0, unbiased=False)
        else:
            variances[k] = torch.ones(d, device=device)  # fallback

    # Step 2: mean variance across clusters per dimension
    mean_variance = variances.mean(dim=0)  # [D]

    # Step 3: deviation of each cluster variance from mean variance
    deviations = torch.abs(variances - mean_variance)  # [K, D]

    return variances

In [None]:
def weighted_kmeans(x, num_clusters, weights, centroids=None, distance_fn=cluster_weighted_cosine_distance, max_iters=100, tol=1e-4, verbose=True):
    device = x.device
    N, D = x.shape
    K = num_clusters

    # Initialize centroids randomly\
    if centroids is None:
        indices = torch.randperm(N, device=device)[:K]
        centroids = x[indices].clone()

    prev_assignments = None

    for iteration in range(max_iters):
        # Compute weighted cosine distances
        distances = distance_fn(x, centroids, weights=weights)
        assignments = distances.argmin(dim=1)

        # Check for convergence
        if prev_assignments is not None and torch.equal(assignments, prev_assignments):
            if verbose:
                print(f"Converged at iteration {iteration}")
            break
        prev_assignments = assignments.clone()

        # Update centroids
        new_centroids = torch.zeros_like(centroids)
        for k in range(K):
            mask = assignments == k
            if mask.any():
                new_centroids[k] = x[mask].mean(dim=0)
            else:
                new_centroids[k] = x[torch.randint(0, N, (1,), device=device)]

        shift = torch.norm(centroids - new_centroids, dim=1).sum()
        if verbose:
            print(f"Iteration {iteration+1}/{max_iters} | centroid shift = {shift.item():.6f}")

        if shift < tol:
            if iteration > 2:
                shift = 1
            break

        centroids = new_centroids


    return assignments, centroids, shift

# === Adaptive Weighted K-Means loop ===
def weighted_cosine_kmeans(x, num_clusters, distance_fn=cluster_weighted_cosine_distance, max_iters_inner=128, max_iters_outer=8, tol=1e-4, verbose=True, rate=1e-1):
    """
    Runs iterative K-Means with weighted cosine distance and
    moving-average variance-based weight updates.
    """
    # Compute new variance-based weights
    weights = torch.ones((num_clusters, x.shape[-1]), device=x.device)
    centroids = None

    for iteration in range(max_iters_outer):
        assignments, centroids, shift = weighted_kmeans(x, num_clusters, weights, centroids=centroids, distance_fn=distance_fn, max_iters=max_iters_inner, tol=tol, verbose=verbose)

        if shift < 1e-8:
            break

        variances = cluster_variances(x, assignments, num_clusters)
        new_weights = 1 / (variances + 1e-8)

        # Normalize weights per cluster (optional but stabilizing)
        row_max, _ = new_weights.max(dim=1, keepdim=True)
        row_max = row_max + 1e-8
        new_weights = new_weights / row_max

        # EMA update for stability
        difference = weights - new_weights
        weights = weights - rate * difference

        if verbose:
            avg_change = torch.abs(weights - new_weights).mean().item()
            print(f"Iter {iteration:02d} | Avg weight Δ: {avg_change:.4f}")

        # Convergence by weight change
        if avg_change < tol:
            if verbose:
                print(f"Weight changes below tol ({tol}); stopping.")
            break

    return assignments, centroids, weights

In [None]:
import torch
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

import torch
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

def plot_tsne_clusters(data, assignments, title="t-SNE Clusters", perplexity=30, random_state=42):
    """
    Generates a 2D t-SNE plot of high-dimensional data colored by cluster assignments.

    Args:
        data (torch.Tensor): [N, D] input data
        assignments (torch.Tensor or np.array): [N] cluster labels
        title (str): Plot title
        perplexity (float): t-SNE perplexity
        random_state (int): Random seed for reproducibility
    """
    # Move to CPU and convert to numpy
    data_np = data.cpu().numpy()
    labels_np = assignments.cpu().long().numpy()

    # Compute t-SNE embedding
    tsne = TSNE(n_components=2, perplexity=perplexity, random_state=random_state)
    data_2d = tsne.fit_transform(data_np)

    # Scatter plot
    plt.figure(figsize=(8,6))
    scatter = plt.scatter(data_2d[:,0], data_2d[:,1], c=labels_np, cmap="tab10", alpha=0.7)
    plt.title(title)
    plt.xlabel("t-SNE 1")
    plt.ylabel("t-SNE 2")
    plt.colorbar(scatter, label="Cluster")
    plt.show()

In [None]:
def euclidean_distance(clusters, eps=1e-8):
    x = [x.get_position for x in clusters]

    diff = x.unsqueeze(1) - x.unsqueeze(0)
    distances = torch.sqrt((diff ** 2).sum(dim=2) + eps)  # [N, K]
    return distances

def cosine_distance(x, y=None, eps=1e-8):
    if y is None:
        y = x

    # Normalize both tensors along feature dimension
    x_norm = x / (x.norm(dim=1, keepdim=True) + eps)
    y_norm = y / (y.norm(dim=1, keepdim=True) + eps)

    # Cosine similarity matrix
    sim = torch.mm(x_norm, y_norm.T)

    # Convert to distance (1 - cosine similarity)
    dist = 1.0 - sim
    return dist

def variance_increase_distance(clusters, eps=1e-8):
    N = len(clusters)
    device = clusters[0].points.device

    # Precompute stats
    means = torch.stack([c.points.mean(dim=0) for c in clusters])  # [N, D]
    variances = torch.stack([c.points.var(dim=0, unbiased=False).mean() for c in clusters])  # [N]
    sizes = torch.tensor([c.points.shape[0] for c in clusters], dtype=torch.float32, device=device)  # [N]

    delta_var = torch.zeros((N, N), device=device)

    for i in range(N):
        n_i = sizes[i]
        μ_i = means[i]
        var_i = variances[i]

        # j > i
        j_indices = torch.arange(i + 1, N, device=device)
        n_j = sizes[j_indices]                      # [M]
        μ_j = means[j_indices]                      # [M, D]
        var_j = variances[j_indices]                # [M]

        n_total = n_i + n_j                          # [M]
        # Weighted within-cluster variance
        var_within = (n_i * var_i + n_j * var_j) / n_total  # [M]
        # Between-cluster variance
        var_between = (n_i * n_j / (n_total ** 2)) * ((μ_i - μ_j).pow(2).sum(dim=1))  # [M]
        # Total merged variance
        merged_var = var_within + var_between
        # Δ variance relative to average
        delta = merged_var - 0.5 * (var_i + var_j)  # [M]

        delta_var[i, j_indices] = delta
        delta_var[j_indices, i] = delta  # symmetric

    return delta_var

def variance_increase_cosine_distance(clusters, eps=1e-8):
    sizes = []

    sizes = torch.tensor([c.shape[0] for c in clusters], dtype=torch.float32, device=clusters.device)  # [K]

    # --- Normalize centroids for cosine similarity ---
    means_norm = clusters / (clusters.norm(dim=1, keepdim=True) + eps)

    # --- Compute pairwise cosine similarity ---
    sim = torch.mm(means_norm, means_norm.T).clamp(-1.0, 1.0)

    # --- Convert similarity to distance ---
    cosine_dist = 1.0 - sim  # [K, K]

    # --- Compute Ward-like Δ variance (weighted by cluster sizes) ---
    n_i = sizes.view(-1, 1)
    n_j = sizes.view(1, -1)

    delta_var = (n_i * n_j) / (n_i + n_j + eps) * cosine_dist
    delta_var.fill_diagonal_(0)

    return delta_var

In [None]:
from tqdm import tqdm
import torch

class Cluster:
    def __init__(self, points: torch.Tensor, indices=None):
        """
        points: [N, D] tensor of member points
        indices: optional list of original indices
        """
        self.points = points
        self.indices = indices if indices is not None else list(range(points.shape[0]))
        self.update_stats()

    def update_stats(self):
        """Recompute centroid and variance."""
        self.centroid = self.points.mean(dim=0)
        self.variance = self.points.var(dim=0, unbiased=False)

    def get_position(self):
        """Return centroid (for distance computations)."""
        return self.centroid

    def get_variance(self):
        """Return per-dimension variance."""
        return self.variance

    def merge(self, other):
        """Merge this cluster with another cluster or single point tensor."""
        if isinstance(other, Cluster):
            new_points = torch.cat([self.points, other.points], dim=0)
            new_indices = self.indices + other.indices
        else:
            # Assume `other` is a single point tensor [D]
            new_points = torch.cat([self.points, other.unsqueeze(0)], dim=0)
            new_indices = self.indices + [-1]  # placeholder
        return Cluster(new_points, new_indices)


def hierarchical_clustering(data, K, distance_fn):
    """
    Custom hierarchical clustering with distance_fn.

    Args:
        data: [N, D] tensor
        K: target number of clusters
        distance_fn: callable (a, b) -> distance scalar
    """

    N, D = data.shape
    clusters = [Cluster(data[i].unsqueeze(0), [i]) for i in range(N)]

    # Iteratively merge until K remain
    pbar = tqdm(total=N - K)
    while len(clusters) > K:
        # Collect all current centroids
        centroids = torch.stack([c.get_position() for c in clusters])  # [M, D]
        M = centroids.shape[0]

        # Compute full pairwise distance matrix
        dist_matrix = distance_fn(clusters)  # [M, M]
        dist_matrix = dist_matrix + torch.eye(M, device=dist_matrix.device) * 1e9  # mask self-distances

        # Find minimum distance pair (single argmin over upper triangle)
        min_idx = torch.argmin(dist_matrix)
        i, j = divmod(min_idx.item(), M)

        # Merge those two clusters
        new_cluster = clusters[i].merge(clusters[j])

        # Remove the old ones and append the new one
        new_clusters = []
        for idx, c in enumerate(clusters):
            if idx not in (i, j):
                new_clusters.append(c)
        new_clusters.append(new_cluster)
        clusters = new_clusters

        pbar.update(1)

    return clusters

def get_cluster_assignments(clusters, n_points):
    assignments = torch.full((n_points,), -1, dtype=torch.long)

    for cluster_idx, cluster in enumerate(clusters):
        for point_idx in cluster.indices:
            if point_idx >= 0:  # ignore synthetic/placeholder indices (e.g. -1)
                assignments[point_idx] = cluster_idx

    return assignments

In [None]:
inputfile = "E:\Coding\SongAnalyzer\Analyzer\src\output_analysis\output-Myna-CLS-ALIBI-Chunking-256.csv"
device = "cuda" if torch.cuda.is_available() else "cpu"

ids = []
data = []

with open(inputfile, "r", encoding='utf-8') as f:
    for line in f:
        parts = line.split(" ")
        numbers = [float(x) for x in parts[:128]]
        id = " ".join(parts[128:])
        data.append(torch.tensor(numbers, device=device))
        ids.append(id)

# Generate synthetic data
# data = []
# num_clusters = 32
# nums = [x * 2 + 64 for x in range(num_clusters)]
#
# for i in range(num_clusters):
#     x = torch.randn(nums[i], 32, device=device) * torch.randn(32, device=device)
#     data.append(x)

In [None]:
Hi Brad,
I’ve been following the MIR/representation-learning work from people now at Suno — especially [paper], which aligns closely with a project I built: transformer-based audio embeddings + a contrastive VAE for music similarity. Demo/code: [github.com/Gorp5/ContrastiveMusicLearning]. Write-up with method/results (WIP PDF): [PDF link].
Would love to contribute to Suno’s ML/MIR engineering efforts.

In [None]:
torch.manual_seed(0)

data = torch.stack(data, dim=0)

clusters = hierarchical_clustering(data, K=256, distance_fn=variance_increase_distance)
assignments = get_cluster_assignments(clusters, data.shape[0])

# Display results
for idx, c in enumerate(clusters):
    print(f"Cluster {idx}:")
    print(" Centroid:", c.get_position())
    print(" Variance:", c.get_variance())
    print(" Points:", c.points.shape[0])

plot_tsne_clusters(data, assignments, title="t-SNE of Weighted Cosine K-Means", perplexity=30)

In [None]:
for cluster_index in range(16):
    print(f"Cluster {cluster_index}:")
    assign = [i for i, a in enumerate(assignments) if a == cluster_index]
    for assign_index in assign:
        print(ids[assign_index])

In [None]:
ids