In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch 
from torch.autograd import Variable
import torch.nn as nn
from torch.nn import Conv2d, InstanceNorm2d, ConvTranspose2d
from torch import optim 
import torch.nn.functional as F 
import torchvision

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
# this won't work if the drive is shared with you
# solution: right click on the shared  project folder and select 'add shortcut to my drive'
%cd gdrive/My\ Drive/CS7643_Final_Project

In [None]:
class VAE_Classifier(nn.Module):
  def __init__(self, latent_dim = 20):
    super(VAE_Classifier, self).__init__()
    self.latent_dim = latent_dim

    self.cn1 = nn.Conv2d(1, 32, kernel_size = 3, stride = 2, padding = 1)
    self.cn2 = nn.Conv2d(32, 64, kernel_size = 3, stride = 2, padding = 1)
    self.m = nn.Linear(64 * 7 * 7, latent_dim)
    self.v = nn.Linear(64 * 7 * 7, latent_dim)

    self.fc1 = nn.Linear(latent_dim, 40)
    self.fc2 = nn.Linear(40, 40)
    self.fc3 = nn.Linear(40, 30)
    self.fc4 = nn.Linear(30, 20)
    self.fc5 = nn.Linear(20, 10)

  def encode(self, x):
    x = F.relu(self.cn1(x))
    x = F.relu(self.cn2(x))
    x = x.view(x.shape[0], -1)
  
    mu = F.relu(self.m(x))
    logvar = F.relu(self.v(x))
    return mu, logvar
  
  def decode(self, z):
    out = F.relu(self.fc1(z))
    out = F.relu(self.fc2(out))
    out = F.relu(self.fc3(out))
    out = F.relu(self.fc4(out))
    out = self.fc5(out)
    return out
  
  def reparameterize(self, mu, logvar):
    std = torch.exp(0.5*logvar)
    eps = torch.randn_like(std)
    return mu + eps*std
  
  def forward(self, x):
    mu, logvar = self.encode(x)
    lz = self.reparameterize(mu, logvar)
    out = self.decode(lz)
    return out, mu, logvar

In [None]:
def loss_function(pred, label, mu, logvar):
  crossentropy = nn.CrossEntropyLoss()
  
  CE = crossentropy(pred, label)
  KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
  return CE + KLD

In [None]:
def evaluate_model(model, dataloader, criterion = nn.CrossEntropyLoss()):
    """
    Calculates the average loss and accuracy of the model on a dataset
    """
    model.eval() # Notify all layers we're in eval mode instead of training mode
    loss, total, correct = 0, 0, 0
    n = 0
    with torch.no_grad():
      for xs, ys in dataloader:
          output, _, _ = model(xs.to(device))
          loss += criterion(output, ys.to(device)).item()
          _, predictions = torch.max(output.data, 1)
          total += ys.size(0)
          correct += (predictions == ys.to(device)).sum().item()
          n += 1
    
    print(f'\nTest loss: {loss / n: .2f} | Test accuracy: {correct / total : .2f}')
    return loss / n, 100 * correct / total

In [None]:
batch_size = 128

In [None]:
train_data_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size, shuffle=True)

test_data_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /files/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /files/MNIST/raw/train-images-idx3-ubyte.gz to /files/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /files/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /files/MNIST/raw/train-labels-idx1-ubyte.gz to /files/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /files/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /files/MNIST/raw/t10k-images-idx3-ubyte.gz to /files/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /files/MNIST/raw/t10k-labels-idx1-ubyte.gz





HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /files/MNIST/raw/t10k-labels-idx1-ubyte.gz to /files/MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [None]:
log_epoch = 10

In [None]:
classifier = VAE_Classifier().to(device)
optimizer = optim.Adam(params = classifier.parameters())

In [None]:
losses = []
test_losses = []
epochs = 200

In [None]:
for epoch in range(1, epochs + 1):
  total = 0
  n = 0
  for x, y in train_data_loader:

    optimizer.zero_grad()
    
    out, mu, logvar = classifier(x.to(device))
    loss = loss_function(out, y.to(device), mu, logvar)
    loss.backward()
    optimizer.step()

    total += loss.item()
    n += 1
  
  losses.append(total / n)
  test_loss, test_accuracy = evaluate_model(classifier, test_data_loader)
  # put model back in train mode after putting it in eval mode
  classifier.train()
  test_losses.append(test_loss)
  
  print('Epoch: {}, Avg Loss: {:.4f}'.format(epoch, total / n))


Test loss:  2.30 | Test accuracy:  0.11
Epoch: 1, Avg Loss: 2.3136

Test loss:  2.30 | Test accuracy:  0.11
Epoch: 2, Avg Loss: 2.3016

Test loss:  2.30 | Test accuracy:  0.11
Epoch: 3, Avg Loss: 2.3015

Test loss:  2.30 | Test accuracy:  0.11
Epoch: 4, Avg Loss: 2.3015

Test loss:  2.30 | Test accuracy:  0.11
Epoch: 5, Avg Loss: 2.3015

Test loss:  2.30 | Test accuracy:  0.11
Epoch: 6, Avg Loss: 2.3015

Test loss:  2.30 | Test accuracy:  0.11
Epoch: 7, Avg Loss: 2.3014

Test loss:  2.30 | Test accuracy:  0.11
Epoch: 8, Avg Loss: 2.3014

Test loss:  2.30 | Test accuracy:  0.11
Epoch: 9, Avg Loss: 2.3013

Test loss:  2.30 | Test accuracy:  0.11
Epoch: 10, Avg Loss: 2.3014

Test loss:  2.30 | Test accuracy:  0.11
Epoch: 11, Avg Loss: 2.3013

Test loss:  2.30 | Test accuracy:  0.11
Epoch: 12, Avg Loss: 2.3013

Test loss:  2.30 | Test accuracy:  0.11
Epoch: 13, Avg Loss: 2.3014

Test loss:  2.30 | Test accuracy:  0.11
Epoch: 14, Avg Loss: 2.3013

Test loss:  2.30 | Test accuracy:  0.11
Ep

KeyboardInterrupt: ignored