<a href="https://colab.research.google.com/github/YifanXu1999/AI-Learning/blob/master/VAEMNISTLatent5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

from torch.utils.data.dataloader import DataLoader
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
!pip install tsne
from tsne import bh_sne

In [0]:
# Data loaders
trainloader = DataLoader(
    MNIST(root='./data',train=True,download=True,transform=transforms.ToTensor()),
    batch_size=128,shuffle=True)
testloader = DataLoader(
    MNIST(root='./data',train=False,download=True,transform=transforms.ToTensor()),
    batch_size=5000,shuffle=True)
device = torch.device("cuda:0")
print(torch.cuda.device_count())

In [0]:
class VAE(nn.Module):
    def __init__(self, latent_dim=5, hidden_dim=500):
        super(VAE, self).__init__()
        self.encoder_l1 = nn.Linear(784, hidden_dim)
        self.encoder_mean = nn.Linear(hidden_dim, latent_dim)
        self.encoder_logvar = nn.Linear(hidden_dim, latent_dim)
        self.decoder_l1 = nn.Linear(latent_dim, hidden_dim)
        self.decoder_output = nn.Linear(hidden_dim, 784)

    def encode(self, x_in):
        x = F.relu(self.encoder_l1(x_in.view(-1, 784)))
        mean = self.encoder_mean(x)
        logvar = self.encoder_logvar(x)
        return mean, logvar
    
    def decode(self, z):
        z = F.relu(self.decoder_l1(z))
        x_out = torch.sigmoid(self.decoder_output(z))
        return x_out.view(-1, 1, 28, 28)
    
    def sample(self, mu, log_var):
        # z = mu + standard deviavation * eps
        eps = torch.normal(torch.zeros(size=mu.size()), torch.ones(size=log_var.size())).cuda()
        sd = torch.exp(log_var * 0.5)
        z = mu + sd * eps
        return z

    def forward(self, x_in):
        z_mean, z_logvar = self.encode(x_in)
        z = self.sample(z_mean, z_logvar)
        x_out = self.decode(z)
        return x_out, z_mean, z_logvar

In [0]:
# Loss function
def criterion(x_out, x_in, z_mu, z_logvar):
    # ELBO = -DK(q(z|x)|| p(z)) + logp_theta(x|z)
    #      = 1/2(1 + log(var) - mu ^2 - var) +  logp_theta(x|z)
    bce_loss = F.binary_cross_entropy(x_out,x_in, reduction='sum')
    kld_loss = -0.5 * torch.sum(1 + z_logvar - (z_mu ** 2) - torch.exp(z_logvar))
    kld_loss = 0
    loss = (bce_loss + kld_loss) / x_out.size(0) # normalize by batch size
    return loss

In [0]:
def remove_imgs(imgs, labels, digits_to_remove):
  list_imgs = list(imgs)
  new_imgs_list = torch.tensor([])
  new_labels = []
  for i in range(len(list_imgs)):
    if labels[i] not in digits_to_remove:
      new_imgs_list = torch.cat([new_imgs_list, list_imgs[i]], 0)
      new_labels.append(labels[i])
  return  new_imgs_list.view(-1, 1, 28 ,28), new_labels

In [0]:
# Training
def train(model,optimizer,dataloader,epochs=4):
    losses = []
    for epoch in range(epochs):
        for images, labels in dataloader:
            x_in = images.cuda()
            optimizer.zero_grad()
            x_out, z_mu, z_logvar = model(x_in)
            loss = criterion(x_out,x_in,z_mu,z_logvar)
            loss.backward()
            optimizer.step()
            losses.append(loss.cpu().data.numpy())
    print("done")

    return losses

In [0]:
model = VAE().to(device)
# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train_losses = train(model,optimizer,trainloader)
plt.figure(figsize=(10,5))
plt.plot(train_losses)
plt.show()

In [0]:
def imshow(img):
    npimg = img.cpu().numpy()
    
    plt.imshow(np.transpose(npimg,(1,2,0)))
    plt.axis('off')
    plt.show()

imgs, label = iter(testloader).next()

def add_noise(x, noise_factor=0.2):
    x = x + np.random.randn(*x.shape) * noise_factor
    x = x.clip(0., 1.)
    return x
def visualize(images, label, model):
  x_in = images
  x_out,_, _ = model(x_in.view(-1, 28 * 28))
  x_out = x_out.data
  z_mu, z_logvar = model.encode(x_in)
  z = model.sample(z_mu, z_logvar )
  imshow(make_grid(x_in[0:16]))
  imshow(make_grid(x_out[0:16]))
  return z.cpu().data.numpy()

In [0]:
imgs, labels = iter(testloader).next()
z = visualize(imgs.cuda(), label, model)

In [0]:
z = bh_sne(np.float64(z))

In [0]:
plt.scatter(z[:, 0], z[:, 1], c=labels)
plt.colorbar()