# VAE Latent Space Visualization Notebook
This notebook collects latent codes from your model and visualizes them using PCA/t-SNE, plus latent traversals.

In [1]:
import torch
import numpy as np
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

device = 'cuda' if torch.cuda.is_available() else 'cpu'


## Collect Latent Codes

In [7]:
from torch.utils.data import DataLoader
import os
from generative_ai.src.dataloader import DendritePFMDataset

test_dataset = DendritePFMDataset((3, 64, 64), os.path.join("data", "dataset_split.json"), split="test")
test_dataloader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=True)

model = torch.load(os.path.join("ckpt", "CVAE.ckpt")).to(device)

model.eval()
all_mu, all_z, all_ctr = [], [], []
with torch.no_grad():
    for x, ctr, _ in test_dataloader:
        x, ctr = x.to(device), ctr.to(device)
        mu, logvar = model.encoder(x)
        c = model.cMLP(ctr).view(x.size(0), model.latent_size, 2)
        mu_ = mu + c[..., 0]
        logvar_ = logvar + c[..., 1]
        z = model.reparameterize(mu_, logvar_)
        all_mu.append(mu_.cpu())
        all_z.append(z.cpu())
        all_ctr.append(ctr.cpu())
(torch.cat(all_mu).numpy(),
        torch.cat(all_z).numpy(),
        torch.cat(all_ctr).numpy())

  model = torch.load(os.path.join("ckpt", "CVAE.ckpt")).to(device)


(array([[ 2.6631486 , -0.20753288, -0.0453696 , ...,  0.610454  ,
          0.04829521,  0.16342273],
        [ 1.5561382 , -0.01626784,  0.2974891 , ...,  0.6042874 ,
          0.3993227 ,  0.0438895 ],
        [ 1.4306616 ,  0.11173975,  0.12958041, ...,  0.98450494,
         -0.12998372,  0.23423953],
        ...,
        [ 0.8328345 ,  0.5633985 ,  0.3937545 , ...,  0.6882454 ,
          0.18897179,  0.34280765],
        [ 0.6390162 ,  0.54820645, -0.4731377 , ...,  1.4159375 ,
          0.27273282, -0.28595456],
        [ 1.0278027 , -0.10734588, -0.18419465, ...,  1.7056677 ,
          0.30045006, -0.04322848]], dtype=float32),
 array([[ 2.1435783 ,  0.49349636, -0.6277553 , ...,  4.7876763 ,
         -1.9473677 ,  0.01667894],
        [ 3.387135  , -0.4496736 , -0.16500047, ..., -3.1944942 ,
         -0.5564869 ,  0.83264637],
        [ 1.0900619 ,  0.01170342,  0.7735894 , ...,  0.5320033 ,
         -0.85899043,  1.8355932 ],
        ...,
        [ 1.0882951 , -0.32507604, -0.9

## PCA & t-SNE

In [3]:
pca = PCA(n_components=2)
z2 = pca.fit_transform(latent)
print('Explained variance:', pca.explained_variance_ratio_)

def latent_to_2d_tsne(latent, n_samples=4000):
    if latent.shape[0] > n_samples:
        idx = np.random.choice(latent.shape[0], n_samples, replace=False)
        latent = latent[idx]
    tsne = TSNE(n_components=2, perplexity=30, init='pca')
    return tsne.fit_transform(latent)


## Plotting Function

In [4]:
def plot_latent(z2d, ctr, k=0, title='latent space'):
    c = ctr[:, k]
    plt.figure(figsize=(6,6))
    sc = plt.scatter(z2d[:,0], z2d[:,1], s=5, c=c)
    plt.colorbar(sc, label=f'ctr[{k}]')
    plt.title(title)
    plt.show()


## Latent Traversal

In [5]:
@torch.no_grad()
def latent_traversal(model, ctr, dim, n_steps=7, span=3.0, device=device):
    model.eval()
    ctr = ctr.to(device)
    c = model.cMLP(ctr).view(1, model.latent_size, 2)
    mu = c[..., 0].squeeze(0)
    z_list = []
    for alpha in torch.linspace(-span, span, n_steps):
        z = mu.clone()
        z[dim] += alpha
        z_list.append(z.unsqueeze(0))
    z_batch = torch.cat(z_list, dim=0)
    z_batch = model.zAttn(torch.cat([z_batch, model.zMLP(ctr).repeat(n_steps,1)], dim=1))
    imgs = model.decoder(z_batch)
    grid = make_grid(imgs, nrow=n_steps)
    plt.figure(figsize=(15,3))
    plt.imshow(grid.permute(1,2,0).cpu())
    plt.axis('off')
    plt.show()
