<a href="https://colab.research.google.com/github/Yyzhang2000/learning-generative-models/blob/main/energy/01_energy_based_contrastive_divergence.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import json
import math
import numpy as np
import random

import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data

import torchvision
from torchvision.datasets import MNIST
from torchvision import transforms

from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

In [2]:
DATASET_PATH = "./data"
CHECKPOINT_PATH = "./models"

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

Device: cuda:0


In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
                               ])

# Loading the training dataset. We need to split it into a training and validation part
train_set = MNIST(root=DATASET_PATH, train=True, transform=transform, download=True)

# Loading the test set
test_set = MNIST(root=DATASET_PATH, train=False, transform=transform, download=True)

# We define a set of data loaders that we can use for various purposes later.
# Note that for actually training a model, we will use different data loaders
# with a lower batch size.
train_loader = data.DataLoader(train_set, batch_size=128, shuffle=True,  drop_last=True,  num_workers=2, pin_memory=True)
test_loader  = data.DataLoader(test_set,  batch_size=256, shuffle=False, drop_last=False, num_workers=2)

100%|██████████| 9.91M/9.91M [00:01<00:00, 5.07MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 58.3kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.27MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 7.42MB/s]


In [4]:
class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class CNNModel(nn.Module):
    def __init__(
            self,
            hidden_features = 32,
            out_dim = 1,
    ):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(1, hidden_features // 2, kernel_size = 5, stride = 2, padding = 4),
            Swish(),
            nn.Conv2d(hidden_features // 2, hidden_features,kernel_size=3, stride=2, padding=1),
            Swish(),
            nn.Conv2d(hidden_features, hidden_features * 2, kernel_size=3, stride=2, padding=1),
            Swish(),
            nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=2, padding=1),
            Swish(),
            nn.Flatten(),
            nn.Linear(hidden_features * 2 * 4, hidden_features * 2),
            Swish(),
            nn.Linear(hidden_features * 2, out_dim)
        )


    def forward(self, x):
        x = self.layers(x).squeeze(dim = -1)

        return x

In [5]:
class Sampler:
    def __init__(
            self,
            model,
            img_shape,
            sample_size,
            max_len = 8192
    ):
        self.model = model
        self.img_shape = img_shape
        self.sample_size = sample_size
        self.max_len = max_len
        self.examples = [(torch.rand((1,) + img_shape) * 2 - 1) for _ in range(self.sample_size)]

    def sample_new_exmps(self, steps = 60, step_size = 10):
        # Choose 95% of the batch from the buffer, 5% generate from scratch
        n_new = np.random.binomial(self.sample_size, 0.05)
        rand_imgs = torch.rand((n_new, ) + (self.img_shape) )* 2 - 1
        old_imgs = torch.cat(
            random.choices(self.examples, k = self.sample_size - n_new), dim = 0
        )
        inp_imgs = torch.cat(
            (rand_imgs, old_imgs), dim = 0
        )

        # Perform MCMC Sampling
        inp_imgs = Sampler.generate_samples(
            self.model,
            inp_imgs,
            steps = steps,
            step_size = step_size
        )

        return inp_imgs

    @staticmethod
    def generate_samples(
            model,
            inp_imgs,
            steps = 60,
            step_size = 10,
            return_img_per_step = False
    ):
        is_training = model.training

        model.eval()
        for p  in model.parameters():
            p.requires_grad = False
        inp_imgs.requires_grad = True # To get Score

        had_gradients_enable = torch.is_grad_enabled()
        torch.set_grad_enabled(True)


        noise = torch.randn(inp_imgs.shape, device = inp_imgs.device)

        imgs_per_step = []
        for _ in range(steps):

            # 1. Add Noise to the input
            noise.normal_(0, 0.005)
            inp_imgs.data.add_(noise.data)
            inp_imgs.data.clamp_(-1.0, 1.0)

            # 2. Calculate gradients for the input
            out_imgs = -model(inp_imgs)
            out_imgs.sum().backward()
            inp_imgs.grad.data.clamp_(-0.03, 0.03) # For stabilizing and preventing too high gradients

            # 3. Step according to the gradient
            inp_imgs.data.add_(-step_size * inp_imgs.grad.data)
            inp_imgs.data.clamp_(-1.0, 1.0)

            # Clear Gradient
            inp_imgs.grad.detach_()
            inp_imgs.grad.zero_()


            if return_img_per_step:
                imgs_per_step.append(inp_imgs.clone().detach())


        for p in model.parameters():
            p.requires_grad = True

        model.train(is_training)
        torch.set_grad_enabled(had_gradients_enable)

        if return_img_per_step:
            return torch.stack(imgs_per_step, dim = 0)
        else:
            return inp_imgs

In [7]:
def training_step(model, batch, sampler, alpha = 0.1):
    real_imgs, _ = batch
    small_noise = torch.randn_like(real_imgs) * 0.005
    real_imgs.add_(small_noise).clamp_(-1.0, 1.0)

    # Obtain samples
    fake_imgs = sampler.sample_new_exmps(
        steps = 60,
        step_size = 10
    )

    # Predict energy score for all imgs
    inp_imgs = torch.cat([real_imgs, fake_imgs], dim = 0)
    real_out, fake_out = model(inp_imgs).chunk(2, dim = 0)

    # Calculate loss
    reg_loss = alpha * (real_out ** 2 + fake_out ** 2).mean()
    cdiv_loss = fake_out.mean() - real_out.mean()
    loss = reg_loss + cdiv_loss
    return loss

In [15]:
model = CNNModel()
sampler = Sampler(model, img_shape=(1,28,28),sample_size=train_loader.batch_size)
optimizer = optim.Adam(
    model.parameters(),
    lr=1e-4,
    betas = (0.0, 0.999)
)
scheduler = optim.lr_scheduler.StepLR(
    optimizer,
    1,
    gamma = 0.97
)

In [18]:
losses = []
for epoch in range(60):
    print(f"{epoch + 1} / 60")
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}", leave=False)
    for batch in progress_bar:
        loss = training_step(model, batch, sampler)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        losses.append(loss.item())
        progress_bar.set_postfix(loss=f"{loss.item():.4f}")

    avg_loss = sum(losses[-len(train_loader):]) / len(train_loader)
    print(f"Average loss: {avg_loss:.4f}")

1 / 60


Epoch 1:   0%|          | 0/468 [00:00<?, ?it/s]

Average loss: -0.1286
2 / 60


Epoch 2:   0%|          | 0/468 [00:00<?, ?it/s]

Average loss: -0.1289
3 / 60


Epoch 3:   0%|          | 0/468 [00:00<?, ?it/s]

Average loss: -0.1289
4 / 60


Epoch 4:   0%|          | 0/468 [00:00<?, ?it/s]

Average loss: -0.1289
5 / 60


Epoch 5:   0%|          | 0/468 [00:00<?, ?it/s]

Average loss: -0.1289
6 / 60


Epoch 6:   0%|          | 0/468 [00:00<?, ?it/s]

Average loss: -0.1289
7 / 60


Epoch 7:   0%|          | 0/468 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:

plt.plot(losses)

In [None]:
torch.save(model.state_dict(), os.path.join(CHECKPOINT_PATH, "energy_mnist.pth"))