# Neural Graphics Ex1: Training Your Own Diffusion Model!

## Setup environment

In [None]:
# We recommend using these utils.
# https://google.github.io/mediapy/mediapy.html
# https://einops.rocks/
# !pip install mediapy einops --quiet

In [None]:
# Import essential modules. Feel free to add whatever you need.
import torch
import matplotlib.pyplot as plt
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import torchvision
from tqdm import tqdm
import os
from torchvision import transforms
from random import randint
import numpy as np

print(torch.__version__)
print(torch.cuda.is_available())
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

### Seed your work
To be able to reproduce your code, please use a random seed from this point onward.

In [None]:
def seed_everything(seed):
    torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)


YOUR_SEED = 180  # modify if you want
seed_everything(YOUR_SEED)

## 1. Basic Ops and UNet blocks
**Notations:**  
 * `Conv2D(kernel_size, stride, padding)` is `nn.Conv2d()`  
 * `BN` is `nn.BatchNorm2d()`  
 * `GELU` is `nn.GELU()`  
 * `ConvTranspose2D(kernel_size, stride, padding)` is `nn.ConvTranspose2d()`  
 * `AvgPool(kernel_size)` is `nn.AvgPool2d()`  
 * `Linear` is `nn.Linear()`  
 * `N`, `C`, `W` and `H` are batch size, channels num, weight and height respectively


### Basic Ops

In [None]:
class Conv(nn.Module):
    """
    A convolutional layer that doesn’t change the image
    resolution, only the channel dimension
    Applies nn.Conv2d(3, 1, 1) followed by BN and GELU.
    """

    def __init__(self, in_channels: int, out_channels: int):
        """
        Initializes the Conv layer
        Args:
            in_channels (int): The number of input channels
            out_channels (int): The number of output channels
        """
        super().__init__()
        self.conv2d = nn.Conv2d(kernel_size=3,
                                in_channels=in_channels,
                                out_channels=out_channels,
                                padding=1,
                                stride=1, )
        self.bn = nn.BatchNorm2d(out_channels)
        self.gelu = nn.GELU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (N, in_channels, H, W) input tensor.

        Returns:
            (N, out_channels, H, W) output tensor.
        """
        x = self.conv2d(x)
        x = self.bn(x)
        x = self.gelu(x)
        return x


class DownConv(nn.Module):
    """
        A convolutional layer down-samples the tensor by 2.
        The layer consists of Conv2D(3, 2, 1) followed by BN and GELU.
    """

    def __init__(self, in_channels: int, out_channels: int):
        """
        Initializes the DownConv layer
        Args:
            in_channels (int): The number of input channels
            out_channels (int): The number of output channels
        """
        super().__init__()
        self.conv2d = nn.Conv2d(kernel_size=3,
                                in_channels=in_channels,
                                out_channels=out_channels,
                                stride=2,
                                padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.gelu = nn.GELU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (N, in_channels, H, W) input tensor.

        Returns:
            (N, out_channels, H/2, W/2) output tensor.
        """
        x = self.conv2d(x)
        x = self.bn(x)
        x = self.gelu(x)
        return x


class UpConv(nn.Module):
    """
    A convolutional layer that upsamples the tensor by 2.
    The layer consists of ConvTranspose2d(4, 2, 1) followed by
    BN and GELU.
    """

    def __init__(self, in_channels: int, out_channels: int):
        """
        Initializes the UpConv layer
        Args:
            in_channels (int): The number of input channels
            out_channels (int): The number of output channels
        """
        super().__init__()
        self.convTranspose2d = nn.ConvTranspose2d(kernel_size=2,
                                                  in_channels=in_channels,
                                                  out_channels=out_channels,
                                                  stride=2,
                                                  padding=0
                                                  )
        self.bn = nn.BatchNorm2d(out_channels)
        self.gelu = nn.GELU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (N, in_channels, H, W) input tensor.

        Returns:
            (N, out_channels, H*2, W*2) output tensor.
        """
        x = self.convTranspose2d(x)
        x = self.bn(x)
        x = self.gelu(x)
        return x


class Flatten(nn.Module):
    """
    Average pooling layer that flattens a 7x7 tensor into a 1x1 tensor.
    The layer consists of AvgPool followed by GELU.
    """

    def __init__(self):
        super().__init__()
        self.avgpool = nn.AvgPool2d(7)
        self.gelu = nn.GELU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (N, C, 7, 7) input tensor.

        Returns:
            (N, C, 1, 1) output tensor.
        """
        x = self.avgpool(x)
        x = self.gelu(x)
        return x


class Unflatten(nn.Module):
    """
      Convolutional layer that expends/up-samples a 1x1 tensor into a
      7x7 tensor. The layer consists of ConvTranspose2D(7, 7, 0)
      followed by BN and GELU.
    """

    def __init__(self, in_channels: int):
        """
        Initializes Unflatten layer
        Args:
            in_channels (int): The number of input channels
        """
        super().__init__()
        self.convtranspose2d = nn.ConvTranspose2d(kernel_size=7,
                                                  stride=7,
                                                  padding=0,
                                                  in_channels=in_channels,
                                                  out_channels=in_channels)
        self.bn = nn.BatchNorm2d(in_channels)
        self.gelu = nn.GELU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (N, in_channels, 1, 1) input tensor.

        Returns:
            (N, in_channels, 7, 7) output tensor.
        """
        x = self.convtranspose2d(x)
        x = self.bn(x)
        x = self.gelu(x)
        return x

### UNet Blocks

In [None]:
class ConvBlock(nn.Module):
    """
    Two consecutive Conv operations.
    Note that it has the same input and output shape as Conv.
    """

    def __init__(self, in_channels: int, out_channels: int):
        """
        Initializes ConvBlock
        Args:
            in_channels (int): The number of input channels
            out_channels (int): The number of output channels
        """
        super().__init__()
        self.conv1 = Conv(in_channels, out_channels)
        self.conv2 = Conv(out_channels, out_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (N, in_channels, H, W) input tensor.

        Returns:
            (N, out_channels, H, W) output tensor.
        """
        x = self.conv1(x)
        x = self.conv2(x)
        return x


class DownBlock(nn.Module):
    """
    DownConv followed by ConvBlock. Note that it has the same input and output
    shape as DownConv.
    """

    def __init__(self, in_channels: int, out_channels: int):
        """
        Initializes DownBlock
        Args:
            in_channels (int): The number of input channels
            out_channels (int): The number of output channels
        """
        super().__init__()
        self.conv1 = DownConv(in_channels, out_channels)
        self.conv2 = ConvBlock(out_channels, out_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (N, in_channels, H, W) input tensor.

        Returns:
            (N, out_channels, H/2, W/2) output tensor.
        """
        x = self.conv1(x)
        x = self.conv2(x)
        return x


class UpBlock(nn.Module):
    """
    UpConv followed by ConvBlock.
    Note that it has the same input and output shape as UpConv
    """

    def __init__(self, in_channels: int, out_channels: int):
        """
        Initializes UpBlock
        Args:
            in_channels (int): The number of input channels
            out_channels (int): The number of output channels
        """
        super().__init__()
        self.conv1 = UpConv(in_channels, out_channels)
        self.conv2 = ConvBlock(out_channels, out_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (N, in_channels, H, W) input tensor.

        Returns:
            (N, out_channels, H*2, W*2) output tensor.
        """
        x = self.conv1(x)
        x = self.conv2(x)
        return x


class FCBlock(nn.Module):
    """
    Fully-connected Block, consisting of FC layer followed by Linear layer. Note
    that it has the same input and output shape as FC.
    """

    def __init__(self, in_channels: int, out_channels: int):
        """
        Initializes FCBlock
        Args:
            in_channels (int): The number of input channels
            out_channels (int): The number of output channels
        """
        super().__init__()
        self.fc = FC(in_channels, out_channels)
        self.linear = nn.Linear(out_channels, out_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (N, in_channels) input tensor.

        Returns:
            (N, out_channels) output tensor.
        """
        x = self.fc(x)
        x = self.linear(x)
        return x


class FC(nn.Module):
    """
    Fully connected layer, consisting of nn.linear followed by GELU.
    """

    def __init__(self, in_channels: int, out_channels: int):
        """
        Initializes the FC layer
        Args:
            in_channels (int): The number of input channels
            out_channels (int): The number of output channels
        """
        super().__init__()
        self.fc = nn.Linear(in_channels, out_channels)
        self.gelu = nn.GELU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (N, in_channels) input tensor.

        Returns:
            (N, out_channels) output tensor.
        """
        x = self.fc(x)
        x = self.gelu(x)
        return x

## 2. Unconditional Diffusion Framework


### 2.1 UNet architecture

In [None]:
class DenoisingUNet(nn.Module):
    def __init__(
            self,
            in_channels: int,  # 1
            num_hiddens: int  # D
    ):
        super().__init__()
        # t-tensor
        self.t_fc4 = FCBlock(in_channels, 2 * num_hiddens)
        self.t_fc2 = FCBlock(in_channels, num_hiddens)

        # In
        self.conv_block_in = ConvBlock(in_channels, num_hiddens)  # (N, D, 28, 28)
        # Down
        self.down_block1 = DownBlock(num_hiddens, num_hiddens)  # (N, D, 14, 14)
        self.down_block2 = DownBlock(num_hiddens, 2 * num_hiddens)  # (N, 2 * D, 7, 7)
        self.flatten = Flatten()  # (N, 2 * D, 1, 1)
        # Up / with skip connections and t-tensor addition
        self.expend = Unflatten(2 * num_hiddens)  # (N, [2 * D + 2 * D], 7, 7)
        self.up_block2 = UpBlock(4 * num_hiddens, num_hiddens)  # (N, [D + D], 14, 14)
        self.up_block1 = UpBlock(2 * num_hiddens, num_hiddens)  # (N, [D + D], 28, 28)
        # Out
        self.conv_block_out = ConvBlock(2 * num_hiddens, num_hiddens)  # (N, D, 28, 28)
        self.conv2d = nn.Conv2d(
            in_channels=num_hiddens,
            out_channels=1,
            kernel_size=3,
            stride=1,
            padding=1
        )

    def forward(
            self,
            x: torch.Tensor,
            t: torch.Tensor,
    ) -> torch.Tensor:
        """
        Args:
            x: (N, C, H, W) input tensor.
            t: (N, 1) normalized time tensor.

        Returns:
            (N, C, H, W) output tensor.
        """
        assert x.shape[-2:] == (28, 28), "Expect input shape to be (28, 28)."
        # t-tensor
        t_layer4 = self.t_fc4(t).unsqueeze(-1).unsqueeze(-1)
        t_layer2 = self.t_fc2(t).unsqueeze(-1).unsqueeze(-1)

        # Down
        d_layer1 = self.conv_block_in(x)
        d_layer2 = self.down_block1(d_layer1)
        d_layer3 = self.down_block2(d_layer2)

        # Flatten/Extract
        flat = self.flatten(d_layer3)
        expend = self.expend(flat)

        expend = expend + t_layer4  # t-tensor addition layer 4

        # Up
        u_layer2 = self.up_block2(torch.cat((d_layer3, expend), dim=1))  # skip connection layer 3
        u_layer2 += t_layer2  # t-tensor addition layer 2

        u_layer1 = self.up_block1(torch.cat((d_layer2, u_layer2), dim=1))  # skip connection layer 2
        output = self.conv_block_out(torch.cat((d_layer1, u_layer1), dim=1))  # skip connection layer 1
        output = self.conv2d(output)
        return output

### 2.2 DDPM Forward and Inverse Process


In [None]:
def ddpm_schedule(beta1: float, beta2: float, num_ts: int, device: str = 'cuda') -> dict:
    """Constants for DDPM training and sampling.

    Arguments:
        beta1: float, starting beta value.
        beta2: float, ending beta value.
        num_ts: int, number of timesteps.

    Returns:
        dict with keys:
            betas: linear schedule of betas from beta1 to beta2.
            alphas: 1 - betas.
            alpha_bars: cumulative product of alphas.
            device: cuda or cpu
    """
    assert beta1 < beta2 < 1.0, "Expect beta1 < beta2 < 1.0."
    beta_list = torch.linspace(beta1, beta2, num_ts, device=device)
    alpha_list = 1 - beta_list
    alpha_bar_list = torch.cumprod(alpha_list, dim=0)
    alpha_bar_list = torch.cat([torch.ones(1, device=device), alpha_bar_list])

    return {'beta_list': beta_list, 'alpha_list': alpha_list, 'alpha_bar_list': alpha_bar_list}

In [None]:
def ddpm_forward(
        unet: DenoisingUNet,
        ddpm_schedule: dict,
        x_0: torch.Tensor,
        num_ts: int,
) -> torch.Tensor:
    """Algorithm 1 of the DDPM paper (not including gradient step).

    Args:
        unet: DenoisingUNet
        ddpm_schedule: dict
        x_0: (N, C, H, W) input tensor.
        num_ts: int, number of timesteps.
    Returns:
        (,) diffusion loss.
    """
    unet.train()
    N = x_0.shape[0]

    t = torch.randint(low=1, high=num_ts + 1, size=(N,), device=x_0.device)  # t ~ Uniform({1, ..., T})

    alpha_bar_t = ddpm_schedule['alpha_bar_list'][t].view(-1, 1, 1, 1)

    e = torch.randn_like(x_0)  # e ~ N(0, I)

    x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1 - alpha_bar_t) * e

    t_normalized = t.view(-1, 1).float().to(x_t.device) / num_ts

    e_hat = unet(x_t, t_normalized)

    return ((e - e_hat) ** 2).mean()

In [None]:
@torch.inference_mode()
def ddpm_sample(
        unet: DenoisingUNet,
        ddpm_schedule: dict,
        img_wh: tuple[int, int],
        batch_size: int,
        num_ts: int
) -> torch.Tensor:
    """Algorithm 2 of the DDPM paper.

    Args:
        unet: DenoisingUNet
        ddpm_schedule: dict
        img_wh: (H, W) output image width and height.
        num_ts: int, number of timesteps.

    Returns:
        (N, C, H, W) final sample.
    """
    unet.eval()

    betas = ddpm_schedule['beta_list']
    alphas = ddpm_schedule['alpha_list']
    alpha_bar = ddpm_schedule['alpha_bar_list']

    x_t = torch.randn(batch_size, 1, img_wh[0], img_wh[1], device=betas.device)  # x_t ~ N(0, I)

    for t in range(num_ts, 0, -1):

        b_t = betas[t - 1].view(1, 1, 1, 1)
        a_t = alphas[t - 1].view(1, 1, 1, 1)
        a_b_t = alpha_bar[t].view(1, 1, 1, 1)
        a_b_t_m_1 = alpha_bar[t - 1].view(1, 1, 1, 1)

        t_normalized = torch.full((batch_size, 1), t / num_ts, device=betas.device)
        e_hat = unet(x_t, t_normalized)

        if t > 1:
            z = torch.randn_like(x_t)
        else:
            z = torch.zeros_like(x_t)

        x_0_hat = (1 / torch.sqrt(a_b_t)) * (x_t - torch.sqrt(1 - a_b_t) * e_hat)

        cof1 = (torch.sqrt(a_b_t_m_1) * b_t) / (1 - a_b_t)
        cof2 = (torch.sqrt(a_t) * (1 - a_b_t_m_1)) / (1 - a_b_t)
        cof3 = torch.sqrt(b_t)

        x_t = cof1 * x_0_hat + cof2 * x_t + cof3 * z

    return x_t


In [None]:
# Do Not Modify
class DDPM(nn.Module):
    def __init__(
            self,
            unet: DenoisingUNet,
            betas: tuple[float, float] = (1e-4, 0.02),
            num_ts: int = 300,
            p_uncond: float = 0.1,
    ):
        super().__init__()
        self.unet = unet
        self.betas = betas
        self.num_ts = num_ts
        self.p_uncond = p_uncond
        self.ddpm_schedule = ddpm_schedule(betas[0], betas[1], num_ts)

        for k, v in ddpm_schedule(betas[0], betas[1], num_ts).items():
            self.register_buffer(k, v, persistent=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (N, C, H, W) input tensor.

        Returns:
            (,) diffusion loss.
        """
        return ddpm_forward(
            self.unet, self.ddpm_schedule, x, self.num_ts
        )

    @torch.inference_mode()
    def sample(
            self,
            img_wh: tuple[int, int],
            batch_size: int
    ):
        return ddpm_sample(
            self.unet, self.ddpm_schedule, img_wh, batch_size, self.num_ts
        )

### 2.3 Train your denoiser

In [None]:
# Hyper parameters - Modify if you wish
num_hidden = 128
batch_size = 64
num_epochs = 20
lr = 1e-3
img_wh = (28, 28)
eval_batch_size = 20
T = 300

train_data = MNIST(root='./data', train=True, download=True, transform=transforms.Compose([
    ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
]))
test_data = MNIST(root='./data', train=False, download=True, transform=transforms.Compose([
    ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
]))

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
eval_loader = DataLoader(test_data, batch_size=eval_batch_size,
                         shuffle=True)  # Not usefull now, but will be for evaluating class-conditioned denoiser (3.3)

# Init denoiser and DDPM wrapper
denosier_unet = DenoisingUNet(in_channels=1, num_hiddens=num_hidden)
ddpm = DDPM(denosier_unet, num_ts=T)

# Optimizer and device setup - Adam optimizer with exponential learning rate decay
optimizer = torch.optim.Adam(ddpm.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.1 ** (1.0 / num_epochs))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ddpm.to(device)

train_batch_losses = []
train_epoch_losses = []

run_id = f"{randint(0, 9999):04d}"
base_dir = os.path.join('runs', run_id)
plot_dir = os.path.join(base_dir, 'plots')
weight_dir = os.path.join(base_dir, 'weights')
os.makedirs(plot_dir, exist_ok=True)
os.makedirs(weight_dir, exist_ok=True)

tqdm.write(f"start training {run_id} on device: {device}")

total_iters = num_epochs * len(train_loader)  # single progress bar for the whole run

try:
    pbar = tqdm(total=total_iters, desc='Training', dynamic_ncols=True)
    for epoch in range(num_epochs):
        ddpm.train()  # Set the model to training mode
        epoch_loss = 0.0
        for batch, (data, label) in enumerate(train_loader):
            optimizer.zero_grad()
            data = data.to(device)
            loss = ddpm(data)
            loss.backward()  # Compute gradients
            optimizer.step()  # Update weights

            batch_loss = loss.item()
            train_batch_losses.append(batch_loss)
            epoch_loss += batch_loss

            pbar.set_description(f"Epoch [{epoch + 1}/{num_epochs}]")
            pbar.set_postfix(loss=f"{batch_loss:.4f}")
            pbar.update(1)

        avg_epoch_loss = epoch_loss / len(train_loader)
        train_epoch_losses.append(avg_epoch_loss)
        scheduler.step()

        ddpm.eval()
        with torch.no_grad():
            samples = ddpm.sample(img_wh, eval_batch_size).cpu()
        grid = torchvision.utils.make_grid(samples, normalize=True, value_range=(-1, 1))
        plt.figure(figsize=(6, 6))
        plt.imshow(grid.permute(1, 2, 0).cpu(), cmap='gray')
        plt.axis('off')
        plt.title(f'Samples at Epoch {epoch + 1}')
        plt.savefig(os.path.join(plot_dir, f'samples_epoch_{epoch + 1:02d}.png'), dpi=300, bbox_inches='tight')
        plt.close()

        torch.save(ddpm.state_dict(), os.path.join(weight_dir, f'ddpm_epoch_{epoch + 1:02d}.pt'))
except KeyboardInterrupt:
    tqdm.write("Training interrupted by user.")
finally:
    pbar.close()

    # --------- Plots ---------
    plt.figure(figsize=(8, 4))
    plt.plot(train_batch_losses, alpha=0.8, label='Batch loss')
    plt.xlabel('Iteration')
    plt.ylabel('MSE loss')
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(plot_dir, 'batch_loss.png'), dpi=300)
    plt.close()

    plt.figure(figsize=(8, 4))
    plt.plot(train_epoch_losses, linewidth=2, label='Epoch loss')
    plt.xlabel('Epoch')
    plt.ylabel('MSE loss')
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(plot_dir, 'epoch_loss.png'), dpi=300)
    plt.close()


## 3. Implementing class-conditioned diffusion framework with CFG


### 3.1 Adding Class-Conditioning to UNet architecture

In [None]:
class ConditionalDenoisingUNet(nn.Module):
    def __init__(
            self,
            in_channels: int,
            num_classes: int,
            num_hiddens: int,
    ):
        super().__init__()
        self.num_classes = num_classes
        # t-tensor
        self.t_fc4 = FCBlock(in_channels, 2 * num_hiddens)
        self.t_fc2 = FCBlock(in_channels, num_hiddens)
        # c-tesnor
        self.c_fc4 = FCBlock(num_classes, 2 * num_hiddens)
        self.c_fc2 = FCBlock(num_classes, num_hiddens)

        # In
        self.conv_block_in = ConvBlock(in_channels, num_hiddens)  # (N, D, 28, 28)
        # Down
        self.down_block1 = DownBlock(num_hiddens, num_hiddens)  # (N, D, 14, 14)
        self.down_block2 = DownBlock(num_hiddens, 2 * num_hiddens)  # (N, 2 * D, 7, 7)
        self.flatten = Flatten()  # (N, 2 * D, 1, 1)
        # Up / with skip connections and t-tensor addition
        self.expend = Unflatten(2 * num_hiddens)  # (N, [2 * D + 2 * D], 7, 7)
        self.up_block2 = UpBlock(4 * num_hiddens, num_hiddens)  # (N, [D + D], 14, 14)
        self.up_block1 = UpBlock(2 * num_hiddens, num_hiddens)  # (N, [D + D], 28, 28)
        # Out
        self.conv_block_out = ConvBlock(2 * num_hiddens, num_hiddens)  # (N, D, 28, 28)
        self.conv2d = nn.Conv2d(
            in_channels=num_hiddens,
            out_channels=1,
            kernel_size=3,  # or whatever you used earlier
            stride=1,
            padding=1
        )

    def forward(
            self,
            x: torch.Tensor,
            c: torch.Tensor,
            t: torch.Tensor,
            mask: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """
        Args:
            x: (N, C, H, W) input tensor.
            c: (N, num_classes) float condition tensor.
            t: (N, 1) normalized time tensor.
            mask: (N, 1) mask tensor. If not None, mask out condition when mask == 0.

        Returns:
            (N, C, H, W) output tensor.
        """
        assert x.shape[-2:] == (28, 28), "Expect input shape to be (28, 28)."
        # t-tensor
        t_layer4 = self.t_fc4(t).unsqueeze(-1).unsqueeze(-1)
        t_layer2 = self.t_fc2(t).unsqueeze(-1).unsqueeze(-1)
        # c-tensor
        c_layer4 = self.c_fc4(c).unsqueeze(-1).unsqueeze(-1)
        c_layer2 = self.c_fc2(c).unsqueeze(-1).unsqueeze(-1)

        # Down
        d_layer1 = self.conv_block_in(x)
        d_layer2 = self.down_block1(d_layer1)
        d_layer3 = self.down_block2(d_layer2)

        # Flatten/Extract
        flat = self.flatten(d_layer3)
        expend = self.expend(flat)

        expend = c_layer4 * expend + t_layer4  # c-tensor multiplication, t-tensor addition layer 4

        # Up
        u_layer2 = self.up_block2(torch.cat((d_layer3, expend), dim=1))  # skip connection layer 3
        u_layer2 = c_layer2 * u_layer2 + t_layer2  # c-tensor multiplication, t-tensor addition layer 2

        u_layer1 = self.up_block1(torch.cat((d_layer2, u_layer2), dim=1))  # skip connection layer 2
        output = self.conv_block_out(torch.cat((d_layer1, u_layer1), dim=1))  # skip connection layer 1
        output = self.conv2d(output)
        return output

### 3.2 DDPM Forward and Inverse Process with CFG

In [None]:
def ddpm_forward(
        unet: ConditionalDenoisingUNet,
        ddpm_schedule: dict,
        x_0: torch.Tensor,
        c: torch.Tensor,
        p_uncond: float,
        num_ts: int,
) -> torch.Tensor:
    """Algorithm 3 (not including gradient step).

    Args:
        unet: ConditionalDenoisingUNet
        ddpm_schedule: dict
        x_0: (N, C, H, W) input tensor.
        c: (N,) int64 condition tensor.
        p_uncond: float, probability of unconditioning the condition.
        num_ts: int, number of timesteps.

    Returns:
        (,) diffusion loss.
    """
    unet.train()
    N = x_0.shape[0]
    num_classes = unet.num_classes

    # make C into one-hot vector and set 0 with probability of p_uncond
    c_one_hot = torch.eye(num_classes, device=c.device)[c]
    c_one_hot *= torch.rand(N, 1, device=c.device) > p_uncond

    t = torch.randint(low=1, high=num_ts + 1, size=(N,), device=x_0.device)  # t ~ Uniform({1, ..., T})

    alpha_bar_t = ddpm_schedule['alpha_bar_list'][t].view(-1, 1, 1, 1)

    e = torch.randn_like(x_0)  # e ~ N(0, I)

    x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1 - alpha_bar_t) * e

    t_normalized = t.view(-1, 1).float().to(x_t.device) / num_ts

    e_hat = unet(x_t, c_one_hot, t_normalized)

    return ((e - e_hat) ** 2).mean()



In [None]:
@torch.inference_mode()
def ddpm_cfg_sample(
        unet: ConditionalDenoisingUNet,
        ddpm_schedule: dict,
        c: torch.Tensor,
        img_wh: tuple[int, int],
        num_ts: int,
        guidance_scale: float = 5.0
) -> torch.Tensor:
    """Algorithm 4.

    Args:
        unet: ConditionalDenoisingUNet
        ddpm_schedule: dict
        c: (N,) int64 condition tensor. Only for class-conditional
        img_wh: (H, W) output image width and height.
        num_ts: int, number of timesteps.
        guidance_scale: float, CFG scale.

    Returns:
        (N, C, H, W) final sample.
    """
    unet.eval()
    device = c.device
    N = c.shape[0]
    num_classes = unet.num_classes

    # make C into one-hot vector
    c_one_hot = torch.eye(num_classes, device=device)[c]

    betas = ddpm_schedule['beta_list'].to(device)
    alphas = ddpm_schedule['alpha_list'].to(device)
    alpha_bar = ddpm_schedule['alpha_bar_list'].to(device)

    x_t = torch.randn(N, 1, img_wh[0], img_wh[1], device=device)  # x_t ~ N(0, I)

    for t in range(num_ts, 0, -1):
        b_t = betas[t - 1].view(1, 1, 1, 1)
        a_t = alphas[t - 1].view(1, 1, 1, 1)
        a_b_t = alpha_bar[t].view(1, 1, 1, 1)
        a_b_t_m_1 = alpha_bar[t - 1].view(1, 1, 1, 1)

        t_normalized = torch.full((N, 1), t / num_ts, device=device)
        e_cond = unet(x_t, c_one_hot, t_normalized)
        e_uncond = unet(x_t, torch.zeros_like(c_one_hot, device=device), t_normalized)

        e_hat = e_uncond + guidance_scale * (e_cond - e_uncond)

        z = torch.randn_like(x_t) if t > 1 else torch.zeros_like(x_t)

        x_0_hat = (1 / torch.sqrt(a_b_t)) * (x_t - torch.sqrt(1 - a_b_t) * e_hat)

        cof1 = (torch.sqrt(a_b_t_m_1) * b_t) / (1 - a_b_t)
        cof2 = (torch.sqrt(a_t) * (1 - a_b_t_m_1)) / (1 - a_b_t)
        cof3 = torch.sqrt(b_t)

        x_t = cof1 * x_0_hat + cof2 * x_t + cof3 * z

    return x_t


In [None]:
# Do Not Modify
class DDPM(nn.Module):
    def __init__(
            self,
            unet: ConditionalDenoisingUNet,
            betas: tuple[float, float] = (1e-4, 0.02),
            num_ts: int = 300,
            p_uncond: float = 0.1,
    ):
        super().__init__()
        self.unet = unet
        self.betas = betas
        self.num_ts = num_ts
        self.p_uncond = p_uncond
        self.ddpm_schedule = ddpm_schedule(betas[0], betas[1], num_ts)

    def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (N, C, H, W) input tensor.
            c: (N,) int64 condition tensor.

        Returns:
            (,) diffusion loss.
        """
        return ddpm_forward(
            self.unet, self.ddpm_schedule, x, c, self.p_uncond, self.num_ts
        )

    @torch.inference_mode()
    def sample(
            self,
            c: torch.Tensor,
            img_wh: tuple[int, int],
            guidance_scale: float = 5.0
    ):
        return ddpm_cfg_sample(
            self.unet, self.ddpm_schedule, c, img_wh, self.num_ts, guidance_scale
        )

### 3.3 Train your class-conditioned denoiser

In [None]:
# Hyper parameters - Modify if you wish
num_hidden = 128
batch_size = 64
num_epochs = 20
lr = 1e-3
img_wh = (28, 28)
eval_batch_size = 20
T = 300

train_data = MNIST(root='./data', train=True, download=True, transform=transforms.Compose([
    ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
]))
test_data = MNIST(root='./data', train=False, download=True, transform=transforms.Compose([
    ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
]))

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
eval_loader = DataLoader(test_data, batch_size=eval_batch_size, shuffle=True)

denosier_unet = ConditionalDenoisingUNet(
    in_channels=1,
    num_classes=10,
    num_hiddens=num_hidden
)
ddpm = DDPM(denosier_unet, num_ts=T, p_uncond=0.1)

optimizer = torch.optim.Adam(ddpm.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ExponentialLR(
    optimizer=optimizer, gamma=0.1 ** (1.0 / num_epochs)
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ddpm.to(device)

train_batch_losses = []
train_epoch_losses = []

run_id = f"cond{randint(0, 9999):04d}"
base_dir = os.path.join('runs', run_id)
plot_dir = os.path.join(base_dir, 'plots')
weight_dir = os.path.join(base_dir, 'weights')
os.makedirs(plot_dir, exist_ok=True)
os.makedirs(weight_dir, exist_ok=True)

tqdm.write(f"start training {run_id} on device: {device}")

total_iters = num_epochs * len(train_loader)

try:
    pbar = tqdm(total=total_iters, desc='Training', dynamic_ncols=True)
    for epoch in range(num_epochs):
        ddpm.train()  # Set the model to training mode
        epoch_loss = 0.0
        for batch, (data, labels) in enumerate(train_loader):
            optimizer.zero_grad()
            data = data.to(device)
            labels = labels.to(device)
            loss = ddpm(data, labels)
            loss.backward()
            optimizer.step()

            batch_loss = loss.item()
            train_batch_losses.append(batch_loss)
            epoch_loss += batch_loss

            pbar.set_description(f"Epoch [{epoch + 1}/{num_epochs}]")
            pbar.set_postfix(loss=f"{batch_loss:.4f}")
            pbar.update(1)

        avg_epoch_loss = epoch_loss / len(train_loader)
        train_epoch_losses.append(avg_epoch_loss)
        scheduler.step()

        ddpm.eval()
        with torch.no_grad():
            # sample eval_batch_size random class labels
            sample_labels = torch.randint(0, 10, (eval_batch_size,), device=device)
            samples = ddpm.sample(sample_labels, img_wh, guidance_scale=5.0).cpu()

        # denormalize to [0,1]
        samples = (samples + 1) / 2.0
        n_cols = int(np.ceil(eval_batch_size ** 0.5))
        n_rows = int(np.ceil(eval_batch_size / n_cols))

        fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2, n_rows * 2))
        axes = axes.flatten()
        for idx, (img, lbl) in enumerate(zip(samples, sample_labels.cpu())):
            axes[idx].imshow(img.squeeze(0), cmap='gray')
            axes[idx].axis('off')
            axes[idx].set_title(str(lbl.item()), fontsize=8, pad=2)  # label under each image

        # turn off any empty subplots
        for ax in axes[len(samples):]:
            ax.axis('off')

        plt.suptitle(f'Samples at Epoch {epoch + 1}', y=0.92)
        plt.tight_layout()
        plt.savefig(os.path.join(plot_dir, f'samples_epoch_{epoch + 1:02d}.png'),
                    dpi=300, bbox_inches='tight')
        plt.close()

        torch.save(ddpm.state_dict(), os.path.join(
            weight_dir, f'ddpm_epoch_{epoch + 1:02d}.pt'))
except KeyboardInterrupt:
    tqdm.write("Training interrupted by user.")
finally:
    pbar.close()

    # ---------- Loss plots ----------
    plt.figure(figsize=(8, 4))
    plt.plot(train_batch_losses, alpha=0.8, label='Batch loss')
    plt.xlabel('Iteration')
    plt.ylabel('MSE loss')
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(plot_dir, 'batch_loss.png'), dpi=300)
    plt.close()

    plt.figure(figsize=(8, 4))
    plt.plot(train_epoch_losses, linewidth=2, label='Epoch loss')
    plt.xlabel('Epoch')
    plt.ylabel('MSE loss')
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(plot_dir, 'epoch_loss.png'), dpi=300)
    plt.close()

### 3.4 Experiment with different guidance sacles

In [None]:
@torch.inference_mode()
def experiment_guidance_scales(
    model_dir: str,
    img_wh: tuple[int, int] = (28, 28),
    num_ts: int = 300,
    num_classes: int = 10,
    num_samples_per_class: int = 1,
    guidance_scales: list[float] = [0, 1, 5, 7, 10, 15],
    output_dir: str = "guidance_experiment_outputs"
):
    os.makedirs(output_dir, exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load model
    unet = ConditionalDenoisingUNet(in_channels=1, num_classes=num_classes, num_hiddens=128)
    ddpm_model = DDPM(unet, num_ts=num_ts)
    latest_ckpt = sorted(os.listdir(os.path.join(model_dir, "weights")))[-1]
    ddpm_model.load_state_dict(torch.load(os.path.join(model_dir, "weights", latest_ckpt)))
    ddpm_model = ddpm_model.to(device)
    ddpm_model.eval()

    all_samples = {}

    for g_scale in tqdm(guidance_scales, desc="Guidance scales"):
        images = []
        labels = []

        for digit in range(num_classes):
            class_batch = torch.full((num_samples_per_class,), digit, dtype=torch.long, device=device)
            sample = ddpm_model.sample(c=class_batch, img_wh=img_wh, guidance_scale=g_scale)
            sample = torch.clamp((sample + 1) / 2, 0.0, 1.0)
            images.append(sample)
            labels += [digit] * num_samples_per_class

        images = torch.cat(images, dim=0)
        grid = torchvision.utils.make_grid(images, nrow=num_classes, padding=2)

        plt.figure(figsize=(num_classes * 1.5, 2))
        plt.imshow(grid.permute(1, 2, 0).cpu())
        plt.axis("off")
        plt.title(f"Guidance scale = {g_scale}")
        plt.savefig(os.path.join(output_dir, f"guidance_{g_scale:.1f}.png"), dpi=300, bbox_inches='tight')
        plt.close()

        all_samples[g_scale] = (images.cpu(), labels)

    return all_samples


samples = experiment_guidance_scales(
        model_dir="runs/cond8049",
        output_dir="guidance_grids"
    )
