# Tutorial 8.1: Deep Energy Models

![Status](https://img.shields.io/static/v1.svg?label=Status&message=Under%20development&color=red)

**Filled notebook:** 
[![View on Github](https://img.shields.io/static/v1.svg?logo=github&label=Repo&message=View%20On%20Github&color=lightgrey)](https://github.com/phlippe/uvadlc_notebooks/blob/master/docs/tutorial_notebooks/tutorial8/Deep_Energy_Models.ipynb)
[![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/phlippe/uvadlc_notebooks/blob/master/docs/tutorial_notebooks/tutorial8/Deep_Energy_Models.ipynb)  
**Empty notebook:** 
[![View on Github](https://img.shields.io/static/v1.svg?logo=github&label=Repo&message=View%20On%20Github&color=lightgrey)](https://github.com/phlippe/uvadlc_notebooks/blob/master/docs/tutorial_notebooks/tutorial8/Deep_Energy_Models.ipynb)
[![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/phlippe/uvadlc_notebooks/blob/master/docs/tutorial_notebooks/tutorial8/Deep_Energy_Models.ipynb)  
**Pre-trained models:** 

In [12]:
## Standard libraries
import os
import json
import math
import numpy as np 
import random

## Imports for plotting
import matplotlib.pyplot as plt
# %matplotlib inline 
# from IPython.display import set_matplotlib_formats
# set_matplotlib_formats('svg', 'pdf') # For export
# from matplotlib.colors import to_rgb
# import matplotlib
# matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()
sns.set()

## Progress bar
# from tqdm.notebook import tqdm

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
# Torchvision
import torchvision
from torchvision.datasets import MNIST
from torchvision import transforms
# PyTorch Lightning
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateLogger, ModelCheckpoint

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "../data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "../saved_models/tutorial8"

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

In [13]:
# Transformations applied on each image => make them a tensor and normalize
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))
                               ])

# Loading the training dataset. We need to split it into a training and validation part
train_dataset = MNIST(root=DATASET_PATH, train=True, transform=transform, download=True)
pl.seed_everything(42)
train_set, val_set = torch.utils.data.random_split(train_dataset, [50000, 10000])

# Loading the test set
test_set = MNIST(root=DATASET_PATH, train=False, transform=transform, download=True)

# We define a set of data loaders that we can use for various purposes later.
# Note that for actually training a model, we will use different data loaders
# with a lower batch size.
train_loader = data.DataLoader(train_set, batch_size=256, shuffle=True, drop_last=True, pin_memory=True, num_workers=4)
val_loader = data.DataLoader(val_set, batch_size=64, shuffle=False, drop_last=False, num_workers=4)
test_loader = data.DataLoader(test_set, batch_size=64, shuffle=False, drop_last=False, num_workers=4)

## CNN Model

In [14]:
act_fn_by_name = {
    "leakyrelu": nn.LeakyReLU,
    "gelu": nn.GELU,
    "relu": nn.ReLU
}

class CNNModel(nn.Module):

    def __init__(self, hidden_features=64, out_dim=1, act_fn_name="leakyrelu"):
        super().__init__()
        act_fn = act_fn_by_name.get(act_fn_name, nn.LeakyReLU)
        c_hid1 = hidden_features//2
        c_hid2 = hidden_features
        c_hid3 = hidden_features*2
        self.cnn_layers = nn.Sequential(
                nn.Conv2d(1, c_hid1, kernel_size=5, stride=2, padding=4), # [16x16]
                act_fn(),
                nn.Conv2d(c_hid1, c_hid2, kernel_size=3, stride=2, padding=1), #  [8x8]
                act_fn(),
                nn.Conv2d(c_hid2, c_hid3, kernel_size=3, stride=1, padding=1),
                act_fn(),
                nn.Conv2d(c_hid3, c_hid3, kernel_size=3, stride=2, padding=1), # [4x4]
                act_fn(),
                nn.Conv2d(c_hid3, c_hid3, kernel_size=3, stride=1, padding=1),
                nn.AvgPool2d(kernel_size=4)
        )
        self.out_layer = nn.Sequential(
                nn.Linear(c_hid3, c_hid3),
                act_fn(),
                nn.Linear(c_hid3, c_hid3),
                act_fn(),
                nn.Linear(c_hid3, out_dim)
        )


    def forward(self, x):
        x = self.cnn_layers(x).squeeze()
        x = self.out_layer(x).squeeze(dim=-1)
        return x

In [15]:
class Sampler:

    def __init__(self, model, img_shape, sample_size, max_len=8192):
        super().__init__()
        self.model = model
        self.img_shape = img_shape
        self.sample_size = sample_size
        self.max_len = max_len
        self.examples = [torch.rand((1,)+img_shape) for _ in range(self.sample_size)]

    def sample_new_exmps(self, steps=60, step_size=10):
        n_new = np.random.binomial(self.sample_size, 0.05)
        rand_imgs = torch.rand((n_new,) + self.img_shape) * 2 - 1
        old_imgs = torch.cat(random.choices(self.examples, k=self.sample_size-n_new), dim=0)
        inp_imgs = torch.cat([rand_imgs, old_imgs], dim=0).detach().to(device)

        inp_imgs = Sampler.generate_samples(self.model, inp_imgs, steps=steps, step_size=step_size)

        self.examples = list(inp_imgs.to(torch.device("cpu")).chunk(self.sample_size, dim=0)) + self.examples
        self.examples = self.examples[:self.max_len]
        self.model.train()
        return inp_imgs

    @staticmethod
    def generate_samples(model, inp_imgs, steps=60, step_size=10, return_img_per_step=False):
        model.eval()
        for p in model.parameters():
            p.requires_grad = False
        inp_imgs.requires_grad = True
        noise = torch.randn(inp_imgs.shape, device=device)
        imgs_per_step = []
        for _ in range(steps):
            noise.normal_(0, 0.01)
            inp_imgs.data.add_(noise.data)
            inp_imgs.data.clamp_(min=-1.0, max=1.0)

            out_imgs = -model(inp_imgs)
            out_imgs.sum().backward()
            inp_imgs.grad.data.clamp_(-0.03, 0.03)

            inp_imgs.data.add_(-step_size * inp_imgs.grad.data)
            inp_imgs.grad.detach_()
            inp_imgs.grad.zero_()
            inp_imgs.data.clamp_(min=-1.0, max=1.0)
            if return_img_per_step:
                imgs_per_step.append(inp_imgs.clone().detach())
        
        for p in model.parameters():
            p.requires_grad = True

        if return_img_per_step:
            return torch.stack(imgs_per_step, dim=0)
        else:
            return inp_imgs

In [16]:
class DeepEnergyModel(pl.LightningModule):
    
    def __init__(self, img_shape, batch_size, alpha=.1, **CNN_args):
        super().__init__()
        self.save_hyperparameters()
        
        self.cnn = CNNModel(**CNN_args)
        self.sampler = Sampler(self.cnn, img_shape=img_shape, sample_size=batch_size)
        
        self.alpha = alpha
        self.example_input_array = torch.zeros(1, *img_shape)
        
    def forward(self, x):
        z = self.cnn(x)
        return z
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-4, betas=(0.9, 0.999))
        scheduler = optim.lr_scheduler.StepLR(optimizer, 40, gamma=0.5)
        return [optimizer], [scheduler]
    
    def training_step(self, batch, batch_idx):
        real_imgs, _ = batch
        fake_imgs = self.sampler.sample_new_exmps(steps=60, step_size=10)
        
        inp_imgs = torch.cat([real_imgs, fake_imgs], dim=0)
        real_out, fake_out = self.cnn(inp_imgs).chunk(2, dim=0)
        
        reg_loss = self.alpha * (real_out ** 2 + fake_out ** 2).mean()
        diff_loss = fake_out.mean() - real_out.mean()
        
        loss = reg_loss + diff_loss
        
        result = pl.TrainResult(minimize=loss)
        result.log('loss', loss)
        result.log('loss_regularization', reg_loss)
        result.log('loss_difference', diff_loss)
        result.log('metrics_avg_real', real_out.mean())
        result.log('metrics_avg_fake', fake_out.mean())
        return result

In [17]:
class GenerateCallback(pl.Callback):

    def __init__(self, batch_size=8, vis_steps=8, num_steps=256, every_n_epochs=1):
        super().__init__()
        self.batch_size = batch_size # Number of images to generate
        self.vis_steps = vis_steps # Number of steps within generation to visualize
        self.num_steps = num_steps # Number of steps to take during generation
        self.every_n_epochs = every_n_epochs # Only save those images every N epochs (otherwise tensorboard gets quite large)

    def on_epoch_end(self, trainer, pl_module):
        if trainer.current_epoch % self.every_n_epochs == 0:
            # Reconstruct images
            pl_module.eval()
            start_imgs = torch.rand((self.batch_size,) + pl_module.hparams["img_shape"]).to(pl_module.device)
            start_imgs = start_imgs * 2 - 1
            imgs_per_step = Sampler.generate_samples(pl_module.cnn, start_imgs, steps=self.num_steps, step_size=10, return_img_per_step=True)
            pl_module.train()
            # Plot and add to tensorboard
            for i in range(imgs_per_step.shape[1]):
                step_size = self.num_steps // self.vis_steps
                imgs_to_plot = imgs_per_step[step_size-1::step_size,i]
                grid = torchvision.utils.make_grid(imgs_to_plot, nrow=imgs_to_plot.shape[0], normalize=True, range=(-1,1))
                trainer.logger.experiment.add_image("generation_%i" % i, grid, global_step=trainer.current_epoch)

In [18]:
class SamplerCallback(pl.Callback):
    
    def __init__(self, num_imgs=32, every_n_epochs=1):
        super().__init__()
        self.num_imgs = num_imgs
        self.every_n_epochs = every_n_epochs
        
    def on_epoch_end(self, trainer, pl_module):
        if trainer.current_epoch % self.every_n_epochs == 0:
            exmp_imgs = torch.cat(random.choices(pl_module.sampler.examples, k=self.num_imgs), dim=0)
            grid = torchvision.utils.make_grid(exmp_imgs, nrow=4, normalize=True, range=(-1,1))
            trainer.logger.experiment.add_image("sampler", grid, global_step=trainer.current_epoch)

In [19]:
class OutlierCallback(pl.Callback):
    
    def __init__(self, batch_size=1024):
        super().__init__()
        self.batch_size = batch_size
    
    def on_epoch_end(self, trainer, pl_module):
        with torch.no_grad():
            pl_module.eval()
            rand_imgs = torch.rand((self.batch_size,) + pl_module.hparams["img_shape"]).to(pl_module.device)
            rand_imgs = rand_imgs * 2 - 1.0
            rand_out = pl_module.cnn(rand_imgs).mean()
            pl_module.train()
        
        trainer.logger.experiment.add_scalar("rand_out", rand_out, global_step=trainer.current_epoch)

In [20]:
def train_model(**kwargs):
    # Create a PyTorch Lightning trainer with the generation callback
    trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, "MNIST"),
                         checkpoint_callback=ModelCheckpoint(save_weights_only=True, save_last=True),
                         gpus=1,
                         max_epochs=40,
                         gradient_clip_val=0.1,
                         callbacks=[GenerateCallback(every_n_epochs=5),
                                    SamplerCallback(every_n_epochs=5),
                                    OutlierCallback(),
                                    LearningRateLogger("epoch")
                                   ],
                        progress_bar_refresh_rate=100)
    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, "MNIST.ckpt")
    if os.path.isfile(pretrained_filename):
        model = DeepEnergyModel.load_from_checkpoint(pretrained_filename)
    else:
        model = DeepEnergyModel(**kwargs)
        trainer.fit(model, train_loader)
    # No testing as we are more interested in other properties
    return model

In [21]:
model = train_model(img_shape=(1,28,28), batch_size=train_loader.batch_size)

GPU available: True, used: True
I0924 21:10:31.854144 140265729595200 distributed.py:41] GPU available: True, used: True
TPU available: False, using: 0 TPU cores
I0924 21:10:31.855967 140265729595200 distributed.py:41] TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]
I0924 21:10:31.857367 140265729595200 distributed.py:41] CUDA_VISIBLE_DEVICES: [0]

  | Name | Type     | Params | In sizes       | Out sizes
---------------------------------------------------------------
0 | cnn  | CNNModel | 421 K  | [1, 1, 28, 28] | ?        
I0924 21:10:31.987430 140265729595200 lightning.py:1449] 
  | Name | Type     | Params | In sizes       | Out sizes
---------------------------------------------------------------
0 | cnn  | CNNModel | 421 K  | [1, 1, 28, 28] | ?        


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

Saving latest checkpoint..
I0924 21:10:51.487762 140265729595200 training_loop.py:1136] Saving latest checkpoint..







## Out-of-distribution detection

In [11]:
np.random.binomial(128,0.05)

5