In [None]:
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchsummary import summary
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
from torch.utils.data import DataLoader
from pathlib import Path
import pandas as pd

from models import AutoEncoder
import pickle

import umap

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
mpl.rcParams['figure.dpi'] = 150

In [None]:
out_path = Path("/home/amal/UbuntuDocuments/projects/generative_modelling/saved_models")

In [None]:
latent_dim = 200

model = AutoEncoder(latent_dim=latent_dim)
model.to(device)
checkpoint_path = out_path / f"autoencoder_epoch_9.pth"
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

In [None]:
data_path = "/home/amal/UbuntuDocuments/data/torch_datasets"
validation_data =  datasets.CelebA(data_path, split="valid", transform=transforms.PILToTensor(), download=True)

In [None]:
batch_size = 1024
val_dataloader = DataLoader(validation_data, batch_size=batch_size,)

In [None]:
def encode_dataset(
    model: torch.nn.Module, 
    dataloader: DataLoader,
    device,
):

    encoded = []
    attributes = []
    
    model.eval()
    with torch.no_grad():
        for indx, batch in enumerate(dataloader):
            input_image = batch[0]
            input = torch.tensor(input_image/255, dtype=torch.float).to(device)
            _enc = model.encode(input)
            encoded.append(_enc.detach().cpu())
            attributes.append(batch[1])

            if indx > 6:
                break

    return torch.cat(encoded, dim=0), torch.cat(attributes, dim=0)

In [None]:
encoded, attributes = encode_dataset(model, val_dataloader, device)

In [None]:
encoded.shape

In [None]:
mus = torch.mean(encoded, 0)
sigmas = torch.std(encoded, dim=0)
N_samples = 6
torch.mean(encoded), torch.std(encoded)

In [None]:
mus = mus.unsqueeze(0).repeat(N_samples, 1)
sigmas = sigmas.unsqueeze(0).repeat(N_samples, 1)

In [None]:
#random_vectors = torch.normal(mean=mus, std=sigmas).to(device)
random_vectors = torch.normal(
    mean=torch.zeros(200).repeat(N_samples, 1) + torch.mean(encoded), 
    std=torch.ones(200).repeat(N_samples, 1) * torch.std(encoded)
).to(device)

In [None]:
random_vectors[0, :]

In [None]:
model.eval()
generated_image = model.decode(random_vectors)

In [None]:
generated_image=generated_image.detach().cpu()

In [None]:
generated_image = (generated_image*255).to(torch.int32)

In [None]:
output_dir = Path("/home/amal/UbuntuDocuments/writing/blog/generative_ml/images/vae")
SAVE = True
for i in range(N_samples):
    plt.imshow(generated_image[i, :].squeeze().permute(1, 2, 0))
    plt.axis('off')
    if SAVE:
        plt.savefig(output_dir / f"ae_random_gen_{i}.png", bbox_inches='tight', pad_inches=0)
    plt.show()