In [1]:
import torch
import torch.utils
import torch.utils.data
from tqdm.auto import tqdm
from torch import nn
import argparse
import torch.nn.functional as F
import utils
import dataset
import os
import matplotlib.pyplot as plt
from helperClasses import PositionalEncoding, UNetBlock, UNetModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class NoiseScheduler():
    """
    Noise scheduler for the DDPM model

    Args:
        num_timesteps: int, the number of timesteps
        type: str, the type of scheduler to use
        **kwargs: additional arguments for the scheduler

    This object sets up all the constants like alpha, beta, sigma, etc. required for the DDPM model
    
    """
    def __init__(self, num_timesteps=50, type="linear", **kwargs):

        self.num_timesteps = num_timesteps
        self.type = type

        if type == "linear":
            self.init_linear_schedule(**kwargs)
        else:
            raise NotImplementedError(f"{type} scheduler is not implemented") # change this if you implement additional schedulers


    def init_linear_schedule(self, beta_start, beta_end):
        """
        Precompute the linear schedule for beta, alpha, and other required quantities
        """
        self.betas = torch.linspace(beta_start, beta_end, self.num_timesteps)
        self.alphas = 1.0 - self.betas
        self.alphaBar = torch.cumprod(self.alphas, dim=0)
        self.alphaBarPrev = torch.cat([torch.tensor([1.0]), self.alphaBar[:-1]])
        self.sqrtAlphaBar = torch.sqrt(self.alphaBar)
        self.sqrtOneMinusAlphaBar = torch.sqrt(1.0 - self.alphaBar)
        self.logOneMinusAlphaBar = torch.log(1.0 - self.alphaBar)
        self.sqrtRecipAlphaBar = torch.sqrt(1.0 / self.alphaBar)
        self.sqrtRecipMinusOneAlphaBar = torch.sqrt(1.0 / self.alphaBar - 1)

    def __len__(self):
        return self.num_timesteps

In [3]:
tempNoiseScheduler = NoiseScheduler(200, type="linear", beta_start=0.0, beta_end=1.0)
# print(tempNoiseScheduler.alphaBar)

In [4]:
class DDPM(nn.Module):
    def __init__(self, n_dim=3, n_steps=200):
        """
        Noise prediction network for the DDPM

        Args:
            n_dim: int, the dimensionality of the data
            n_steps: int, the number of steps in the diffusion process
        We have separate learnable modules for `time_embed` and `model`. `time_embed` can be learned or a fixed function as well

        """
        super(DDPM, self).__init__()
        self.time_embed = nn.Embedding(n_steps, n_dim)
        self.timeMLP = PositionalEncoding(n_steps)
        self.model = UNetModel(inputChannels=n_dim,
                               outputChannels=n_dim,
                               timeEmbedDimension=n_steps)

    def forward(self, x, t):
        """
        Args:
            x: torch.Tensor, the input data tensor [batch_size, n_dim]
            t: torch.Tensor, the timestep tensor [batch_size]

        Returns:
            torch.Tensor, the predicted noise tensor [batch_size, n_dim]
        """
        t_embed = self.time_embed(t)
        t_embed = self.timeMLP(t_embed)
        return self.model(x, t_embed)

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batchSize = 4
imageSize = 64
dummyImages = torch.randn(batchSize, 3, imageSize, imageSize).to(device)
dummyTimesteps = torch.randint(0, 200, (batchSize,)).to(device)
dummyModel = DDPM(n_dim=3, n_steps=200).to(device)
dummyNoise = dummyModel(dummyImages, dummyTimesteps)
print(dummyNoise.shape)

torch.Size([4, 3, 1])


RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 1

In [65]:
def train(model : DDPM, noise_scheduler : NoiseScheduler, dataloader, optimizer, epochs, run_name):
    """
    Train the model and save the model and necessary plots

    Args:
        model: DDPM, model to train
        noise_scheduler: NoiseScheduler, scheduler for the noise
        dataloader: torch.utils.data.DataLoader, dataloader for the dataset
        optimizer: torch.optim.Optimizer, optimizer to use
        epochs: int, number of epochs to train the model
        run_name: str, path to save the model
    """
    print(f"[DDPM MODEL] Training the model : {run_name}")
    prevLoss = float("inf")
    model.train()
    device = next(model.parameters()).device
    for epoch in range(epochs):
        tqdmDataloader = tqdm(dataloader, desc = f"Epoch : {epoch + 1}")
        for x, _ in tqdmDataloader:
            x = x.to(device)
            optimizer.zero_grad()

            # Fix the batch size and the number of timesteps
            batch_size = x.shape[0]
            t = torch.randint(0, noise_scheduler.num_timesteps, (batch_size,))

            noise = torch.randn_like(x)
            alphaBarT = noise_scheduler.alphaBar[t].unsqueeze(1).to(device)
            sqrtAlphaBarT = noise_scheduler.sqrtAlphaBar[t].unsqueeze(1).to(device)
            sqrtOneMinusAlphaBarT = noise_scheduler.sqrtOneMinusAlphaBar[t].unsqueeze(1).to(device)

            xTilde = sqrtAlphaBarT * x + sqrtOneMinusAlphaBarT * noise
            noisePred = model(xTilde, t)
            loss = F.mse_loss(noisePred, noise)

            # Compute the loss
            loss.backward()
            optimizer.step()

            tqdmDataloader.set_postfix({"Loss" : loss.item()})
        if loss.item() < prevLoss:
            # Save the model if the loss is less than the previous best loss
            prevLoss = loss.item()
            torch.save(model.state_dict(), "models/" + run_name + ".pth")
            print(f"Model saved with loss {loss}")
        else:
            print(f"Model not saved with loss {loss}")

In [66]:
@torch.no_grad()
def sample(model, n_samples, noise_scheduler, return_intermediate=False): 
    """
    Sample from the model
    
    Args:
        model: DDPM
        n_samples: int
        noise_scheduler: NoiseScheduler
        return_intermediate: bool
    Returns:
        torch.Tensor, samples from the model [n_samples, n_dim]

    If `return_intermediate` is `False`,
            torch.Tensor, samples from the model [n_samples, n_dim]
    Else
        the function returns all the intermediate steps in the diffusion process as well 
        Return: [[n_samples, n_dim]] x n_steps
        Optionally implement return_intermediate=True, will aid in visualizing the intermediate steps
    """   
    print(f"[DDPM MODEL] Sampling from the model : {model.__class__.__name__}")
    device = next(model.parameters())
    # Dimensions for the model
    nDim = model.time_embed.weight.shape[1]
    # Start with pure Gaussian Noise at the last timestep
    xT = torch.randn(n_samples, nDim).to(device)
    intermediates = [xT.clone() if return_intermediate else None]

    for timestep in reversed(range(noise_scheduler.num_timesteps)):
        timeTensor = torch.full((n_samples,), timestep, dtype=torch.long).to(device)
        predictedNoise = model(xT, timeTensor)
        print(f"{timestep} : {predictedNoise}")

In [67]:
# Load the albatross dataset
moonsDataX, moonDatay = dataset.load_dataset("moons")
moonDataset = torch.utils.data.TensorDataset(moonsDataX, moonDatay.unsqueeze(1))
moonDataloader = torch.utils.data.DataLoader(moonDataset, batch_size=64, shuffle=True)

In [68]:
# Create and train the model with the given parameters
model = DDPM(n_dim=2, n_steps=200)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train(model, tempNoiseScheduler, moonDataloader, optimizer, 1, "test_moons")

[DDPM MODEL] Training the model : test_moons


Epoch : 1: 100%|██████████| 125/125 [00:00<00:00, 915.75it/s, Loss=0.553]

Model saved with loss 0.5527198910713196





In [69]:
# Sample from the model  and visualise the results
samples = sample(model, 1000, tempNoiseScheduler, return_intermediate=True)

[DDPM MODEL] Sampling from the model : DDPM


RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.FloatTensor instead (while checking arguments for embedding)