# ⚡️ Energy-Based Models

In this notebook, we'll walk through the steps required to train your own Energy Based Model to predict the distribution of a demo dataset

The code is adapted from the excellent ['Deep Energy-Based Generative Models' tutorial](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial8/Deep_Energy_Models.html) created by Phillip Lippe.

In [None]:
# %%
%load_ext autoreload
%autoreload 2

In [None]:
# %%
import numpy as np
import random
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets

## 0. Parameters <a name="parameters"></a>

In [None]:
# %%
IMAGE_SIZE = 32
CHANNELS = 1
STEP_SIZE = 10
STEPS = 60
NOISE = 0.005
ALPHA = 0.1
GRADIENT_CLIP = 0.03
BATCH_SIZE = 128
BUFFER_SIZE = 8192
LEARNING_RATE = 1e-4
EPOCHS = 60
LOAD_MODEL = False

## 1. Prepare the data <a name="parameters"></a>

In [None]:
# Load the data
# %%
# Load MNIST
train_set = datasets.MNIST(root="./data", train=True, download=True)
test_set = datasets.MNIST(root="./data", train=False, download=True)

In [None]:
# %%
def preprocess(imgs):
    """
    Normalize and reshape the images
    """
    imgs = imgs.astype("float32")
    imgs = (imgs - 127.5) / 127.5
    imgs = np.pad(imgs, ((0, 0), (2, 2), (2, 2)), constant_values=-1.0)
    imgs = np.expand_dims(imgs, axis=1)
    return imgs

In [None]:
# %%
x_train = preprocess(train_set.data.numpy())
x_test = preprocess(test_set.data.numpy())

train_loader = DataLoader(
    TensorDataset(torch.from_numpy(x_train)),
    batch_size=BATCH_SIZE,
    shuffle=True,
)

test_loader = DataLoader(
    TensorDataset(torch.from_numpy(x_test)),
    batch_size=BATCH_SIZE,
)

In [None]:
# %%
def display(images, n=10):
    images = images[:n]
    fig, axes = plt.subplots(1, n, figsize=(n * 1.2, 1.2))
    for i, ax in enumerate(axes):
        ax.imshow(images[i, 0], cmap="gray")
        ax.axis("off")
    plt.show()

# show a batch
sample = next(iter(train_loader))[0]
display(sample)

## 2. Build the EBM network <a name="train"></a>

In [None]:
# %%
class EBMNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 16, 5, stride=2, padding=2),
            nn.SiLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.SiLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.SiLU(),
            nn.Conv2d(64, 64, 3, stride=2, padding=1),
            nn.SiLU(),
            nn.Flatten(),
            nn.Linear(64 * 2 * 2, 64),
            nn.SiLU(),
            nn.Linear(64, 1),
        )

    def forward(self, x):
        return self.net(x)

In [None]:
# %%
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = EBMNet().to(device)

if LOAD_MODEL:
    model.load_state_dict(torch.load("./models/model.pt"))

## 3. Set up a Langevin sampler function <a name="sampler"></a>

In [None]:
# %%
def generate_samples(
    model, imgs, steps, step_size, noise, return_img_per_step=False
):
    imgs = imgs.clone().detach().to(device)
    imgs.requires_grad_(True)

    imgs_per_step = []

    for _ in range(steps):
        imgs.data += noise * torch.randn_like(imgs)
        imgs.data.clamp_(-1.0, 1.0)

        energy = model(imgs).sum()
        grads = torch.autograd.grad(energy, imgs)[0]
        grads = torch.clamp(grads, -GRADIENT_CLIP, GRADIENT_CLIP)

        imgs.data += step_size * grads
        imgs.data.clamp_(-1.0, 1.0)

        if return_img_per_step:
            imgs_per_step.append(imgs.detach().cpu())

    if return_img_per_step:
        return torch.stack(imgs_per_step)
    return imgs.detach()

## 4. Set up a buffer to store examples <a name="buffer"></a>

In [None]:
# %%
class Buffer:
    def __init__(self, model):
        self.model = model
        self.examples = [
            torch.rand(1, 1, IMAGE_SIZE, IMAGE_SIZE) * 2 - 1
            for _ in range(BATCH_SIZE)
        ]

    def sample_new_exmps(self, steps, step_size, noise):
        n_new = np.random.binomial(BATCH_SIZE, 0.05)
        rand_imgs = torch.rand(
            n_new, 1, IMAGE_SIZE, IMAGE_SIZE
        ) * 2 - 1

        old_imgs = torch.cat(
            random.choices(self.examples, k=BATCH_SIZE - n_new), dim=0
        )

        inp_imgs = torch.cat([rand_imgs, old_imgs], dim=0).to(device)

        inp_imgs = generate_samples(
            self.model, inp_imgs, steps, step_size, noise
        )

        self.examples = (
            torch.split(inp_imgs.cpu(), 1, dim=0) + self.examples
        )[:BUFFER_SIZE]

        return inp_imgs

In [None]:
# %%
class EBM:
    def __init__(self, model):
        self.model = model
        self.buffer = Buffer(model)
        self.optimizer = torch.optim.Adam(
            model.parameters(), lr=LEARNING_RATE
        )

    def train_step(self, real_imgs):
        real_imgs = real_imgs.to(device)
        real_imgs += NOISE * torch.randn_like(real_imgs)
        real_imgs.clamp_(-1.0, 1.0)

        fake_imgs = self.buffer.sample_new_exmps(
            STEPS, STEP_SIZE, NOISE
        )

        inp_imgs = torch.cat([real_imgs, fake_imgs], dim=0)

        real_out, fake_out = torch.chunk(
            self.model(inp_imgs), 2, dim=0
        )

        cdiv_loss = fake_out.mean() - real_out.mean()
        reg_loss = ALPHA * (real_out.pow(2).mean() + fake_out.pow(2).mean())
        loss = cdiv_loss + reg_loss

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return {
            "loss": loss.item(),
            "cdiv": cdiv_loss.item(),
            "reg": reg_loss.item(),
            "real": real_out.mean().item(),
            "fake": fake_out.mean().item(),
        }

In [None]:
ebm = EBM()

## 3. Train the EBM network <a name="train"></a>

In [None]:
# %%
for epoch in range(EPOCHS):
    metrics = []
    for (x,) in train_loader:
        metrics.append(ebm.train_step(x))

    if epoch % 5 == 0:
        avg = {k: np.mean([m[k] for m in metrics]) for k in metrics[0]}
        print(f"Epoch {epoch:03d} |", avg)

## 4. Generate images <a name="generate"></a>

In [None]:
# %%
start_imgs = torch.rand(10, 1, IMAGE_SIZE, IMAGE_SIZE) * 2 - 1
display(start_imgs)

In [None]:
# %%
gen_imgs = generate_samples(
    model,
    start_imgs,
    steps=1000,
    step_size=STEP_SIZE,
    noise=NOISE,
    return_img_per_step=True,
)

In [None]:
# %%
display(gen_imgs[-1])

In [None]:
# %%
imgs = []
for i in [0, 1, 3, 5, 10, 30, 50, 100, 300, 999]:
    imgs.append(gen_imgs[i][6])

display(torch.stack(imgs))