# Latent Exploration

> Scripts to explore the Latent Space

In [13]:
#| default_exp latent_exploration

In [14]:
#| export
#| hide
import numpy as np
from itertools import combinations
from scipy.spatial.distance import cdist

In [15]:
#| hide
from fastcore.test import test_eq

In [16]:
#| export
def linear_interpolation(z1, z2, steps):
    """Perform linear interpolation between two points."""
    return np.linspace(z1, z2, steps)

def slerp(z1, z2, steps):
    """Perform spherical linear interpolation between two points."""
    z1_norm = z1 / np.linalg.norm(z1)
    z2_norm = z2 / np.linalg.norm(z2)
    dot_product = np.clip(np.dot(z1_norm, z2_norm), -1.0, 1.0)
    omega = np.arccos(dot_product)
    if omega == 0:
        return np.tile(z1, (steps, 1))
    sin_omega = np.sin(omega)
    return np.array([
        (np.sin((1 - t) * omega) / sin_omega) * z1 +
        (np.sin(t * omega) / sin_omega) * z2
        for t in np.linspace(0, 1, steps)
    ])

def interpolate_sample(centroids, granularity=10, variance=0.0):
    """
    Perform interpolating sampling between all pairs of centroids.

    Parameters:
    - centroids (np.ndarray): Array of shape (n_centroids, latent_dim).
    - granularity (int): Number of interpolation steps between each pair.
    - variance (float): Standard deviation for Gaussian sampling.

    Returns:
    - samples (np.ndarray): Array of sampled points.
    """
    latent_dim = centroids.shape[1]
    samples = []

    for z1, z2 in combinations(centroids, 2):
        if latent_dim <= 2:
            interpolated = linear_interpolation(z1, z2, granularity)
        else:
            interpolated = slerp(z1, z2, granularity)
        
        if variance > 0:
            noise = np.random.normal(0, variance, interpolated.shape)
            interpolated += noise
        
        samples.append(interpolated)
    
    if samples:
        return np.vstack(samples)
    else:
        return np.empty((0, latent_dim))

In [17]:
# Define example centroids for a 2-dimensional latent space
centroids = np.array([
    [1.0, 2.0],
    [3.0, 4.0],
    [5.0, 6.0]
])

granularity = 3
variance = 0.0  # Set to 0 for deterministic interpolation

sampled_points = interpolate_sample(centroids, granularity, variance)

# Define the expected sampled points manually for granularity=3
expected_data = np.array([
    [1.0, 2.0],
    [2.0, 3.0],
    [3.0, 4.0],
    [1.0, 2.0],
    [3.0, 4.0],
    [5.0, 6.0],
    [3.0, 4.0],
    [4.0, 5.0],
    [5.0, 6.0]
])

# Check the sampled points against the expected data
test_eq(sampled_points, expected_data)

In [18]:
#| hide
import nbdev; nbdev.nbdev_export()