In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary


batch_size = 128

train_dataset = torchvision.datasets.MNIST(root = './dataset', train= True, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()]), download=True)

test_dataset = torchvision.datasets.MNIST(root = './dataset', train= False, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()]), download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size = batch_size, shuffle = True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size = batch_size,shuffle = True)




Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./dataset/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 144943405.04it/s]

Extracting ./dataset/MNIST/raw/train-images-idx3-ubyte.gz to ./dataset/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./dataset/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 38813102.80it/s]


Extracting ./dataset/MNIST/raw/train-labels-idx1-ubyte.gz to ./dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 34562521.35it/s]


Extracting ./dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to ./dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 17820887.53it/s]


Extracting ./dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./dataset/MNIST/raw



In [None]:
latent_dim = 200
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()

        self.layer1 = nn.Conv2d(1, 8, kernel_size=(5, 5), padding='same')#, stride=(2, 2))
        self.layer2 = nn.Conv2d(8, 16, kernel_size=(5, 5), padding='same')
        self.layer3 = nn.Conv2d(16, 32, kernel_size=(5, 5), padding='same')#, stride=(2, 2))
        self.layer4 = nn.Conv2d(32, 64, kernel_size=(5, 5), padding='same')
        self.layer5 = nn.Conv2d(64, 64, kernel_size=(5, 5), padding='same')#, stride=(2, 2))

        self.z_mean = nn.Linear(64 * 32 * 32, latent_dim)
        self.z_log_var = nn.Linear(64 * 32 * 32, latent_dim)

    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = F.relu(self.layer3(x))
        x = F.relu(self.layer4(x))
        x = F.relu(self.layer5(x))
        x = x.view(x.size(0), -1)
        x_mu = self.z_mean(x)
        x_log_var = self.z_log_var(x)
        return x_mu, x_log_var

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()

        self.dense = nn.Linear(latent_dim, 64 * 32 * 32)
        self.layer1 = nn.ConvTranspose2d(64, 64, kernel_size=(1))
        self.layer2 = nn.ConvTranspose2d(64, 32, kernel_size=(1))
        self.layer3 = nn.ConvTranspose2d(32, 16, kernel_size=(1))
        self.layer4 = nn.ConvTranspose2d(16, 8, kernel_size=(1))
        self.layer5 = nn.ConvTranspose2d(8, 1, kernel_size=(1))

    def forward(self, x):
        x = self.dense(x)
        x = x.view(x.size(0), 64, 32, 32)
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = F.relu(self.layer3(x))
        x = F.relu(self.layer4(x))
        x = F.relu(self.layer5(x))
        return x

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, image):
        latent_mu, latent_log_var = self.encoder(image)
        latent = self.latent_sample(latent_mu, latent_log_var)
        x = self.decoder(latent)
        return x, latent_mu, latent_log_var

    def latent_sample(self, mu, log_var):
        if self.training:
            std = log_var.mul(0.5).exp_()
            eps = torch.empty_like(std).normal_()
            return eps.mul(std).add(mu)
        else:
            return mu


def vae_loss(recon_x, x, mu, log_var):
    reconstruction_loss = F.mse_loss(recon_x.view(-1, 1024), x.view(-1, 1024), reduction='sum')
    #reconstruction_loss = F.binary_cross_entropy(recon_x.view(-1, 1024), x.view(-1, 1024), reduction='sum')
    kl = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return reconstruction_loss + kl


In [None]:
epochs_VAE = 50 #200
total_step = len(train_loader)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
modelVAE = VAE().to(device)

#for VAE
adam_optimizer = torch.optim.Adam(modelVAE.parameters(), lr=0.001)

for epoch in range(epochs_VAE):
        i = 0
        for image, _ in train_loader:
            image = image.to(device)
            # Forward pass
            outputs, mu, log_var = modelVAE(image)

            #VAE loss
            loss = vae_loss(outputs, image, mu, log_var)
            adam_optimizer.zero_grad()
            loss.backward()
            adam_optimizer.step()
            i += 1
            if (i+1) % 100 == 0:
                print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                    .format(epoch+1, epochs_VAE, i+1, total_step, loss.item()))


# with torch.no_grad():
#   correct = 0
#   total = 0
#   for images, labels in test_loader:
#       images = images.to(device)
#       labels = labels.to(device)
#       outputs = modelVAE(images)
#       print(outputs)
#       _, predicted = torch.max(outputs.data, 1)
#       total += labels.size(0)
#       correct += (predicted == labels).sum().item()

#   print('Accuracy of the network: {} %'.format(100 * correct / total))


Epoch [1/50], Step [100/469], Loss: 7513.2295
Epoch [1/50], Step [200/469], Loss: 6227.8120
Epoch [1/50], Step [300/469], Loss: 5665.9331
Epoch [1/50], Step [400/469], Loss: 5773.9663
Epoch [2/50], Step [100/469], Loss: 5016.7959
Epoch [2/50], Step [200/469], Loss: 4982.6235
Epoch [2/50], Step [300/469], Loss: 5049.9497
Epoch [2/50], Step [400/469], Loss: 5098.4824
Epoch [3/50], Step [100/469], Loss: 4817.0791
Epoch [3/50], Step [200/469], Loss: 4636.5557
Epoch [3/50], Step [300/469], Loss: 4652.8613
Epoch [3/50], Step [400/469], Loss: 4635.1289
Epoch [4/50], Step [100/469], Loss: 4490.4824
Epoch [4/50], Step [200/469], Loss: 4159.7109
Epoch [4/50], Step [300/469], Loss: 4396.6226
Epoch [4/50], Step [400/469], Loss: 4445.0874
Epoch [5/50], Step [100/469], Loss: 4208.6104
Epoch [5/50], Step [200/469], Loss: 4169.9834
Epoch [5/50], Step [300/469], Loss: 4391.2539
Epoch [5/50], Step [400/469], Loss: 4122.5586
Epoch [6/50], Step [100/469], Loss: 4066.9685
Epoch [6/50], Step [200/469], Loss

In [None]:
with torch.no_grad():
  correct = 0
  total = 0
  for images, labels in test_loader:
      images = images.to(device)
      labels = labels.to(device)
      outputs = modelVAE(images)
      print(outputs)
      _, predicted = torch.max(outputs.data, 1)
      total += labels.size(0)
      correct += (predicted == labels).sum().item()

  print('Accuracy of the network: {} %'.format(100 * correct / total))


(tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        ...,


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0

AttributeError: ignored

In [None]:
# Function to show images from tensor
def save_images(title=None):
    i = 0
    for img, _ in train_loader:
      if i == 100:
        break
      img = img.to(device)
      img, _, _ = modelVAE(img)
      torchvision.utils.save_image(img[0], "/content/drive/MyDrive/VAE images/MNIST/" + str(i) + ".jpg", normalize=True)
      i += 1

In [None]:
save_images()

In [None]:
modelVAE_scripted = torch.jit.script(modelVAE) # Export to TorchScript
modelVAE_scripted.save('/content/drive/MyDrive/VAE images/modelVAE_MNIST.pt') # Save

In [None]:
import matplotlib.pyplot as plt
plt.ion()
import torchvision.utils

modelVAE.eval()

# This function takes as an input the images to reconstruct
# and the name of the model with which the reconstructions
# are performed
def to_img(x):
    x = x.clamp(0, 1)
    return x

def show_image(img):
    img = to_img(img)
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

def visualise_output(images, model):

    with torch.no_grad():

        images = images.to(device)
        images, _, _ = model(images)
        images = images.cpu()
        images = to_img(images)
        np_imagegrid = torchvision.utils.make_grid(images[1:50], 10, 5).numpy()
        plt.imshow(np.transpose(np_imagegrid, (1, 2, 0)))
        plt.show()

images, labels = iter(train_loader).next()

# First visualise the original images
print('Original images')
show_image(torchvision.utils.make_grid(images[1:50],10,5))
plt.show()

# Reconstruct and visualise the images using the vae
print('\nVAE reconstruction:')
visualise_output(images, modelVAE)

In [None]:
def comparison(in, out):
  return psnr.psnr(in, out), ssim.ssim(in, out)

def difference_images():
  psnr_img = []
  ssim_img = []
  for image, _ in train_dataset:
    image = image.to(device)
    img, _, _ = modelVAE(image)
    temp1, temp2 = comparison(image, img)
    psnr_img.append(temp1)
    ssim_img.append(temp2)