# Latent Exploration

> Scripts to explore the Latent Space

In [36]:
#| default_exp latent_exploration

In [37]:
#| export
#| hide
import numpy as np
from itertools import combinations
from scipy.spatial.distance import cdist
import numpy as np
from itertools import combinations
from scipy.spatial import distance
from sklearn.mixture import GaussianMixture

In [38]:
#| hide
from fastcore.test import test_eq, test_close

In [39]:
#| export
def linear_interpolation(z1, z2, steps):
    """Perform linear interpolation between two points."""
    return np.linspace(z1, z2, steps)

def slerp(z1, z2, steps):
    """Perform spherical linear interpolation between two points."""
    z1_norm = z1 / np.linalg.norm(z1)
    z2_norm = z2 / np.linalg.norm(z2)
    dot_product = np.clip(np.dot(z1_norm, z2_norm), -1.0, 1.0)
    omega = np.arccos(dot_product)
    if omega == 0:
        return np.tile(z1, (steps, 1))
    sin_omega = np.sin(omega)
    return np.array([
        (np.sin((1 - t) * omega) / sin_omega) * z1 +
        (np.sin(t * omega) / sin_omega) * z2
        for t in np.linspace(0, 1, steps)
    ])

def interpolate_sample(centroids, granularity=10, variance=0.0):
    """
    Perform interpolating sampling between all pairs of centroids.

    Parameters:
    - centroids (np.ndarray): Array of shape (n_centroids, latent_dim).
    - granularity (int): Number of interpolation steps between each pair.
    - variance (float): Standard deviation for Gaussian sampling.

    Returns:
    - samples (np.ndarray): Array of sampled points.
    """
    latent_dim = centroids.shape[1]
    samples = []

    for z1, z2 in combinations(centroids, 2):
        if latent_dim <= 2:
            interpolated = linear_interpolation(z1, z2, granularity)
        else:
            interpolated = slerp(z1, z2, granularity)
        
        if variance > 0:
            noise = np.random.normal(0, variance, interpolated.shape)
            interpolated += noise
        
        samples.append(interpolated)
    
    if samples:
        return np.vstack(samples)
    else:
        return np.empty((0, latent_dim))

In [40]:
# Define example centroids for a 2-dimensional latent space
centroids = np.array([
    [1.0, 2.0],
    [3.0, 4.0],
    [5.0, 6.0]
])

granularity = 3
variance = 0.0  # Set to 0 for deterministic interpolation

sampled_points = interpolate_sample(centroids, granularity, variance)

# Define the expected sampled points manually for granularity=3
expected_data = np.array([
    [1.0, 2.0],
    [2.0, 3.0],
    [3.0, 4.0],
    [1.0, 2.0],
    [3.0, 4.0],
    [5.0, 6.0],
    [3.0, 4.0],
    [4.0, 5.0],
    [5.0, 6.0]
])

# Check the sampled points against the expected data
test_eq(sampled_points, expected_data)

In [41]:
#| export
def compute_centroids(latents, labels, method='mean', return_labels=False, **kwargs):
    """
    Compute the centroid of each class in the latent space using various methods.
    
    Parameters:
    - latents (np.ndarray): Array of shape (n_samples, latent_dim).
    - labels (np.ndarray): Array of shape (n_samples,) with class labels.
    - method (str): Method to compute centroids. Options: 'mean', 'median', 'geom_median', 'medoid', 'trimmed_mean', 'gmm'.
    - return_labels (bool): If True, also return the unique labels corresponding to the centroids.
    - kwargs: Additional arguments for specific methods.
    
    Returns:
    - centroids (np.ndarray): Array of shape (n_classes, latent_dim) containing centroids.
    - unique_labels (np.ndarray, optional): Array of shape (n_classes,) with unique class labels.
    """
    unique_labels = np.unique(labels)
    centroids = []
    
    for label in unique_labels:
        class_points = latents[labels == label]
        
        if method == 'mean':
            centroid = class_points.mean(axis=0)
        
        elif method == 'median':
            centroid = np.median(class_points, axis=0)
        
        elif method == 'geom_median':
            centroid = geometric_median(class_points, tol=kwargs.get('tol', 1e-5))
        
        elif method == 'medoid':
            centroid = compute_medoid(class_points)
        
        elif method == 'trimmed_mean':
            trim_ratio = kwargs.get('trim_ratio', 0.1)
            centroid = trimmed_mean_centroid(class_points, trim_ratio=trim_ratio)
        
        elif method == 'gmm':
            n_components = kwargs.get('n_components', 1)
            if n_components != 1:
                raise ValueError("GMM-based centroids require n_components=1 for simple centroid computation.")
            gmm = GaussianMixture(n_components=1)
            gmm.fit(class_points)
            centroid = gmm.means_[0]
        
        else:
            raise ValueError(f"Unsupported centroid computation method: {method}")
        
        centroids.append(centroid)
    
    centroids = np.array(centroids)
    
    if return_labels:
        return centroids, unique_labels
    else:
        return centroids

# Auxiliary Functions
def geometric_median(points, tol=1e-5):
    y = np.mean(points, axis=0)
    while True:
        D = distance.cdist([y], points, 'euclidean')[0]
        nonzeros = (D != 0)
        
        if not np.any(nonzeros):
            return y
        
        D = D[nonzeros]
        points_nonzero = points[nonzeros]
        y1 = np.sum(points_nonzero / D[:, np.newaxis], axis=0) / np.sum(1 / D)
        
        if np.linalg.norm(y - y1) < tol:
            return y1
        y = y1

def compute_medoid(points):
    dist_matrix = distance.cdist(points, points, 'euclidean')
    medoid_index = np.argmin(dist_matrix.sum(axis=1))
    return points[medoid_index]

def trimmed_mean_centroid(points, trim_ratio=0.1):
    trimmed_points = []
    for dim in range(points.shape[1]):
        sorted_dim = np.sort(points[:, dim])
        trim = int(trim_ratio * len(sorted_dim))
        trimmed_dim = sorted_dim[trim: -trim]
        trimmed_points.append(trimmed_dim)
    return np.array([np.mean(dim) for dim in trimmed_points])

In [42]:
#| hide

# Define example latent representations
latents = np.array([
    [1.0, 2.0],
    [1.5, -1.0],
    [2.0, -1.0],
    [1.2, 2.2],
    [1.4, 1.8],
    [1.7, 7.0]
])

# Define corresponding labels
labels = np.array([1,1,1,2,2,3])

# Compute centroids without returning labels
centroids = compute_centroids(latents, labels)

# Define the expected centroids
expected_centroids = np.array([
    [1.5, 0.],        # Centroid for label 1
    [1.3, 2.],        # Centroid for label 2
    [1.7, 7.]         # Centroid for label 2
])

# Check the computed centroids against the expected data
test_close(centroids, expected_centroids)

In [43]:
#| hide
import nbdev; nbdev.nbdev_export()