In [2]:
# mounting google drive to colab
from google.colab import drive
drive.mount('/content/drive')

%load_ext autoreload
%autoreload 2

Mounted at /content/drive


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import os

# MNIST Dataset
train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor(), download=True)

# DataLoader
batch_size = 64
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw



In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [5]:
%cd /content/drive/MyDrive/Deep-Learning-Paper-Practice/Auto-Encoding Variational Bayes

/content/drive/MyDrive/Deep-Learning-Paper-Practice/Auto-Encoding Variational Bayes


In [6]:
from vae import VAE

vae = VAE(x_dim=784, h_dim1=512, h_dim2=256, z_dim=2)
vae.to(device)

VAE(
  (fc1): Linear(in_features=784, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=256, bias=True)
  (fc31): Linear(in_features=256, out_features=2, bias=True)
  (fc32): Linear(in_features=256, out_features=2, bias=True)
  (fc4): Linear(in_features=2, out_features=256, bias=True)
  (fc5): Linear(in_features=256, out_features=512, bias=True)
  (fc6): Linear(in_features=512, out_features=784, bias=True)
)

In [7]:
optimizer = optim.Adam(vae.parameters())
# return reconstruction error + KL divergence losses
from loss import loss_function

In [8]:
from plot_utils import plot_latent_space, plot_label_clusters

In [9]:
all_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=len(train_dataset), shuffle=False) # all train dataset

In [10]:
images, labels = next(iter(all_dataloader))

In [11]:
images.view(-1,784).shape

torch.Size([60000, 784])

In [12]:
def train(vae, dataloader, epoch):
    print("="*70)
    print(f"Epoch:{epoch: 3d}")
    print("Train")

    vae.train()

    train_loss = 0
    log_interval = 100

    for batch_idx, (x, _) in enumerate(dataloader):
        optimizer.zero_grad()
        x = x.to(device)
        
        recon, mu, log_var = vae(x)
        loss = loss_function(recon, x, mu, log_var)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_idx % log_interval == 0 and batch_idx > 0:
          print("-"*70)
          print(f"| epoch {epoch:3d} | {batch_idx:5d}/{len(dataloader):5d} batch | loss {loss/batch_size:8.3f}")
    print(f"End of Epoch: {epoch:3d}, Train Loss: {train_loss/len(dataloader.dataset)}")
    # plot and save
    save_root = "./fig/"
    if not os.path.exists(save_root):
      os.makedirs(save_root)
    if epoch == 1 or epoch == 20: save_plot = True
    else: save_plot = False
    plot_latent_space(vae=vae, n=30, figsize=15, device=device, save_root=save_root, save_plot=save_plot, epoch=epoch)
    plot_label_clusters(vae=vae, data=images.view(-1, 784), labels=labels, device=device, save_root=save_root, save_plot=save_plot, epoch=epoch) # from all_dataloader

    return train_loss/len(dataloader.dataset)

In [13]:
def test(vae, dataloader, epoch):
    print("="*70)
    print("Test")

    vae.eval()

    test_loss= 0
    with torch.no_grad():
        for x, _ in dataloader:
            x = x.to(device)
            recon, mu, log_var = vae(x)
            
            # sum up batch loss
            test_loss += loss_function(recon, x, mu, log_var).item()
        
    test_loss /= len(dataloader.dataset)
    print(f"End of Epoch: {epoch:3d}, Test Loss: {test_loss:8.3f}")

    return test_loss

In [14]:
train_loss_epoch = []
test_loss_epoch = []
for epoch in range(1, 21):
    train_loss_epoch = train(vae, train_dataloader, epoch)
    test_loss_epoch = test(vae, test_dataloader, epoch)

Output hidden; open in https://colab.research.google.com to view.

In [43]:
# files to plot
filelist = [
            './fig/latent_space_epoch1.png',
            './fig/latent_space_epoch20.png',
            './fig/label_clusters_epoch1.png',
            './fig/label_clusters_epoch20.png'
]

In [42]:
import matplotlib.pyplot as plt

img = []
img.append(plt.imread(filelist[0]))
img.append(plt.imread(filelist[1]))
img.append(plt.imread(filelist[2]))
img.append(plt.imread(filelist[3]))

fig, axarr = plt.subplots(2,2, figsize = (20, 20), dpi=300)

axarr[0,0].imshow(img[0])
axarr[0,0].set_title("latent space @epoch1", fontsize=18)
axarr[0,0].axis('off')

axarr[0,1].imshow(img[1])
axarr[0,1].set_title("latent space @epoch20", fontsize=18)
axarr[0,1].axis('off')

axarr[1,0].imshow(img[2])
axarr[1,0].set_title("label clusters for all train dataset @epoch1", fontsize=18)
axarr[1,0].axis('off')

axarr[1,1].imshow(img[3])
axarr[1,1].set_title("label clusters for all train dataset @epoch20", fontsize=18)
axarr[1,1].axis('off')

plt.subplots_adjust(left=0.125,
                    bottom=0.1, 
                    right=0.9, 
                    top=0.3, 
                    wspace=0.2, 
                    hspace=0.35)
fig.tight_layout()

Output hidden; open in https://colab.research.google.com to view.