In [1]:
import torch
import torch.nn as nn

import numpy as np

from pathlib import Path

import random

import sys
import import_ipynb
dir = Path('notebooks')
sys.path.insert(0, str(dir.resolve()))
import globals

importing Jupyter notebook from globals.ipynb


In [2]:
BUFFER_SIZE = 5000
NOISE = 0.005
STEP_SIZE = 10
STEPS = 100
ALPHA = 0.1


In [3]:
class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv2d_1 = nn.Conv2d(in_channels = globals.CHANNELS, out_channels = 32, kernel_size = 5, stride = 2, padding = 1)
        self.conv2d_2 = nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = 4, stride = 2, padding = 1)
        self.conv2d_3 = nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 4, stride = 2, padding = 1)
        self.conv2d_4 = nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 4, stride = 2, padding = 1)
        self.flatten = nn.Flatten()

        self.swish = Swish()
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, 0.0, 0.1)

    def forward(self, data):
        batch_size = data.shape[0]
        x = self.conv2d_1(data)
        x = self.swish(x)
        x = self.conv2d_2(x)
        x = self.swish(x)
        x = self.conv2d_3(x)
        x = self.swish(x)
        x = self.conv2d_4(x)
        x = self.swish(x)
        x = self.flatten(x)
        x = nn.Linear(in_features = x.shape[1], out_features = batch_size)(x)
        x = self.swish(x)
        ebm_output = nn.Linear(in_features = batch_size, out_features = 1)(x)
        return ebm_output


In [4]:
# Langevin Sampling Function

def generate_samples(model, inp_imgs, steps, step_size, noise):
    imgs_per_step = []
    for _ in range(steps): 
        torch.autograd.set_detect_anomaly(True)

        noise = torch.randn(size = inp_imgs.shape) * noise

        inp_imgs.requires_grad = True

        out_score = torch.mean(model(inp_imgs))

        grads = torch.autograd.grad(outputs = out_score, 
                                    inputs = inp_imgs,
                                    only_inputs = True)
        
        perturbation = step_size * torch.cat(grads) + noise
        inp_imgs = inp_imgs.detach() + perturbation

    return inp_imgs.detach()



In [5]:
class Buffer:
    def __init__(self, model):
        super(Buffer, self).__init__()
        self.model = model
        self.examples = torch.concat([torch.rand(size = (1, globals.CHANNELS, globals.IMAGE_SIZE, globals.IMAGE_SIZE)) * 2 - 1
                                      for _ in range(globals.BATCH_SIZE)],
                                    dim = 0
                                    )
        

    def sample_new_exmps(self, steps, step_size, noise, batch_size):
        n_new = np.random.binomial(n = globals.BATCH_SIZE, p = 0.05)

        rand_imgs = (torch.rand(size = (n_new, globals.CHANNELS, globals.IMAGE_SIZE, globals.IMAGE_SIZE)) * 2 - 1) 
        
        old_imgs = torch.stack(
            random.choices(population = self.examples, k = batch_size - n_new)
        )


        inp_imgs = torch.concat(tensors = [rand_imgs, old_imgs], dim = 0)

        inp_imgs = generate_samples(model = self.model, inp_imgs = inp_imgs, steps = steps, step_size = step_size, noise = noise)
        
        new_img_examples = torch.concat(torch.split(inp_imgs, split_size_or_sections = batch_size, dim = 0), dim = 0)

        self.examples = torch.concat(tensors = [new_img_examples, self.examples ], dim = 0)

        self.examples = self.examples[:BUFFER_SIZE]

        return inp_imgs

In [6]:
class EBM(nn.Module):
    def __init__(self):
        super(EBM, self).__init__()
        self.model = Model()
        self.buffer = Buffer(model = self.model)
        self.opt = torch.optim.Adam(params = self.model.parameters(), lr = 0.0001, betas = (0.9, 0.999))

    def forward(self, real_imgs):
        batch_size = real_imgs.shape[0]
        self.opt.zero_grad()
        real_imgs += torch.normal(size = real_imgs.shape, mean = 0.0, std = NOISE)
        fake_imgs = self.buffer.sample_new_exmps(steps = STEPS, step_size = STEP_SIZE, noise = NOISE, batch_size = batch_size)

        real_scores = self.model(real_imgs)
        fake_scores = self.model(fake_imgs)


        cdiv_loss = torch.mean(fake_scores) - torch.mean(real_scores)

        reg_loss = ALPHA * torch.mean(
            real_scores ** 2 + fake_scores ** 2
        )

        loss = cdiv_loss + reg_loss
        
        print('real_scores is: {}'.format(torch.mean(real_scores)))
        print('fake_scores is: {}'.format(torch.mean(fake_scores)))

        print('cdiv_loss is: {}'.format(cdiv_loss))
        print('reg_loss is: {}'.format(reg_loss))


        print('Overall loss is: {}'.format(loss))

        loss.backward()
        self.opt.step()

        return loss
    
    