# Now lets explore the latent space a litte

In [None]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
from torch.utils.data import DataLoader
from pathlib import Path
import pandas as pd

from models import AutoEncoder, VAE
import pickle

import umap

In [None]:
# https://www.youtube.com/watch?v=5WoItGTWV54&ab_channel=StanfordUniversitySchoolofEngineering

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
mpl.rcParams['figure.dpi'] = 200

## Start by loading the model and the validation set

In [None]:
out_path = Path("/home/amal/UbuntuDocuments/projects/generative_modelling/saved_models")

In [None]:
latent_dim = 200

model = VAE(latent_dim=latent_dim)
model.to(device)
checkpoint_path = out_path / f"VAE_epoch_14.pth"
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

In [None]:
attribute_labels = [
    '5_o_Clock_Shadow',
    'Arched_Eyebrows',
    'Attractive',
    'Bags_Under_Eyes',
    'Bald',
    'Bangs',
    'Big_Lips',
    'Big_Nose',
    'Black_Hair',
    'Blond_Hair',
    'Blurry',
    'Brown_Hair',
    'Bushy_Eyebrows',
    'Chubby',
    'Double_Chin',
    'Eyeglasses',
    'Goatee',
    'Gray_Hair',
    'Heavy_Makeup',
    'High_Cheekbones',
    'Male',
    'Mouth_Slightly_Open',
    'Mustache',
    'Narrow_Eyes',
    'No_Beard',
    'Oval_Face',
    'Pale_Skin',
    'Pointy_Nose',
    'Receding_Hairline',
    'Rosy_Cheeks',
    'Sideburns',
    'Smiling',
    'Straight_Hair',
    'Wavy_Hair',
    'Wearing_Earrings',
    'Wearing_Hat',
    'Wearing_Lipstick',
    'Wearing_Necklace',
    'Wearing_Necktie'
]

In [None]:
data_path = "/home/amal/UbuntuDocuments/data/torch_datasets"
validation_data =  datasets.CelebA(data_path, split="valid", transform=transforms.PILToTensor(), download=True)

In [None]:
batch_size = 1024
val_dataloader = DataLoader(validation_data, batch_size=batch_size,)

In [None]:
def encode_dataset(
    model: torch.nn.Module, 
    dataloader: DataLoader,
    device,
):

    encoded = []
    attributes = []
    
    model.eval()
    with torch.no_grad():
        for indx, batch in enumerate(dataloader):
            input_image = batch[0]
            input = torch.tensor(input_image/255, dtype=torch.float).to(device)
            _enc, _, _ = model.encode(input)
            encoded.append(_enc.detach().cpu())
            attributes.append(batch[1])

            if indx > 6:
                break

    return torch.cat(encoded, dim=0), torch.cat(attributes, dim=0)

In [None]:
encoded, attributes = encode_dataset(model, val_dataloader, device)

In [None]:
encoded.shape

In [None]:
fig, axes = plt.subplots(3, 3, figsize=(10, 8))
fig.tight_layout()
axes = axes.flatten()

indices = np.random.choice(np.arange(latent_dim), len(axes), replace=False)

for i, feature in enumerate(indices):

    axes[i].hist(encoded[:, feature], bins=50)
    axes[i].set_title(f"dim {feature}, ($\mu = ${torch.mean(encoded[:, feature]):.2f}, $\sigma = ${torch.std(encoded[:, feature]):.2f})")

In [None]:
encoded = encoded.cpu().detach().numpy()
attributes = attributes.cpu().detach().numpy()

In [None]:
dim_reduction = umap.UMAP(n_components=2, n_neighbors=4)

In [None]:
reduced = dim_reduction.fit_transform(encoded)

In [None]:
#sns.scatterplot(x=reduced[:, 0], y=reduced[:, 1])

In [None]:
attributes.shape

In [None]:
[*(enumerate(attribute_labels))]

In [None]:
attr = 'Wearing_Hat'

sel_index = attribute_labels.index(attr)
labels =  attributes[:, sel_index] == 1

In [None]:
# sns.scatterplot(x=reduced[:, 0], y=reduced[:, 1], hue=labels)
# plt.legend(title='Smoker', loc='upper left', labels=[attr, 'Nah Bruh'])

In [None]:
fig = plt.figure(figsize=(8, 7))

size = 30
e_color = 'black'
linewidth=0.6
alpha=0.7

plt.scatter(x=reduced[attributes[:, sel_index] != 1, 0], y=reduced[attributes[:, sel_index] != 1, 1], s=size, edgecolors=e_color, linewidth = linewidth, alpha=alpha)
plt.scatter(x=reduced[attributes[:, sel_index] == 1, 0], y=reduced[attributes[:, sel_index] == 1, 1], s=size, edgecolors=e_color, linewidth =linewidth, alpha=alpha, label=attr.replace("_", " "))

plt.legend()

## Try and visualise the data in 3D

In [None]:
dim_reduction_3d = umap.UMAP(n_components=3, n_neighbors=4)

In [None]:
reduced_3d = dim_reduction_3d.fit_transform(encoded)

In [None]:
reduced_3d.shape

In [None]:
alpha=0.5
linewidth=0.6
e_color = 'black'
size = 10

plt.scatter(
    x=reduced_3d[:, 0], 
    y=reduced_3d[:, 1],
     s=size, edgecolors=e_color, linewidth = linewidth, alpha=alpha
)

In [None]:
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(reduced_3d[:, 1], reduced_3d[:, 2], reduced_3d[:, 0],  s=size, edgecolors=e_color, linewidth = linewidth, alpha=0.1)