In [1]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [0]:
import os
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.utils import save_image

In [3]:
# Define Hyperperameters
image_size = 784
hidden_dim = 400
latent_dim = 20
batch_size = 128
epochs = 50

#Check GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#MNIST Dataset

train_dataset = torchvision.datasets.MNIST(root='../../data', 
                                           train=True, 
                                           download=True, 
                                           transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)

test_dataset = torchvision.datasets.MNIST(root='../../data', 
                                           train=False, 
                                           download=True, 
                                           transform=transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)

# Create a Directory to save the reconstructed and sampled images (if directory not present)
PATH = '/content/drive/My Drive/Deep_Learning_Udemy/VAE/'
sample_dir = PATH+str('results')
if not os.path.exists(sample_dir):
  os.makedirs(sample_dir)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../../data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../../data/MNIST/raw/train-images-idx3-ubyte.gz to ../../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../../data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../../data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../../data/MNIST/raw
Processing...



Done!




In [0]:
#VAE Model

class VAE(nn.Module):
  def __init__(self):
    super(VAE, self).__init__()

    self.fc1 = nn.Linear(image_size, hidden_dim)
    self.fc2_mean = nn.Linear(hidden_dim, latent_dim)
    self.fc2_logvar = nn.Linear(hidden_dim, latent_dim)
    self.fc3 = nn.Linear(latent_dim, hidden_dim)
    self.fc4 = nn.Linear(hidden_dim, image_size)

  def encode(self, x):
    h = F.relu(self.fc1(x))
    mu = self.fc2_mean(h)
    logvar = self.fc2_logvar(h)
    return mu, logvar

  def reparameterize(self, mu, logvar):
    std = torch.exp(logvar)
    eps = torch.randn_like(std)

    return mu + std * eps
  
  def decode(self, z):
    h = F.relu(self.fc3(z))
    out = torch.sigmoid(self.fc4(h))
    return out

  def forward(self, x):
    mu, logvar = self.encode(x.view(-1, image_size))
    z = self.reparameterize(mu, logvar)
    reconstructed = self.decode(z)
    return reconstructed, mu, logvar




In [0]:
  # Define model and optimizer
  model = VAE().to(device)
  optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)


In [0]:
# Define Loss
def loss_function(reconstructed_image, original_image, mu, logvar):
  bce = F.binary_cross_entropy(reconstructed_image, original_image.view(-1, 784), reduction='sum')
  kld = 0.5 * torch.sum(logvar.exp() + mu.pow(2) - 1 - logvar)
  return bce + kld


In [0]:
# Train function
def train(epoch):
  model.train()
  train_loss = 0
  for i, (images, _ )in enumerate(train_loader):
    images = images.to(device)
    reconstructed, mu, logvar = model(images)
    loss = loss_function(reconstructed, images, mu, logvar)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    train_loss += loss.item()

    if i % 100 == 0:
      print("Training epochs: {} [Batch {}/{}]\tLoss: {:.3f}".format(epoch, i, len(train_loader), loss.item() / len(images)))
    
  print("==========================>Epochs {}, Average Loss for Training: {:.3f}".format(epoch, train_loss/ len(train_loader.dataset)))   

In [0]:
# Test function
def test(epoch):
  model.eval()
  test_loss = 0
  for batch_idx, (images, _ )in enumerate(test_loader):
    images = images.to(device)
    reconstructed, mu, logvar = model(images)
    loss = loss_function(reconstructed, images, mu, logvar)
    test_loss += loss.item()
    if batch_idx == 0:
      comparison = torch.cat([images[:5], reconstructed.view(batch_size, 1, 28, 28)[:5]])
      save_image(comparison.cpu(), sample_dir+'/reconstruction_' + str(epoch)+ '.png', nrow=5)

    
  print("==========================>Epochs {}, Average Loss for Testing: {:.3f} ".format(epoch, test_loss/ len(test_loader.dataset)))   

In [9]:
# Main function
for epoch in range(1, epochs + 1):
  train(epoch)
  test(epoch)
  with torch.no_grad():
    #Get rid of encoder and sample z from the gaussian distribution and feed it to decoder to generate 64 sample images
    sample = torch.randn(64, 20).to(device)
    generated = model.decode(sample).cpu()
    save_image(generated.view(64, 1, 28, 28), sample_dir+'/sample_' +  str(epoch) + '.png' )

Training epochs: 1 [Batch 0/469]	Loss: 549.412
Training epochs: 1 [Batch 100/469]	Loss: 194.161
Training epochs: 1 [Batch 200/469]	Loss: 160.623
Training epochs: 1 [Batch 300/469]	Loss: 139.082
Training epochs: 1 [Batch 400/469]	Loss: 117.197
Training epochs: 2 [Batch 0/469]	Loss: 116.225
Training epochs: 2 [Batch 100/469]	Loss: 116.435
Training epochs: 2 [Batch 200/469]	Loss: 112.324
Training epochs: 2 [Batch 300/469]	Loss: 107.715
Training epochs: 2 [Batch 400/469]	Loss: 104.583
Training epochs: 3 [Batch 0/469]	Loss: 105.166
Training epochs: 3 [Batch 100/469]	Loss: 101.726
Training epochs: 3 [Batch 200/469]	Loss: 103.773
Training epochs: 3 [Batch 300/469]	Loss: 100.210
Training epochs: 3 [Batch 400/469]	Loss: 100.751
Training epochs: 4 [Batch 0/469]	Loss: 101.309
Training epochs: 4 [Batch 100/469]	Loss: 101.296
Training epochs: 4 [Batch 200/469]	Loss: 95.138
Training epochs: 4 [Batch 300/469]	Loss: 99.501
Training epochs: 4 [Batch 400/469]	Loss: 97.506
Training epochs: 5 [Batch 0/469