In [83]:
import random
import torch
import torch.nn as nn
import numpy as np
from torchvision import datasets, transforms
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt

# HYPERPARAMETERS
NUM_FEATURES = 16
NUM_EPOCHS = 100
BATCH_SIZE = 128
LEARNING_RATE = 0.0001
device = 'cuda'

In [84]:
# class LinearVAE(nn.Module):
#   def __init__(self):
#     super(LinearVAE, self).__init__()

#     # encoder
#     self.encoder = nn.Sequential(
#         nn.Linear(in_features=784, out_features=512),
#         nn.ReLU(),
#         nn.Linear(in_features=512, out_features=128),
#         nn.ReLU(),
#         nn.Linear(in_features=128, out_features=NUM_FEATURES*2)
#     )

#     # decoder 
#     self.decoder = nn.Sequential(
#       nn.Linear(in_features=NUM_FEATURES, out_features=128),
#       nn.ReLU(),
#       nn.Linear(in_features=128, out_features=512),
#       nn.ReLU(),
#       nn.Linear(in_features=512, out_features=784),
#       nn.Sigmoid()
#     )

#   def reparameterize(self, mu, log_var):
#     """
#     :param mu: mean from the encoder's latent space
#     :param log_var: log variance from the encoder's latent space
#     """
#     std = torch.exp(0.5*log_var) # standard deviation
#     eps = torch.randn_like(std) # `randn_like` as we need the same size
#     sample = mu + (eps * std) # sampling as if coming from the input space
#     return sample
 
#   def forward(self, x):
#     # encoding
#     x = self.encoder(x)
#     x = x.view(-1, 2, NUM_FEATURES)
#     # get `mu` and `log_var`
#     mu = x[:, 0, :] # the first feature values as mean
#     log_var = x[:, 1, :] # the other feature values as variance
#     # get the latent vector through reparameterization
#     z = self.reparameterize(mu, log_var)

#     # decoding
#     reconstruction = self.decoder(z)
#     return reconstruction, mu, log_var

#   def generate(self, sample):
#     generated = self.decoder(sample)
#     return generated

# model = LinearVAE()
# print(model)

In [85]:
class VAE(nn.Module):
  def __init__(self):
    super(VAE, self).__init__()

    # Encoder

    self.conv0 = nn.Conv2d(1, 32, kernel_size = 3, stride = 2, padding = 1)
    #self.conv0_bn = nn.BatchNorm2d(32)
    self.conv0_drop = nn.Dropout2d(0.25)
    self.conv1 = nn.Conv2d(32, 64, kernel_size = 3, stride = 1, padding = 1)
    #self.conv1_bn = nn.BatchNorm2d(64)
    self.conv1_drop = nn.Dropout2d(0.25)
    self.conv2 = nn.Conv2d(64, 128, kernel_size = 3, stride = 1, padding = 1)
    #self.conv2_bn = nn.BatchNorm2d(128)
    self.conv2_drop = nn.Dropout2d(0.25)
    self.conv3 = nn.Conv2d(128, 256, kernel_size = 3, stride = 2, padding = 1)
    #self.conv3_bn = nn.BatchNorm2d(256)
    self.conv3_drop = nn.Dropout2d(0.25)
    self.fc = nn.Linear(12544, NUM_FEATURES*2)

    # Decoder

    self.fc1 = nn.Linear(16, 256*7*7)
    self.trans_conv1 = nn.ConvTranspose2d(256, 128, kernel_size = 3, stride = 2, padding = 1, output_padding = 1)
    #self.trans_conv1_bn = nn.BatchNorm2d(128)
    self.trans_conv2 = nn.ConvTranspose2d(128, 64, kernel_size = 3, stride = 1, padding = 1)
    #self.trans_conv2_bn = nn.BatchNorm2d(64)
    self.trans_conv3 = nn.ConvTranspose2d(64, 32, kernel_size = 3, stride = 1, padding = 1)
    #self.trans_conv3_bn = nn.BatchNorm2d(32)
    self.trans_conv4 = nn.ConvTranspose2d(32, 1, kernel_size = 3, stride = 2, padding = 1, output_padding = 1)

  def reparameterize(self, mu, log_var):
    """
    :param mu: mean from the encoder's latent space
    :param log_var: log variance from the encoder's latent space
    """
    std = torch.exp(0.5*log_var) # standard deviation
    eps = torch.randn_like(std) # `randn_like` as we need the same size
    sample = mu + (eps * std) # sampling as if coming from the input space
    return sample
 
  def forward(self, x):
    # encoding
    x = x.view(-1, 1, 28, 28)
    x = F.leaky_relu(self.conv0(x), 0.2)
    #x = self.conv0_bn(x)
    x = self.conv0_drop(x)
    x = F.leaky_relu(self.conv1(x), 0.2)
    #x = self.conv1_bn(x)
    x = self.conv1_drop(x)
    x = F.leaky_relu(self.conv2(x), 0.2)
    #x = self.conv2_bn(x)
    x = self.conv2_drop(x)
    x = F.leaky_relu(self.conv3(x), 0.2)
    #x = self.conv3_bn(x)
    x = self.conv3_drop(x)
    x = x.view(-1, 12544)
    x = self.fc(x)

    x = x.view(-1, 2, NUM_FEATURES)
    # get `mu` and `log_var`
    mu = x[:, 0, :] # the first feature values as mean
    log_var = x[:, 1, :] # the other feature values as variance
    # get the latent vector through reparameterization
    z = self.reparameterize(mu, log_var)

    # decoding
    x = self.fc1(z)
    x = x.view(-1, 256, 7, 7)
    x = F.relu(self.trans_conv1(x))
    #x = self.trans_conv1_bn(x)
    x = F.relu(self.trans_conv2(x))
    #x = self.trans_conv2_bn(x)
    x = F.relu(self.trans_conv3(x))
    #x = self.trans_conv3_bn(x)
    x = self.trans_conv4(x)
    reconstruction = torch.sigmoid(x)
    return reconstruction, mu, log_var

  def generate(self, sample):
    x = self.fc1(sample)
    x = x.view(-1, 256, 7, 7)
    x = F.relu(self.trans_conv1(x))
    #x = self.trans_conv1_bn(x)
    x = F.relu(self.trans_conv2(x))
    #x = self.trans_conv2_bn(x)
    x = F.relu(self.trans_conv3(x))
    #x = self.trans_conv3_bn(x)
    x = self.trans_conv4(x)
    generated = torch.sigmoid(x)
    return generated

model = VAE()
print(model)
model = model.to(device)
model.float()

VAE(
  (conv0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv0_drop): Dropout2d(p=0.25, inplace=False)
  (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv1_drop): Dropout2d(p=0.25, inplace=False)
  (conv2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2_drop): Dropout2d(p=0.25, inplace=False)
  (conv3): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv3_drop): Dropout2d(p=0.25, inplace=False)
  (fc): Linear(in_features=12544, out_features=32, bias=True)
  (fc1): Linear(in_features=16, out_features=12544, bias=True)
  (trans_conv1): ConvTranspose2d(256, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
  (trans_conv2): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (trans_conv3): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (trans_conv4): ConvTranspose2d(32, 1, kernel_size=(3, 3), st

VAE(
  (conv0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv0_drop): Dropout2d(p=0.25, inplace=False)
  (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv1_drop): Dropout2d(p=0.25, inplace=False)
  (conv2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2_drop): Dropout2d(p=0.25, inplace=False)
  (conv3): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv3_drop): Dropout2d(p=0.25, inplace=False)
  (fc): Linear(in_features=12544, out_features=32, bias=True)
  (fc1): Linear(in_features=16, out_features=12544, bias=True)
  (trans_conv1): ConvTranspose2d(256, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
  (trans_conv2): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (trans_conv3): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (trans_conv4): ConvTranspose2d(32, 1, kernel_size=(3, 3), st

In [86]:
# transforms
transform = transforms.Compose([
  transforms.ToTensor(),
])
# train and validation data
train_data = datasets.MNIST(
  root='../input/data',
  train=True,
  download=True,
  transform=transform
)
val_data = datasets.MNIST(
  root='../input/data',
  train=False,
  download=True,
  transform=transform
)
# train_data  = torch.utils.data.Subset(train_data, range(0, 10000-1))
# val_data  = torch.utils.data.Subset(val_data, range(0, 2000-1))

# training and validation data loaders
train_loader = torch.utils.data.DataLoader(
  train_data,
  batch_size=BATCH_SIZE,
  shuffle=True
)
val_loader = torch.utils.data.DataLoader(
  val_data,
  batch_size=BATCH_SIZE,
  shuffle=False
)

normal_samples = torch.randn(25, NUM_FEATURES).to(device)

In [87]:
def final_loss(bce_loss, mu, logvar):
  """
  This function will add the reconstruction loss (BCELoss) and the 
  KL-Divergence.
  KL-Divergence = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
  :param bce_loss: recontruction loss
  :param mu: the mean from the latent vector
  :param logvar: log variance from the latent vector
  """
  BCE = bce_loss 
  KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
  return BCE + KLD

In [88]:
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.BCELoss(reduction='sum')

def fit(model, dataloader):
  model.train()
  running_loss = 0.0
  for i, data in enumerate(dataloader):
    data, _ = data
    data = data.to(device)
    optimizer.zero_grad()
    reconstruction, mu, logvar = model(data)
    bce_loss = criterion(reconstruction, data)
    loss = final_loss(bce_loss, mu, logvar)
    running_loss += loss.item()
    loss.backward()
    optimizer.step()
  train_loss = running_loss/len(dataloader.dataset)
  return train_loss

In [89]:
def validate(model, dataloader, samples):
  model.eval()
  running_loss = 0.0
  with torch.no_grad():
    for i, data in enumerate(dataloader):
      data, _ = data
      data = data.to(device)
      reconstruction, mu, logvar = model(data)
      reconstruction = reconstruction.to(device)
      bce_loss = criterion(reconstruction, data)
      loss = final_loss(bce_loss, mu, logvar)
      running_loss += loss.item()

      # save the last batch input and output of every epoch
      if i == int(len(val_data)/dataloader.batch_size) - 1:
        num_rows = 8
        samples.append(torch.cat((data.view(BATCH_SIZE, 1, 28, 28)[:num_rows], 
                           reconstruction.view(BATCH_SIZE, 1, 28, 28)[:num_rows])))

  val_loss = running_loss/len(dataloader.dataset)
  return val_loss

In [90]:
def view_samples(samples, epoch):
  samples = samples.to('cpu')
  fig, axes = plt.subplots(figsize=(5,5), nrows=5, ncols=5, sharex=True, sharey=True)
  for ax, img in zip(axes.flatten(), samples):
        img = img.detach()
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        im = ax.imshow(img.reshape((28,28)), cmap='Greys_r')
  plt.savefig('graphs/CNN VAE Epoch ' + str(epoch) + '.png')

In [None]:
train_loss = []
val_loss = []
samples = []

for epoch in range(NUM_EPOCHS):
    train_epoch_loss = fit(model, train_loader)
    val_epoch_loss = validate(model, val_loader, samples)
    train_loss.append(train_epoch_loss)
    val_loss.append(val_epoch_loss)
    print('Epoch [{:5d}/{:5d}] | Train loss: {:6.4f} | Val loss: {:6.4f}'.format(
                    epoch+1, NUM_EPOCHS, train_epoch_loss, val_epoch_loss))
    if epoch % 5 == 0:
      model.eval()
      generated_images = model.generate(normal_samples)
      generated_images = generated_images.view(25, 1, 28, 28)
      view_samples(generated_images, epoch+1)

Epoch [    1/  100] | Train loss: 211.5640 | Val loss: 135.7545
Epoch [    2/  100] | Train loss: 125.8046 | Val loss: 114.9909
Epoch [    3/  100] | Train loss: 115.4649 | Val loss: 110.1024
Epoch [    4/  100] | Train loss: 111.7206 | Val loss: 108.0703
Epoch [    5/  100] | Train loss: 109.5002 | Val loss: 106.1999
Epoch [    6/  100] | Train loss: 107.9087 | Val loss: 105.0341
Epoch [    7/  100] | Train loss: 106.7520 | Val loss: 104.4775
Epoch [    8/  100] | Train loss: 105.7674 | Val loss: 103.2803
Epoch [    9/  100] | Train loss: 105.0521 | Val loss: 102.8292
Epoch [   10/  100] | Train loss: 104.3404 | Val loss: 102.5800
Epoch [   11/  100] | Train loss: 103.8254 | Val loss: 101.9558
Epoch [   12/  100] | Train loss: 103.3672 | Val loss: 101.4623
Epoch [   13/  100] | Train loss: 102.9385 | Val loss: 101.1669
Epoch [   14/  100] | Train loss: 102.5843 | Val loss: 101.0965
Epoch [   15/  100] | Train loss: 102.2410 | Val loss: 100.7495
Epoch [   16/  100] | Train loss: 101.91

In [None]:
fig, ax = plt.subplots()
plt.plot(train_loss, label='Training loss')
plt.plot(val_loss, label='Val loss')
plt.title("Training Losses")
plt.legend()

In [None]:
# rows = 5
# fig, axes = plt.subplots(figsize=(7,12), nrows=2*rows, ncols=8, sharex=True, sharey=True)
# flat_axes = [ax for ax_row in axes for ax in ax_row]

# for row in range(rows):
#   sample = samples[row * int(len(samples)/(rows))]
#   view_samples(sample, flat_axes[row*16 : (row+1)*16])

In [None]:
# sample_size = 12
# indices = random.sample(range(0, 2000-1), sample_size)
# sample_subset = torch.utils.data.Subset(val_data, random.sample(range(0, 2000-1), 12))
# loader = torch.utils.data.DataLoader(sample_subset, batch_size=sample_size)
# sample, _ = next(iter(loader))
# print(sample.size())
# sample = sample.view(sample.size(0), -1)
# print(sample.size())

# model.eval()
# reconstruction, _, _ = model(sample)
# print(reconstruction.size())

# out = torch.cat((sample.view(sample_size, 1, 28, 28), 
#            reconstruction.view(sample_size, 1, 28, 28)))

# print(out.size())

# fig, axes = plt.subplots(figsize=(10,2), nrows=2, ncols=sample_size, sharex=True, sharey=True)
# flat_axes = [ax for ax_row in axes for ax in ax_row]

# view_samples(out, flat_axes)

In [None]:
# model.eval()
# generated_images = model.generate(normal_samples)
# generated_images = generated_images.view(25, 1, 28, 28)

# view_samples(generated_images, flat_axes)