In [1]:
import os
import random

import torch
from torch import optim, nn, utils, Tensor
import numpy as np
from PIL import Image
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import pytorch_lightning as pl

In [4]:
class LitVAE(pl.LightningModule):
    def __init__(self, alpha = 1):
        #Autoencoder only requires 1 dimensional argument since input and output-size is the same
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(784,196),nn.ReLU(),nn.BatchNorm1d(196,momentum = 0.7),
                                     nn.Linear(196,49),nn.ReLU(),nn.BatchNorm1d(49,momentum = 0.7),
                                     nn.Linear(49,28),nn.LeakyReLU())
        self.hidden2mu = nn.Linear(28,28)
        self.hidden2log_var = nn.Linear(28,28)
        self.alpha = alpha
        self.decoder = nn.Sequential(nn.Linear(28,49),nn.ReLU(),
                                     nn.Linear(49,196),nn.ReLU(),
                                     nn.Linear(196,784),nn.Tanh())
        self.data_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=(0.5,), std=(0.5,))])


    def encode(self,x):
       hidden = self.encoder(x)
       mu = self.hidden2mu(hidden)
       log_var = self.hidden2log_var(hidden)
       return mu,log_var

    def decode(self,x):
        x = self.decoder(x)
        return x

    def reparametrize(self,mu,log_var):
        #Reparametrization Trick to allow gradients to backpropagate from the
        #stochastic part of the model
        sigma = torch.exp(0.5*log_var)
        z = torch.randn(size = (mu.size(0),mu.size(1)))
        z= z.type_as(mu) # Setting z to be .cuda when using GPU training
        return mu + sigma*z

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=1e-3)

    def training_step(self, batch, batch_idx):
        x, _ = batch
        batch_size = x.size(0)
        x = x.view(batch_size, -1)
        mu, log_var = self.encode(x)

        kl_loss =  (-0.5*(1+log_var - mu**2- torch.exp(log_var)).sum(dim = 1)).mean(dim =0)
        hidden = self.reparametrize(mu, log_var)
        x_out = self.decode(hidden)

        recon_loss_criterion = nn.MSELoss() #Reconstruction Loss
        recon_loss = recon_loss_criterion(x,x_out)
        loss = recon_loss*self.alpha + kl_loss

        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, _ = batch
        batch_size = x.size(0)
        x = x.view(batch_size, -1)
        mu, log_var = self.encode(x)

        kl_loss =  (-0.5*(1+log_var - mu**2- torch.exp(log_var)).sum(dim = 1)).mean(dim =0)
        hidden = self.reparametrize(mu, log_var)
        x_out = self.decode(hidden)

        recon_loss_criterion = nn.MSELoss() #Reconstruction Loss
        recon_loss = recon_loss_criterion(x,x_out)
        loss = recon_loss*self.alpha + kl_loss

        self.log('val_kl_loss', kl_loss, on_step=True, on_epoch=True)
        self.log('val_recon_loss', recon_loss, on_step=True, on_epoch=True)
        self.log('val_loss', loss, on_step=True, on_epoch=True)

        return x_out, loss

    def validation_epoch_end(self, outputs):
        if not os.path.exists('vae_images'):
            os.makedirs('vae_images')
        choice = random.choice(outputs)
        output_sample = choice[0]
        output_sample = output_sample.reshape(-1,1,28,28)
        output_sample = self.scale_image(output_sample)
        utils.save_image(output_sample, f"vae_images/epoch_{self.current_epoch+1}.png")

    def forward(self,x):
       batch_size = x.size(0)
       x = x.view(batch_size,-1)
       mu,log_var = self.encode(x)
       hidden = self.reparametrize(mu,log_var)
       return self.decoder(hidden)


In [5]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=(0.5,), std=(0.5,))])

mnist_train = MNIST('data/', download=True, train=True, transform=transform)
train_dataloader = utils.data.DataLoader(mnist_train, batch_size=64)

mnist_val = MNIST('data/', download=True, train=False, transform=transform)
val_loader = utils.data.DataLoader(mnist_val, batch_size=64)

trainer = pl.Trainer(auto_lr_find=True, max_epochs=25)
trainer.fit(LitVAE)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


TypeError: Unwrapping the module did not yield a `LightningModule`, got <class 'type'> instead.