**Install the required packages**

In [1]:
#!pip install torch numpy matplotlib scikit-learn tqdm

**Import Libraries**

In [2]:
import torch
import numpy as np
from torch import nn
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.datasets import make_swiss_roll

**Device Setup**

In [3]:
# Determine the device to use (GPU if available, otherwise CPU)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

**Data Generation**

In [4]:
# Function to generate a batch of samples from the Swiss Roll dataset
def sample_batch(batch_size):
    """
    Generate a batch of samples from the Swiss Roll dataset.

    Args:
        batch_size (int): Number of samples to generate.

    Returns:
        np.ndarray: Normalized 2D data samples.
    """
    # Generate Swiss Roll dataset with noise
    data, _ = make_swiss_roll(n_samples=batch_size, noise=0.25)
    # Project to 2D (use only the 1st and 3rd columns)
    data = data[:, [0, 2]]
    # Normalize the data to improve training stability
    data = data / 10.0
    return data

**MLP Model Definition**

In [5]:
# MLP Model Definition
class MLP(nn.Module):
    def __init__(self, N=40, data_dim=2, hidden_dim=128):
        """
        Initialize the MLP model.

        Args:
            N (int): Number of layers in the network tail.
            data_dim (int): Dimensionality of the input data.
            hidden_dim (int): Number of neurons in hidden layers.
        """
        super(MLP, self).__init__()
        # Define the network head with two linear layers and ReLU activations
        self.network_head = nn.Sequential(
            nn.Linear(data_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        # Define the network tail with N layers, each having a linear layer and ReLU activation
        self.network_tail = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, data_dim * 2)  # Output mean and log variance
            ) for _ in range(N)
        ])

    def forward(self, x, t: int):
        """
        Forward pass through the MLP model.

        Args:
            x (torch.Tensor): Input data tensor.
            t (int): Time step index.

        Returns:
            torch.Tensor: Output of the network tail at time step t.
        """
        h = self.network_head(x)  # Pass through the network head
        return self.network_tail[t](h)  # Pass through the network tail for the given time step


**Diffusion Model Definition**

In [6]:
# Diffusion Model Definition
class DiffusionModel(nn.Module):
    def __init__(self, model: nn.Module, n_steps=40, device=device):
        """
        Initialize the Diffusion Model.

        Args:
            model (nn.Module): The neural network model used for denoising.
            n_steps (int): Number of diffusion steps.
            device (str): The device to run the model on ('cpu' or 'cuda').
        """
        super().__init__()
        self.model = model
        self.device = device

        # Initialize diffusion parameters
        betas = torch.sigmoid(torch.linspace(-18, 10, n_steps)) * (3e-1 - 1e-5) + 1e-5
        self.beta = betas
        self.alpha = 1. - betas
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
        self.n_steps = n_steps
        self.sigma2 = betas

    def forward_process(self, x0, t):
        """
        Perform the forward diffusion process.

        Args:
            x0 (torch.Tensor): Initial data tensor.
            t (int): Time step.

        Returns:
            tuple: (mu_posterior, sigma_posterior, xt)
                - mu_posterior (torch.Tensor): Posterior mean.
                - sigma_posterior (torch.Tensor): Posterior variance.
                - xt (torch.Tensor): Data at time step t.
        """
        t = t - 1  # Adjust indexing to start at 0
        alpha_cum_forward = self.alpha_bar[t]
        noise = torch.randn_like(x0)  # Add Gaussian noise
        xt = x0 * torch.sqrt(alpha_cum_forward) + noise * torch.sqrt(1. - alpha_cum_forward)

        # Compute posterior mean and variance
        mu_posterior = (x0 * torch.sqrt(alpha_cum_forward) + xt * torch.sqrt(self.beta[t])) / (
            torch.sqrt(self.alpha_bar[t] / self.alpha[t]) + 1. / torch.sqrt(self.alpha[t]))
        sigma_posterior = torch.sqrt(1. / (1. / (1. - self.alpha_bar[t] / self.alpha[t]) + 1. / self.sigma2[t]))

        return mu_posterior, sigma_posterior, xt

    def reverse(self, xt, t):
        """
        Perform the reverse diffusion process.

        Args:
            xt (torch.Tensor): Data tensor at time step t.
            t (int): Time step.

        Returns:
            tuple: (mu, sigma, samples)
                - mu (torch.Tensor): Mean predicted by the model.
                - sigma (torch.Tensor): Standard deviation predicted by the model.
                - samples (torch.Tensor): Generated samples.
        """
        t = t - 1  # Adjust indexing to start at 0
        if t == 0:
            return None, None, xt  # No reverse process at t=0
        mu, h = self.model(xt, t).chunk(2, dim=1)  # Split the output into mean and log-variance
        sigma = torch.exp(h * 0.5)  # Compute standard deviation from log-variance
        return mu, sigma, mu + torch.randn_like(xt) * sigma  # Generate samples

    def sample(self, size):
        """
        Generate samples from the model.

        Args:
            size (int): Number of samples to generate.

        Returns:
            list of torch.Tensor: List of generated samples at each diffusion step.
        """
        noise = torch.randn((size, 2), device=self.device)  # Start with noise
        samples = [noise]
        for t in range(self.n_steps, 0, -1):
            _, _, x = self.reverse(samples[-1], t)  # Reverse the diffusion process
            samples.append(x)
        return samples

**Training Function**

In [7]:
# Training Function
def train(model, optimizer, nb_epochs=150000, batch_size=64000):
    """
    Train the diffusion model.

    Args:
        model (DiffusionModel): The diffusion model to train.
        optimizer (torch.optim.Optimizer): Optimizer for the model parameters.
        nb_epochs (int): Number of training epochs.
        batch_size (int): Size of each data batch.
    """
    for _ in tqdm(range(nb_epochs)):
        x0 = torch.from_numpy(sample_batch(batch_size)).float().to(device)  # Generate a batch of data
        t = np.random.randint(2, 41)  # Random time step
        mu_posterior, sigma_posterior, xt = model.forward_process(x0, t)  # Forward process
        mu, sigma, _ = model.reverse(xt, t)  # Reverse process

        # Compute KL divergence loss
        loss = (torch.log(sigma) - torch.log(sigma_posterior) +
                (sigma_posterior ** 2 + (mu_posterior - mu) ** 2) / (2 * sigma ** 2) - 0.5).mean()

        optimizer.zero_grad()  # Zero the gradients
        loss.backward()  # Compute gradients
        optimizer.step()  # Update model parameters


**Plotting Function**

In [8]:
# Plotting Function
def plot(model):
    """
    Plot the diffusion process and generated samples.

    Args:
        model (DiffusionModel): The trained diffusion model.
    """
    plt.figure(figsize=(10, 6))  # Set up the figure
    x0 = sample_batch(5000)  # Generate a batch of data
    x0_tensor = torch.from_numpy(x0).float().to(device)
    x20 = model.forward_process(x0_tensor, 20)[-1].cpu().numpy()  # Data at time step 20
    x40 = model.forward_process(x0_tensor, 40)[-1].cpu().numpy()  # Data at time step 40
    data = [x0, x20, x40]

    # Plot the data at different diffusion steps
    for i, (t, title) in enumerate(zip([0, 20, 39], [r'$t=0$', r'$t=\frac{T}{2}$', r'$t=T$'])):
        plt.subplot(2, 3, 1 + i)
        plt.scatter(data[i][:, 0], data[i][:, 1], alpha=.1, s=1)
        plt.xlim([-2, 2])
        plt.ylim([-2, 2])
        plt.gca().set_aspect('equal')
        if i == 0: plt.ylabel(r'$q(\mathbf{x}^{(0...T)})$', fontsize=17, rotation=0, labelpad=60)
        plt.title(title, fontsize=17)

    # Plot generated samples
    samples = model.sample(5000)
    for i, t in enumerate([0, 20, 40]):
        plt.subplot(2, 3, 4 + i)
        plt.scatter(samples[-(t+1)][:, 0].cpu().numpy(), samples[-(t+1)][:, 1].cpu().numpy(), alpha=.1, s=1, c='r')
        plt.xlim([-2, 2])
        plt.ylim([-2, 2])
        plt.gca().set_aspect('equal')
        if i == 0: plt.ylabel(r'$p(\mathbf{x}^{(0...T)})$', fontsize=17, rotation=0, labelpad=60)

    plt.savefig(f"Imgs/diffusion_model.png", bbox_inches='tight')  # Save the plot
    plt.close()


**Start the Training Process**

In [None]:
# Start the training process
model_mlp = MLP(hidden_dim=128).to(device)  # Initialize MLP model
model = DiffusionModel(model_mlp)  # Initialize Diffusion Model
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)  # Set up optimizer

train(model, optimizer)  # Train the model
plot(model)  # Plot the results

  2%|▏         | 3624/150000 [01:05<38:13, 63.82it/s]