In [1]:
import torch

from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torch.utils.data import Subset

from torchvision.datasets import MNIST, CIFAR10

import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
from torch.distributions import kl_divergence

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.utilities.model_summary import ModelSummary

import wandb

import psutil

Encoder and decoder architectures are defined in `oord_encoders.py`:

In [2]:
from oord_encoders import OordEncoder as encoder
from oord_encoders import OordDecoder as decoder

Set some user defined parameters:

In [3]:
batch_size = 128                             # batch size
f_valid = 0.2                                # validation fraction (taken from training data)
learning_rate = 0.001                        # learning rate for optimiser
seed = 42                                    # random seed
max_epochs = 100                             # maximum number of training epochs

Set up some stuff for wandb:

In [4]:
config = {
			'learning_rate': learning_rate,
			'batch_size': batch_size,
			'seed': seed
			}

# initialise the wandb logger
wandb_logger = pl.loggers.WandbLogger(project='vae_tests', log_model=True, config=config)
wandb_config = wandb.config

Get some local specs:

In [5]:
num_cpus = psutil.cpu_count(logical=True)    # number of CPUs available

if torch.cuda.is_available():
	device='cuda'
else:
	device='cpu'

Load the dataset:

In [6]:
datadir = '/Users/user/_data'
datadir = '/Users/user/src'

totensor = transforms.ToTensor()
normalise= transforms.Normalize(0.5, 0.5)
crop = transforms.CenterCrop(28)

transform = transforms.Compose([
		totensor, 
		normalise,
		crop
		])

train_data = MNIST(root=datadir, train=True, download=True, transform=transform)
test_data = MNIST(root=datadir, train=False, download=True, transform=transform)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:1007)>



RuntimeError: Error downloading train-images-idx3-ubyte.gz

Build the data loaders:

In [7]:
# split out a validation set:
f_train = 1. - f_valid
n_train = int(f_train*len(train_data))
indices = list(range(len(train_data)))
	
train_sampler = Subset(train_data, indices[:n_train]) 
valid_sampler = Subset(train_data, indices[n_train:])   

# build data loaders:
train_loader = torch.utils.data.DataLoader(train_sampler, 
											batch_size=batch_size, 
											shuffle=True, 
											num_workers=num_cpus-1,
											persistent_workers=True
											)

val_loader = torch.utils.data.DataLoader(valid_sampler, 
				 						 batch_size=15,
			  							 shuffle=False, 
										 num_workers=num_cpus-1,
										 persistent_workers=True
										 )

test_loader = torch.utils.data.DataLoader(test_data, 
										  batch_size=len(test_data), 
										  shuffle=False, 
										  num_workers=num_cpus-1,
										  persistent_workers=True
										  )


Define the model:

In [8]:
class VAE(nn.Module):
    def __init__(self, n_chan, latent_dim):
        super().__init__()

        self.encoder = encoder(n_chan, latent_dim)
        self.decoder = decoder(n_chan, latent_dim)
    
    def forward(self, x):

        mu, logvar = self.encoder(x)
        
        noise = torch.randn_like(mu)
        z = noise * logvar.mul(.5).exp() + mu  # reparameterisation trick

        x_tilde = self.decoder(z)

        # calculate KL divergence:
        self.kl_div = self._kldiv(mu, logvar)
        
        return x_tilde

    def _kldiv(self, mu, logvar):

        # https://arxiv.org/pdf/1312.6114 Appendix B
        # torch function output verified equivalent to:
        # d_kl = -1.*(0.5*(1 + logvar - mu.pow(2) - logvar.exp())).sum()/mu.size(0)
        
        # KL divergence [latent sum; batch average]:
        q_z_x = Normal(mu, logvar.mul(.5).exp()) # variational posterior
        p_z = Normal(torch.zeros_like(mu), torch.ones_like(logvar)) # prior
        kl_div = kl_divergence(q_z_x, p_z).sum()/mu.size(0) # torch.distributions.kl_divergence
        
        return kl_div


In [9]:
class Compressor(pl.LightningModule):

    """lightning module to reproduce resnet18 baseline"""

    def __init__(self, n_chan, latent_dim, lr, beta=1):

        super().__init__()
        
        self.model = VAE(n_chan, latent_dim)    
        self.lr = lr
        self.beta = beta
        
        # just used to create model summary:
        self.example_input_array = torch.zeros(1, n_chan, 28, 28)
        
    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        loss = self._evaluate(batch, batch_idx, mode='train')
        return loss

    def validation_step(self, batch, batch_idx):
        _ = self._evaluate(batch, batch_idx, mode='val')
        return

    def test_step(self, batch, batch_idx):
        _ = self._evaluate(batch, batch_idx, mode='test')
        return

    def _evaluate(self, batch, batch_idx, mode):

        recon, nll, kl_div = self._get_losses(batch)
        loss = nll + self.beta*kl_div
        
        self.log(f'{mode}/recon', recon)
        self.log(f'{mode}/nll', nll)
        self.log(f'{mode}/kl_div', kl_div)
        self.log(f'{mode}/loss', loss)
        
        return loss

    def _get_losses(self, batch):

        x, _ = batch
        x_tilde = self.model(x)

        # MSE / reconstruction loss
        mse_loss = nn.MSELoss(reduction='none')
        recon = mse_loss(torch.squeeze(x_tilde), torch.squeeze(x)).sum()/x.size(0)
        
        # negative log-likelihood
        nll = -1.*Normal(x_tilde, torch.ones_like(x_tilde)).log_prob(x).sum()/x.size(0)
    
        # KL divergence
        kl_div = self.model.kl_div

            
        return recon, nll, kl_div

    def configure_optimizers(self):

        # should update this at some point to take optimizer from config file
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)

        return [optimizer]

In [10]:
model = Compressor(n_chan=1, latent_dim=8, lr=learning_rate).to(device)
summary = ModelSummary(model, max_depth=3)

In [11]:
print(summary)

  | Name                    | Type        | Params | Mode  | In sizes       | Out sizes                   
----------------------------------------------------------------------------------------------------------------
0 | model                   | VAE         | 40.5 K | train | [1, 1, 28, 28] | [1, 1, 28, 28]              
1 | model.encoder           | OordEncoder | 20.3 K | train | [1, 1, 28, 28] | [[1, 8, 7, 7], [1, 8, 7, 7]]
2 | model.encoder.layers    | Sequential  | 20.2 K | train | [1, 1, 28, 28] | [1, 8, 7, 7]                
3 | model.encoder.to_latent | Conv2d      | 72     | train | [1, 8, 7, 7]   | [1, 8, 7, 7]                
4 | model.decoder           | OordDecoder | 20.2 K | train | [1, 8, 7, 7]   | [1, 1, 28, 28]              
5 | model.decoder.layers    | Sequential  | 20.2 K | train | [1, 8, 7, 7]   | [1, 1, 28, 28]              
----------------------------------------------------------------------------------------------------------------
40.5 K    Trainable param

In [None]:
trainer = pl.Trainer(max_epochs=max_epochs,
					 num_sanity_val_steps=0, # 0 : turn off validation sanity check
				     accelerator=device, 
					 devices=1,
				     logger=wandb_logger) 

# train the model
trainer.fit(model, train_loader, val_dataloaders=val_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: Network error (SSLError), entering retry loop.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Network error (SSLError), entering retry loop.
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.



  | Name  | Type | Params | Mode  | In sizes       | Out sizes     
-------------------------------------------------------------------------
0 | model | VAE  | 40.5 K | train | [1, 1, 28, 28] | [1, 1, 28, 28]
-------------------------------------------------------------------------
40.5 K    Trainable params
0         Non-trainable params
40.5 K    Total params
0.162     Total estimated model params size (MB)
49        Modules in train mode
0         Modules in eval mode


Epoch 0: 100%|████████████████████| 375/375 [00:47<00:00,  7.88it/s, v_num=69cv]
[Aidation: |                                             | 0/? [00:00<?, ?it/s]
[Aidation:   0%|                                       | 0/800 [00:00<?, ?it/s]
[Aidation DataLoader 0:   0%|                          | 0/800 [00:00<?, ?it/s]
[Aidation DataLoader 0:   0%|                  | 1/800 [00:00<00:30, 25.82it/s]
[Aidation DataLoader 0:   0%|                  | 2/800 [00:00<00:28, 28.48it/s]
[Aidation DataLoader 0:   0%|                  | 3/800 [00:00<00:22, 35.80it/s]
[Aidation DataLoader 0:   0%|                  | 4/800 [00:00<00:18, 42.14it/s]
[Aidation DataLoader 0:   1%|                  | 5/800 [00:00<00:19, 41.79it/s]
[Aidation DataLoader 0:   1%|▏                 | 6/800 [00:00<00:17, 45.69it/s]
[Aidation DataLoader 0:   1%|▏                 | 7/800 [00:00<00:16, 49.07it/s]
[Aidation DataLoader 0:   1%|▏                 | 8/800 [00:00<00:16, 47.28it/s]
[Aidation DataLoader 0:   1