In [79]:
import torch
import torch.utils
import torch.utils.data
from tqdm.auto import tqdm
from torch import nn
import argparse
import torch.nn.functional as F
import utils
import dataset
import os
import matplotlib.pyplot as plt

In [80]:
class NoiseScheduler():
    """
    Noise scheduler for the DDPM model

    Args:
        num_timesteps: int, the number of timesteps
        type: str, the type of scheduler to use
        **kwargs: additional arguments for the scheduler

    This object sets up all the constants like alpha, beta, sigma, etc. required for the DDPM model
    
    """
    def __init__(self, num_timesteps=50, type="linear", **kwargs):

        self.num_timesteps = num_timesteps
        self.type = type

        if type == "linear":
            self.init_linear_schedule(**kwargs)
        else:
            raise NotImplementedError(f"{type} scheduler is not implemented") # change this if you implement additional schedulers


    def init_linear_schedule(self, beta_start, beta_end):
        """
        Precompute the linear schedule for beta, alpha, and other required quantities
        """
        self.betas = torch.linspace(beta_start, beta_end, self.num_timesteps)
        self.alphas = 1.0 - self.betas
        self.alphaBar = torch.cumprod(self.alphas, dim=0)
        self.alphaBarPrev = torch.cat([torch.tensor([1.0]), self.alphaBar[:-1]])
        self.sqrtAlphaBar = torch.sqrt(self.alphaBar)
        self.sqrtOneMinusAlphaBar = torch.sqrt(1.0 - self.alphaBar)
        self.logOneMinusAlphaBar = torch.log(1.0 - self.alphaBar)
        self.sqrtRecipAlphaBar = torch.sqrt(1.0 / self.alphaBar)
        self.sqrtRecipMinusOneAlphaBar = torch.sqrt(1.0 / self.alphaBar - 1)

    def __len__(self):
        return self.num_timesteps

In [81]:
tempNoiseScheduler = NoiseScheduler(50, type="linear", beta_start=0.0, beta_end=1.0)
print(tempNoiseScheduler.alphaBar)

tensor([1.0000e+00, 9.7959e-01, 9.3961e-01, 8.8208e-01, 8.1007e-01, 7.2741e-01,
        6.3834e-01, 5.4715e-01, 4.5782e-01, 3.7373e-01, 2.9746e-01, 2.3068e-01,
        1.7419e-01, 1.2798e-01, 9.1411e-02, 6.3428e-02, 4.2717e-02, 2.7897e-02,
        1.7649e-02, 1.0805e-02, 6.3951e-03, 3.6543e-03, 2.0136e-03, 1.0684e-03,
        5.4513e-04, 2.6700e-04, 1.2533e-04, 5.6269e-05, 2.4115e-05, 9.8430e-06,
        3.8167e-06, 1.4020e-06, 4.8642e-07, 1.5883e-07, 4.8622e-08, 1.3892e-08,
        3.6856e-09, 9.0261e-10, 2.0263e-10, 4.1352e-11, 7.5953e-12, 1.2400e-12,
        1.7715e-13, 2.1692e-14, 2.2135e-15, 1.8069e-16, 1.1063e-17, 4.5154e-19,
        9.2150e-21, 0.0000e+00])


In [82]:
class DDPM(nn.Module):
    def __init__(self, n_dim=3, n_steps=200):
        """
        Noise prediction network for the DDPM

        Args:
            n_dim: int, the dimensionality of the data
            n_steps: int, the number of steps in the diffusion process
        We have separate learnable modules for `time_embed` and `model`. `time_embed` can be learned or a fixed function as well

        """
        super(DDPM, self).__init__()
        self.time_embed = nn.Embedding(n_steps, n_dim)
        self.model = nn.Sequential(
            nn.Linear(n_dim, 256),
            nn.ReLU(),
            nn.Linear(256, n_dim)
        )

    def forward(self, x, t):
        """
        Args:
            x: torch.Tensor, the input data tensor [batch_size, n_dim]
            t: torch.Tensor, the timestep tensor [batch_size]

        Returns:
            torch.Tensor, the predicted noise tensor [batch_size, n_dim]
        """
        t_embed = self.time_embed(t)
        return self.model(x + t_embed)

In [83]:
testDDPM = DDPM()
testInput = torch.randn(10, 3)
testT = torch.randint(0, 200, (10,))
# print(testInput)
print(testT)
forwardOutput = testDDPM.forward(testInput, testT)
print(testDDPM.time_embed.weight.shape)
print(forwardOutput)

tensor([ 73,  73, 181, 199,  49,   6,  57, 196,  58, 181])
torch.Size([200, 3])
tensor([[-0.1705,  0.0770, -0.3311],
        [-0.2355, -0.0527, -0.2224],
        [-0.6369, -0.2312, -0.0967],
        [-0.3130, -0.1753, -0.9912],
        [-0.2521, -0.1186, -0.3647],
        [-0.2142,  0.0179, -0.1793],
        [-0.3449,  0.1044, -1.3320],
        [-0.2200, -0.0561, -0.3684],
        [-0.3767, -0.1778, -0.5338],
        [-0.7660, -0.2815, -0.0863]], grad_fn=<AddmmBackward0>)


In [84]:
def train(model : DDPM, noise_scheduler : NoiseScheduler, dataloader, optimizer, epochs, run_name):
    """
    Train the model and save the model and necessary plots

    Args:
        model: DDPM, model to train
        noise_scheduler: NoiseScheduler, scheduler for the noise
        dataloader: torch.utils.data.DataLoader, dataloader for the dataset
        optimizer: torch.optim.Optimizer, optimizer to use
        epochs: int, number of epochs to train the model
        run_name: str, path to save the model
    """
    prevLoss = float("inf")
    model.train()
    device = next(model.parameters()).device
    for epoch in range(epochs):
        tqdmDataloader = tqdm(dataloader, desc = f"Epoch : {epoch + 1}")
        for x, _ in tqdmDataloader:
            x = x.to(device)
            optimizer.zero_grad()

            # Fix the batch size and the number of timesteps
            batch_size = x.shape[0]
            t = torch.randint(0, noise_scheduler.num_timesteps, (batch_size,))

            noise = torch.randn_like(x)
            alphaBarT = noise_scheduler.alphaBar[t].unsqueeze(1).to(device)
            sqrtAlphaBarT = noise_scheduler.sqrtAlphaBar[t].unsqueeze(1).to(device)
            sqrtOneMinusAlphaBarT = noise_scheduler.sqrtOneMinusAlphaBar[t].unsqueeze(1).to(device)

            xTilde = sqrtAlphaBarT * x + sqrtOneMinusAlphaBarT * noise
            noisePred = model(xTilde, t)
            loss = F.mse_loss(noisePred, noise)

            # Compute the loss
            loss.backward()
            optimizer.step()

            tqdmDataloader.set_postfix({"Loss" : loss.item()})
        if loss.item() < prevLoss:
            # Save the model if the loss is less than the previous best loss
            prevLoss = loss.item()
            torch.save(model.state_dict(), "models/" + run_name + ".pth")
            print(f"Model saved with loss {loss}")
        else:
            print(f"Model not saved with loss {loss}")

In [85]:
# Load the albatross dataset
moonsDataX, moonDatay = dataset.load_dataset("moons")
moonDataset = torch.utils.data.TensorDataset(moonsDataX, moonDatay.unsqueeze(1))
moonDataloader = torch.utils.data.DataLoader(moonDataset, batch_size=64, shuffle=True)

In [86]:
# Create and train the model with the given parameters
model = DDPM(n_dim=2, n_steps=200)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train(model, tempNoiseScheduler, moonDataloader, optimizer, 1, "test_moons")

Epoch : 1: 100%|██████████| 125/125 [00:00<00:00, 734.85it/s, Loss=0.516]

Model saved with loss 0.5161635875701904



