In [1]:
import pathlib
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as T
import pytorch_lightning as pl
import torchsummary
import numpy as np

from types import SimpleNamespace
from omegaconf import OmegaConf
from hydra.utils import instantiate, get_class

import tml
import wandb
wandb.finish() # if not finished

class Module(pl.LightningModule):
    
    def __init__(self, model, criterion, optimiser):
        super().__init__()
        self.model = model = instantiate(model)
        self.criterion = instantiate(criterion)
        self.optimiser = instantiate(optimiser, _args_=[self.model.parameters()])
        
    def configure_optimizers(self):
        return self.optimiser
    
    def training_step(self, batch, _):
        x, = batch
        y = self.model(x)
        loss = self.criterion(y, x)
        self.log("train/loss", loss.item())
        return loss
        
    def validation_step(self, batch, batch_i):
        x, = batch
        y = self.model(x)
        self.log("validation/loss", self.criterion(y, x).item())
        if batch_i == 0: # log images on the first batch
            y = self._reconstruction(y)
            self.logger.log_image("validation/reconstruction", self._get_image(x[:16],y[:16]))
    
    def test_step(self, batch, batch_i):
        x, = batch
        y = self._reconstruction(self.model(x))
        return x, y, F.mse_loss(y, x, reduction='none').view(x.shape[0],-1).sum(-1)

    def test_epoch_end(self, outputs):
        x, y, score = [torch.cat(z) for z in zip(*outputs)]
        index = np.argsort(-score.cpu().numpy()) # largest scores first
        # show top anomalies according reconsruction error
        x_ranked, y_ranked = x[index], y[index]
        self.logger.log_image("test/top_ground_truth", self._get_image(x_ranked[:128], n=16))
        self.logger.log_image("test/top_reconstruction", self._get_image(y_ranked[:128], n=16))
        # show raw scores
        self.logger.experiment.log({"test/score" : self._get_line_plot(score[index], columns=['x', 'score'], title="Score")})

    def _reconstruction(self, y):
        if y is not None and "logit" in str(self.criterion).lower():
            y = torch.sigmoid(y)
        return y
    
    def _get_line_plot(self, x, y=None, columns=['x','y'], title="Line Plot"):
        if y is None:
            y, x = x, torch.arange(x.shape[0])
        x, y = x.cpu().numpy(), y.cpu().numpy()
        data = [[i,j] for (i,j) in zip(x,y)]
        table = wandb.Table(data=data, columns=columns)
        return wandb.plot.line(table, columns[0], columns[1], title=title)
    
    def _get_image(self, *x, n=16):
        x = torch.cat(x, dim=2)
        x = torch.clip(x, 0, 1)
        return [torchvision.utils.make_grid(img, nrow=n, pad_value=1) for img in torch.split(x, n)]

class DataModule(pl.LightningDataModule):
    
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size
        self.train_dataset = None
        self.test_dataset = None
        self.validate_dataset = None
        self.train_val_split = 0.8
    
    def prepare_train_data(self):
        transform = lambda x: x.unsqueeze(1).to("cuda:0").float() / 255.
        dataset = torchvision.datasets.MNIST(pathlib.Path("~/.data/MNIST/").expanduser().resolve(), train=True, download=True)
        data = dataset.data[:int(self.train_val_split*dataset.data.shape[0])]
        self.train_dataset = TensorDataset(transform(data))
        
    def prepare_validation_data(self):
        transform = lambda x: x.unsqueeze(1).to("cuda:0").float() / 255.
        dataset = torchvision.datasets.MNIST(pathlib.Path("~/.data/MNIST/").expanduser().resolve(), train=True, download=True)
        data = dataset.data[int(self.train_val_split*dataset.data.shape[0]):]
        self.validate_dataset = TensorDataset(transform(data))
    
    def prepare_test_data(self):
        transform = lambda x: x.unsqueeze(1).to("cuda:0").float() / 255.
        dataset = torchvision.datasets.MNIST(pathlib.Path("~/.data/MNIST/").expanduser().resolve(), train=False, download=True)
        self.test_dataset = TensorDataset(transform(dataset.data))
        
    def prepare_data(self):
        self.prepare_train_data()
        self.prepare_validation_data()
        self.prepare_test_data()
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)    

    def val_dataloader(self):
        return DataLoader(self.validate_dataset, batch_size=self.batch_size, shuffle=False)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)
    
class MLPAutoEncoder(nn.Sequential):
    
    def __init__(self, input_shape, latent_shape, output_activation=nn.Identity()):
        self.input_shape = tml.shape.as_shape(input_shape)
        self.latent_shape = tml.shape.as_shape(latent_shape)
        latent_size = np.prod(self.latent_shape)
        input_size = np.prod(self.input_shape)
        layers = [
            tml.module.View(input_shape, (input_size,)),
            nn.Linear(input_size, 512), nn.LeakyReLU(),
            nn.Linear(512, latent_size), nn.LeakyReLU(), 
            nn.Linear(latent_size, 512), nn.LeakyReLU(),
            nn.Linear(512, input_size), output_activation,
            tml.module.View((input_size,), self.input_shape)
        ]
        super().__init__(*layers)
            
class ConvAutoEncoder(nn.Sequential):
    
    def __init__(self, input_shape, latent_shape, output_activation=nn.Identity()):
        self.input_shape = tml.shape.as_shape(input_shape)
        self.latent_shape = tml.shape.as_shape(latent_shape)
        latent_size = np.prod(self.latent_shape)
        input_size = np.prod(self.input_shape)
        # ignore latent_shape ? 
        layers = [
            nn.Conv2d(1, 8, kernel_size=7, stride=1), nn.LeakyReLU(),
            nn.Conv2d(8, 16, kernel_size=7, stride=1), nn.LeakyReLU(),
            nn.Conv2d(16, 32, kernel_size=5, stride=1), nn.LeakyReLU(),
            nn.Conv2d(32, 64, kernel_size=5, stride=1), nn.LeakyReLU(),
            nn.Conv2d(64, 64, kernel_size=5, stride=1), nn.LeakyReLU(),
            nn.Conv2d(64, 64, kernel_size=4, stride=1), nn.LeakyReLU(),
            
            nn.ConvTranspose2d(64, 64, kernel_size=4, stride=1), nn.LeakyReLU(),
            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=1), nn.LeakyReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=5, stride=1), nn.LeakyReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=5, stride=1), nn.LeakyReLU(),
            nn.ConvTranspose2d(16, 8, kernel_size=7, stride=1), nn.LeakyReLU(),
            nn.ConvTranspose2d(8, 1, kernel_size=7, stride=1), output_activation
        ]
        super().__init__(*layers)
    


In [None]:
#torchsummary.summary(MLPAutoEncoder(input_shape, latent_shape), device="cpu", input_size=input_shape)
#torchsummary.summary(ConvAutoEncoder(input_shape, latent_shape), device="cpu", input_size=input_shape)

config = OmegaConf.create(
""" 
input_shape : [1,28,28]
latent_shape : [16]
batch_size : 512
learning_rate : 0.0005

module:
    _target_ : __main__.Module
    optimiser : 
        _target_ : torch.optim.Adam
        lr : ${learning_rate}
    model : 
        _target_ : __main__.ConvAutoEncoder
        input_shape : ${input_shape}
        latent_shape: ${latent_shape}
        
    criterion: 
        _target_ : torch.nn.MSELoss
        
data_module: 
    _target_ : __main__.DataModule
    batch_size : ${batch_size}
    
trainer:
    _target_: pytorch_lightning.Trainer
    gpus: 1
    max_epochs: 30
    min_epochs: 10
    check_val_every_n_epoch: 4
    log_every_n_steps: 10
    logger: 
        _target_: pytorch_lightning.loggers.WandbLogger
        project: thesis-reconstruction
        log_model: all
        mode: online
""")

OmegaConf.resolve(config)
module = instantiate(config.module, _recursive_=False)
data_module = instantiate(config.data_module)

trainer = instantiate(config.trainer)
trainer.fit(module, datamodule=data_module)
trainer.test(module, datamodule=data_module)
wandb.finish()

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type            | Params
----------------------------------------------
0 | model     | ConvAutoEncoder | 477 K 
1 | criterion | MSELoss         | 0     
----------------------------------------------
477 K     Trainable params
0         Non-trainable params
477 K     Total params
1.911     Total estimated model params size (MB)
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(
[34m[1mwandb[0m: Currently logged in as: [33mbenedict-wilkins[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.11 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{}
--------------------------------------------------------------------------------


VBox(children=(Label(value=' 39.30MB of 39.30MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.…