In [None]:
!pip install pycm livelossplot
!pip install torchsummary 
!pip install tsne_torch

%pylab inline

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import torchvision.datasets
import torchvision.transforms as transforms
from torchsummary import summary
from torchvision.datasets import FashionMNIST, MNIST
import torch.nn.functional as F
from collections import Counter, defaultdict
from tqdm import tqdm
from tsne_torch import TorchTSNE as TSNE
from livelossplot import PlotLosses
import random

import sys
sys.path.insert(1, '..')
import models

In [None]:
def set_seed(seed):

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.benchmark = False  ##uses the inbuilt cudnn auto-tuner to find the fastest convolution algorithms. -
    torch.backends.cudnn.enabled   = False

    return True



In [None]:
device = 'cpu'
if torch.cuda.device_count() > 0 and torch.cuda.is_available():
    print("Cuda installed! Running on GPU!")
    device = 'cuda'
else:
    print("No GPU available!")

In [None]:
from google.colab import drive
drive.mount('/content/gdrive/')

<br>

---

<br>

In [None]:
batch_size = 100

In [None]:
# Load the FashionMNIST dataset and specify the transformations.
fashion_mnist_dataset = FashionMNIST("./", 
                                     transform=transforms.Compose([
                                         transforms.ToTensor(),
                                         transforms.RandomHorizontalFlip(p=0.5),
                                         ]), 
                                     target_transform=torchvision.transforms.Compose([
                                         lambda x:torch.LongTensor([x])
                                        ]),
                                     download=True, train=True)

In [None]:
fashion_mnist_dataset

In [None]:
class_to_idx = fashion_mnist_dataset.class_to_idx
class_to_idx

In [None]:
Counter(fashion_mnist_dataset.targets.to('cpu').detach().numpy())

In [None]:
fashion_mnist_dataset.data[0].shape

#### Plotting 10 samples from each class

In [None]:
def plot_classes(dataset, num_per_class=10):
  class_counts = defaultdict(int)
  images = []
  for img, label_tensor in zip(dataset.data, dataset.targets):
    label = label_tensor.item()
    if class_counts[label] < 10:
      images.append((img.to('cpu').detach().numpy(), label))
      class_counts[label] += 1

  images = [x for x in sorted(images, key=lambda t: t[1])]
  _, ax = plt.subplots(10, 10, figsize=[20, 20])
  for i, img in enumerate(images):
    ax[img[1], i % 10].imshow(img[0].squeeze(), cmap='gray')

plot_classes(fashion_mnist_dataset)

In [None]:
train_loader = DataLoader(dataset=fashion_mnist_dataset, batch_size=batch_size, shuffle=True)

### Load the cVAE model

In [1]:
C = models.Conditional_VAE(20).to(device)

In [None]:
groups = {'Loss': ['VAE_Loss']}

liveloss = PlotLosses(groups=groups)

# Set hyperparameters
lr = 0.0001
latent_dims = 50
epochs = 30

recon_losses = []
kl_divs = []

def train(vae, data, kl_div_on=True, epochs=10, device='cpu', lr=1e-4):

  # Instantiate the optimizer
  opt = torch.optim.Adam(vae.parameters(), lr=lr, betas=(0.5,0.9))
  vae.train()
  for epoch in range(1, epochs+1): 
    logs = {}
    last_loss = 0
    for batch, label in (data): 
      batch = batch.to(device) 
      label = label.to(device)
      opt.zero_grad()

      # Input batch to model
      x_hat, kl_div = vae(batch, label) 

      # Calculate loss
      loss = ((batch - x_hat)**2).mean() + kl_div 

      # Update model
      loss.backward()
      opt.step()

      logs['VAE_Loss'] = loss.to('cpu').detach().numpy()
      recon_losses.append(((batch - x_hat)**2).mean().to('cpu').detach().numpy())
      kl_divs.append(kl_div.to('cpu').detach().numpy())
    
    liveloss.update(logs)
    liveloss.draw()
    print(recon_losses)
    if(np.mod(epoch, 5) == 0):
      torch.save(vae.state_dict(), "./VAE.pth".format(epoch))
  return vae


conditional_vae = Conditional_VAE(latent_dims).to(device)
conditional_vae = train(conditional_vae.train(True), train_loader, lr=lr, epochs=epochs, device=device)
conditional_vae.eval()


In [None]:
vae_30_epochs = plt.imread("/content/gdrive/My Drive/images/vae_training.png", format='png')

f, axarr = plt.subplots(figsize=(10,10) )
axarr.imshow(vae_30_epochs)

axarr.title.set_text('VAE Training')

#### Plotting the latent space

Using T-SNE to reduce the dimensions of some latent space samples to 2 so that they can be plotted. The data appears to be fairly normally distributed although there are some outliers.

In [None]:

def plot_2D_latent_space(autoencoder, data, num_batches=40):
    for n , (x, y) in enumerate(data):  
        z, KL = autoencoder.vae_latent_space(autoencoder.encoder(x.to(device), y.to(device)))
        z = z.to('cpu').detach().numpy() 
        z = TSNE(n_components=2, perplexity=30, n_iter=100, verbose=True).fit_transform(z)
        plt.scatter(z[:, 0], z[:, 1], c=y, cmap='tab10')
        if n > num_batches:
          plt.colorbar()
          break
plot_2D_latent_space(conditional_vae, train_loader)
plt.show()


#### Reconstructed Images

In [None]:
images, labels = next(iter(train_loader))  # Get the first batch of images

_, ax = plt.subplots(2, 5, figsize=[15, 6])
for n, idx  in enumerate(torch.randint(0,images.shape[0], (5,))):
  recon, _ = conditional_vae(images[idx].unsqueeze(0).cuda(), labels[idx].unsqueeze(0).cuda())  # Are mu and sigma correct
  ax[0, n].imshow(images[idx].squeeze(), cmap="gray")
  ax[1, n].imshow(recon.cpu().detach().squeeze(), cmap="gray")

#### Generated Images 

Generate new images from random z vectors.

In [None]:
set_seed(0)

def plot_samples(vae):
  with torch.no_grad():
      test_z, labels = torch.rand(batch_size, latent_dims).to(device).float().to(device), torch.linspace(0,9,10).repeat(10).to(device).long().view(-1, 1)
      generated = vae.decoder(vae.activationOut(vae.latentOut(test_z)), labels).to('cpu').detach().numpy()
  fig, axarr = plt.subplots(10, 10, figsize=(12, 12))
  for ax, img in zip(axarr.flatten(), generated):
    ax.imshow(img.squeeze(0), cmap="gray")

plot_samples(conditional_vae)