In [81]:
import os 
import json 
import math 
import numpy as np 

import torch 
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim 
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets

import random

import logging 

In [7]:
import sys 
sys.path.append("../utils")

from misc import seed_everything

In [8]:
DATASET_PATH = "./data"
CHECKPOINT_PATH = "./checkpoint"

In [9]:
seed_everything(42)

In [67]:
from dataclasses import dataclass
@dataclass
class Params:
    device = "cuda" if torch.cuda.is_available() else "cpu"

    batch_size = 128
    img_shape = (1,28,28)
    lr = 1e-4
    alpha = 0.1
    beta1=0.0
    beta2 = 0.99

params = Params()


'cpu'

# Get Dataset

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)


train_dataset = datasets.MNIST(root=DATASET_PATH, train=True, download=True, transform=transform)
valid_dataset = datasets.MNIST(root=DATASET_PATH, train=False, download=True, transform=transform)


train_dataloader = data.DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True, num_workers=4, pin_memory=True)
valid_dataloader = data.DataLoader(valid_dataset, batch_size=128, shuffle=False, drop_last=False, num_workers=1, pin_memory=True)

# Define Model


In [12]:
class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class Net(nn.Module):
    def __init__(self, hidden_features = 32, out_dim = 1, **kwargs):
        super().__init__()

        c_hid1 = hidden_features // 2
        c_hid2 = hidden_features
        c_hid3 = hidden_features * 2

        self.layers = nn.Sequential(
            nn.Conv2d(1, c_hid1, kernel_size=5, stride=2, padding=4),
            Swish(),
            nn.Conv2d(c_hid1, c_hid2, kernel_size=3, stride=2, padding=1),
            Swish(),
            nn.Conv2d(c_hid2, c_hid3, kernel_size=3, stride=2, padding=1),
            Swish(),
            nn.Flatten(),
            nn.Linear(c_hid3 *4, c_hid3),
            Swish(),
            nn.Linear(c_hid3, out_dim)
        )
        
    
    def forward(self, x):
        return self.layers(x).squeeze(dim=-1)

# Define Sampler Buffer 

In [56]:
class Sampler:
    def __init__(self, model, img_shape, sample_size, max_len = 8192):
        super().__init__()
        self.model = model 
        self.img_shape = img_shape 
        self.sample_size = sample_size 
        self.max_len = max_len 
        self.examples = [
            (torch.rand((1,) + img_shape) * 2 - 1)
            for _ in range(self.sample_size)
        ]

    def sample_new_exmps(self, steps = 60, step_size = 10):
        n_new = np.random.binomial(self.sample_size, 0.05)
        rand_imgs = torch.rand((n_new,) + self.img_shape) * 2 - 1 
        old_imgs = torch.cat(random.choices(self.examples, k=self.sample_size - n_new), dim=0)
        inp_imgs = torch.cat([rand_imgs, old_imgs], dim=0).detach().to(params.device)

        # Perform MCMC sampling
        inp_imgs = Sampler.generate_samples(self.model, inp_imgs, steps, step_size)

        # Add new images to the buffer and remove old ones if needed
        self.examples = list(
            inp_imgs.to(torch.device('cpu')).chunk(self.sample_size, dim=0)
        ) + self.examples
        self.examples = self.examples[:self.max_len]
        return inp_imgs

    @staticmethod
    def generate_samples(model, inp_imgs, steps=60, step_size = 10, return_img_per_step=False):

        is_training = model.training 
        model.eval()
        for p in model.parameters():
            p.requires_grad = False 
        inp_imgs.requires_grad = True  # Gradient with respect to the input image

        # Enable gradient calculation if not already the case
        has_gradients_enabled = torch.is_grad_enabled()
        torch.set_grad_enabled(True )

        noise = torch.randn(inp_imgs.shape, device = params.device)
        imgs_per_step = []

        # Loop over K (steps)
        for _ in range(steps):
            noise.normal_(0, 0.005)
            inp_imgs.data.add_(noise.data)
            inp_imgs.data.clamp_(min=-1.0, max=1.0)

            out_imgs = -model(inp_imgs) # -E(x)
            out_imgs.sum().backward()
            inp_imgs.grad.data.clamp_(-0.03, 0.03)

            # Apply gradients to current samples
            inp_imgs.data.data_(-step_size * inp_imgs.grad.data)
            inp_imgs.grad.detach_()
            inp_imgs.grad.zero_()
            inp_imgs.data.clamp_(min=1.0, max=1.0)

            if return_img_per_step:
                imgs_per_step.append(inp_imgs.clone.detach())

        for p in model.parameters():
            p.requires_grad = True 
        model.train(is_training) 

        # Reset gradient calculation to setting before this function
        torch.set_grad_enabled(has_gradients_enabled)
        if return_img_per_step:
            return torch.stack(imgs_per_step, dim=0)
        else:
            return inp_imgs

# Training the model

In [57]:
params.batch_size

128

In [59]:
net = Net()
sampler = Sampler(net, img_shape=params.img_shape, sample_size=params.batch_size)
example_input_array = torch.zeros(1, *params.img_shape)
optimizer = optim.Adam(net.parameters(), lr=params.lr, betas=(params.beta1, params.beta2))
scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.97)

# Callback

In [61]:
class GenerateCallback:
    def __init__(self, batch_size =8, vis_steps = 8, num_steps = 256, every_n_epochs = 5):
        self.batch_size = batch_size 
        self.vis_steps = vis_steps 
        self.num_steps=num_steps
        self.every_n_epochs = every_n_epochs
    
    def on_epoch_end(self, trainer , model):
        if trainer.current_epoch % self.every_n_epochs == 0:
            # Generate images 
            imgs_per_step = self.generate_imgs(model)
            
            # for i in range(imgs_per_step.shape[1]):
            #     step_size = self.num_steps // self.vis_steps 
            #     imgs_to_plot = imgs_per_step[step_size -1 :: step_size, i]
            #     grid = torchvision.utils.make_grid(imgs_to_plot, nrow=imgs_to_plot.shape[0], normalize=True, range=(-1, 1))
                
    
    def generate_imgs(self, model):
        model.eval()
        start_imgs = torch.rand((self.batch_size,) + params.img_shape).to(params.device)
        start_imgs = start_imgs * 2 - 1 
        torch.set_grad_enabled(True)
        imgs_per_step = Sampler.generate_samples(model, start_imgs, steps=self.num_steps, step_size = 10, return_img_per_step = True)
        torch.set_grad_enabled(False)
        model.train()
        return imgs_per_step
        
    

In [62]:
class SamplerCallbacl():
    def __init__(self, num_imgs=32, every_n_epochs =5):
        self.num_imgs = num_imgs 
        self.every_n_epochs = every_n_epochs 

    def on_epoch_end(self, trainer, sampler):
        if trainer.current_epoch % self.every_n_epochs == 0:
            exmp_imgs = torch.cat(random.choices(sampler.examples, k=self.num_imgs), dim=0)
            grid = torchvision.utils.make_grid(
                exmp_imgs, nrow=4, normalize=True, range=(-1, 1)
            )

In [64]:
class OutlierCallback():

    def __init__(self, batch_size=1024):
        super().__init__()
        self.batch_size = batch_size

    def on_epoch_end(self, trainer, model):
        with torch.no_grad():
            model.eval()
            rand_imgs = torch.rand(
                (self.batch_size,) + params.img_shape
            ).to(params.device)
            rand_imgs = rand_imgs * 2 - 1.0
            rand_out = model(rand_imgs).mean()
            model.train()


# Define Trainer 

In [69]:
from trainer import Trainer

In [83]:
class NetTrainer(Trainer):
    def __init__(self, model, criterion, optimizer, lr_scheduler, metrics, dataloaders,  params,):
        
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer 
        self.lr_scheduler = lr_scheduler 
        self.metrics = metrics 
        self.dataloaders = dataloaders
        
        self.train_losses = []
        self.valid_losses = []

    def train_one_epoch(self):
        
        self.model.train()
        loss_avg  = AverageMeter()
        
        train_dataloade = self.dataloaders['train']
        with tqdm(total=len(train_dataloader)) as t:
            for real_imgs, _ in train_dataloader:
                small_noise = torch.randn_like(real_imgs) * 0.005
                real_imgs.add_(small_noise).clamp_(min=-1.0, max=1.0)
                
                
                # Obtain samples 
                fake_imgs = self.sampler.sample_new_exmps(steps=60, step_size = 10)
                
                # Predict energy score for all images
                inp_imgs = torch.cat([real_imgs, fake_imgs], dim=0)
                real_out, fake_out = self.model(inp_imgs).chunk(2, dim=0)
                
                reg_loss = self.params.alpha * (real_out ** 2 + fake_out ** 2).mean()
                cdiv_loss = fake_out.mean() - real_out.mean()
                loss = reg_loss + cdiv_loss 
                
                
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                loss_avg.update(loss.item())
                t.set_postfix(loss="{:05.3f}".format(loss_avg()))
                t.update()

    @torch.no_grad()
    def valid_one_epoch(self):
        self.model.eval()
        
        valid_dataloader = self.dataloaders['valid']
        
        loss = AverageMeter()
        
        for real_imgs, _ in valid_dataloader:
            fake_imgs = torch.rand_like(real_imgs) * 2 - 1 
            
            inp_imgs = torch.cat([real_imgs, fake_imgs], dim=0)
            real_out, fake_out = self.model(inp_imgs).chunk(2, dim=0)
            
            cdiv = fake_out.mean() - real_out.mean()
            
            loss.update(cdiv.item())
            
        

    def train(self):
        for epoch in self.params.epochs:
            logging.info(f"Epoch {epoch + 1} / {self.params.epochs}")
            print(f"Epoch {epoch + 1} / {self.params.epochs}")
            
            self.train_one_epoch()
            self.valid_one_epoch()
