### Investigating latent-space interpolation using spherical linear interpolation

See more more details about Slerp [here](https://en.wikipedia.org/wiki/Slerp).

##### Prerequisites
- ffmpeg

In [None]:
import os
import time
from tempfile import TemporaryDirectory

import matplotlib.pyplot as plt
import mlflow
import numpy as np
import torch
from tqdm import tqdm

from synthetic_data.api.model_registry import ModelRegistry
from synthetic_data.common import helpers
from synthetic_data.common.config import LocalConfig
from synthetic_data.common.torchutils import get_device
from synthetic_data.mlops.tools.analysis import create_embedded_noise, slerp

cfg = LocalConfig()
model_registry = ModelRegistry()
mlflow.set_tracking_uri(cfg.URI_MODELREG_REMOTE)
mlflow.set_registry_uri(cfg.URI_MODELREG_REMOTE)
BASE_COLOR = "#DE5D4F"

device = get_device()
torch.manual_seed(1337)
np.random.seed(1337)

## Interpolate on WGAN-GP

In [None]:
def save_gan_sequence(sequence: torch.Tensor, global_step: int, max_step: int, savedir: str) -> None:
    if sequence.ndim == 2:
        sequence = sequence[0]

    sequence = sequence.detach().numpy()
    n_samples = sequence.shape[0]
    time_steps = np.arange(n_samples)

    plt.figure(figsize=(10, 3), dpi=200)
    plt.plot(time_steps, sequence, color=BASE_COLOR)
    plt.plot(time_steps, sequence, "o", color=BASE_COLOR)
    plt.title(f"frame {global_step:04d} / {max_step:04d}", loc="right")
    plt.savefig(f"{savedir}/frame{global_step:04d}.jpg")
    plt.close()

with TemporaryDirectory() as tmpdir:

    wgan = model_registry.load_model("WGAN-GP", 4)

    noise_shape = (1, 100)  # batch_size, z_dim
    noise_A = torch.randn(noise_shape)  # initial noise

    n_samples = 200  # number of samples between two distrubtions
    n_classes = 10

    global_step = 0
    max_steps = n_classes * n_samples

    with tqdm(total=max_steps, desc="Interpolating", unit="frame", colour=BASE_COLOR) as pbar:
        while global_step < max_steps:

            # sample new destination
            noise_B = torch.randn(noise_shape)

            for value in np.linspace(0, 1, n_samples):
                noise = slerp(float(value), noise_A, noise_B)
                sequence = wgan(noise)
                save_gan_sequence(sequence, global_step, max_steps, tmpdir)
                
                global_step += 1
                pbar.update(1)
            
            # Update previous noise for next iteration
            noise_A = noise_B

    helpers.create_gif_from_image_folder(tmpdir, "wgan_gp.gif", fps=60)

## Interpolate on C-GAN

In [None]:
def save_cgan_sequence(
    sequence: torch.Tensor, global_step: int, max_step: int, indexA: int, indexB: int, save_dir: str) -> None:
    if sequence.ndim == 2:
        sequence = sequence[0]

    sequence = sequence.detach().numpy()
    n_samples = sequence.shape[0]
    time_steps = np.arange(n_samples)

    freqA = str(indexA + 1) + " Hz"
    freqB = str(indexB + 1) + " Hz"

    plt.figure(figsize=(10, 3), dpi=200)
    plt.plot(time_steps, sequence, color=BASE_COLOR)
    plt.plot(time_steps, sequence, "o", color=BASE_COLOR)
    plt.title(f"from {freqA} to {freqB}", loc="left")
    plt.title(f"frame {global_step:04d} / {max_step:04d}", loc="right")
    plt.savefig(f"{save_dir}/frame%04d.jpg" % global_step)
    plt.savefig(f"{save_dir}/frame{global_step:04d}.jpg")
    plt.close()

with TemporaryDirectory() as tmpdir:
    cgan = model_registry.load_model("C-GAN", 7)

    # Pure forward pass doesnt allow us to use different embeddings,
    # so we retrieve the embedder and the generator from the model's modules
    embedder = cgan.embedder
    model = cgan.model.forward

    n_samples = 20  # number of samples between two frequencies
    n_classes = 10

    label_A = 0  # Setup initial frequency to 1 Hz
    next_index = 1  # Setup next frequency to 2 Hz

    fixed_noise = torch.randn(1, 100) # (batch_size, z_dim)
    noise_A = create_embedded_noise(embedder, fixed_noise, label_A)

    global_step = 0
    max_steps = n_classes * n_samples

    with tqdm(total=max_steps, desc="Interpolating", unit="frame", colour=BASE_COLOR) as pbar:
        for i in range(n_classes):

            label_B = next_index % n_classes
            noise_B = create_embedded_noise(embedder, fixed_noise, label_B)

            # Interpolate between the two noise distributions
            for j, t in enumerate(np.linspace(0, 1, n_samples)):
                global_step = j + i * n_samples
                pbar.update(1)

                dynamic_noise = slerp(t, noise_A, noise_B)
                sequence = model(dynamic_noise)
                save_cgan_sequence(sequence, global_step, max_steps, label_A, label_B, tmpdir)

            # Update previous noise for next iteration
            label_A = label_B
            noise_A = noise_B
            next_index += 1

    helpers.create_gif_from_image_folder(tmpdir, "test_cgan.gif", fps=10)