# Neural Graphics Ex1: Training Your Own Diffusion Model!

## Setup environment

In [37]:
# We recommend using these utils.
# https://google.github.io/mediapy/mediapy.html
# https://einops.rocks/
# !pip install mediapy einops --quiet

In [38]:
# Import essential modules. Feel free to add whatever you need.
import torch
import matplotlib.pyplot as plt
from sympy.polys.polyoptions import Expand
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import torch.nn.functional as F
from tqdm import tqdm

print(torch.__version__)
print(torch.cuda.is_available())

2.7.1+cu118
True


### Seed your work
To be able to reproduce your code, please use a random seed from this point onward.

In [39]:
def seed_everything(seed):
    torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)


YOUR_SEED = 180  # modify if you want
seed_everything(YOUR_SEED)

## 1. Basic Ops and UNet blocks
**Notations:**  
 * `Conv2D(kernel_size, stride, padding)` is `nn.Conv2d()`  
 * `BN` is `nn.BatchNorm2d()`  
 * `GELU` is `nn.GELU()`  
 * `ConvTranspose2D(kernel_size, stride, padding)` is `nn.ConvTranspose2d()`  
 * `AvgPool(kernel_size)` is `nn.AvgPool2d()`  
 * `Linear` is `nn.Linear()`  
 * `N`, `C`, `W` and `H` are batch size, channels num, weight and height respectively


### Basic Ops

In [40]:
class Conv(nn.Module):
    """
    A convolutional layer that doesn’t change the image
    resolution, only the channel dimension
    Applies nn.Conv2d(3, 1, 1) followed by BN and GELU.
    """

    def __init__(self, in_channels: int, out_channels: int):
        """
        Initializes the Conv layer
        Args:
            in_channels (int): The number of input channels
            out_channels (int): The number of output channels
        """
        super().__init__()
        self.conv2d = nn.Conv2d(kernel_size=3,
                                padding=1,
                                stride=1,
                                in_channels=in_channels,
                                out_channels=out_channels)
        self.bn = nn.BatchNorm2d(out_channels)
        self.gelu = nn.GELU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (N, in_channels, H, W) input tensor.

        Returns:
            (N, out_channels, H, W) output tensor.
        """
        x = self.conv2d(x)
        x = self.bn(x)
        x = self.gelu(x)
        return x


class DownConv(nn.Module):
    """
        A convolutional layer down-samples the tensor by 2.
        The layer consists of Conv2D(3, 2, 1) followed by BN and GELU.
    """

    def __init__(self, in_channels: int, out_channels: int):
        """
        Initializes the DownConv layer
        Args:
            in_channels (int): The number of input channels
            out_channels (int): The number of output channels
        """
        super().__init__()
        self.conv2d = nn.Conv2d(kernel_size=3,
                                stride=2,
                                padding=1,
                                in_channels=in_channels,
                                out_channels=out_channels)
        self.bn = nn.BatchNorm2d(out_channels)
        self.gelu = nn.GELU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (N, in_channels, H, W) input tensor.

        Returns:
            (N, out_channels, H/2, W/2) output tensor.
        """
        x = self.conv2d(x)
        x = self.bn(x)
        x = self.gelu(x)
        return x


class UpConv(nn.Module):
    """
    A convolutional layer that upsamples the tensor by 2.
    The layer consists of ConvTranspose2d(4, 2, 1) followed by
    BN and GELU.
    """

    def __init__(self, in_channels: int, out_channels: int):
        """
        Initializes the UpConv layer
        Args:
            in_channels (int): The number of input channels
            out_channels (int): The number of output channels
        """
        super().__init__()
        self.convTranspose2d = nn.ConvTranspose2d(kernel_size=4,
                                                  stride=2,
                                                  padding=1,
                                                  in_channels=in_channels,
                                                  out_channels=out_channels)
        self.bn = nn.BatchNorm2d(out_channels)
        self.gelu = nn.GELU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (N, in_channels, H, W) input tensor.

        Returns:
            (N, out_channels, H*2, W*2) output tensor.
        """
        x = self.convTranspose2d(x)
        x = self.bn(x)
        x = self.gelu(x)
        return x


class Flatten(nn.Module):
    """
    Average pooling layer that flattens a 7x7 tensor into a 1x1 tensor.
    The layer consists of AvgPool followed by GELU.
    """

    def __init__(self):
        super().__init__()
        self.avgpool = nn.AvgPool2d(7)
        self.gelu = nn.GELU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (N, C, 7, 7) input tensor.

        Returns:
            (N, C, 1, 1) output tensor.
        """
        x = self.avgpool(x)
        x = self.gelu(x)
        return x


class Unflatten(nn.Module):
    """
      Convolutional layer that expends/up-samples a 1x1 tensor into a
      7x7 tensor. The layer consists of ConvTranspose2D(7, 7, 0)
      followed by BN and GELU.
    """

    def __init__(self, in_channels: int):
        """
        Initializes Unflatten layer
        Args:
            in_channels (int): The number of input channels
        """
        super().__init__()
        self.convtranspose2d = nn.ConvTranspose2d(kernel_size=7,
                                                  stride=7,
                                                  padding=0,
                                                  in_channels=in_channels,
                                                  out_channels=in_channels)
        self.bn = nn.BatchNorm2d(in_channels)
        self.gelu = nn.GELU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (N, in_channels, 1, 1) input tensor.

        Returns:
            (N, in_channels, 7, 7) output tensor.
        """
        x = self.convtranspose2d(x)
        x = self.bn(x)
        x = self.gelu(x)
        return x


class FC(nn.Module):
    """
    Fully connected layer, consisting of nn.linear followed by GELU.
    """

    def __init__(self, in_channels: int, out_channels: int):
        """
        Initializes the FC layer
        Args:
            in_channels (int): The number of input channels
            out_channels (int): The number of output channels
        """
        super().__init__()
        self.fc = nn.Linear(in_channels, out_channels)
        self.gelu = nn.GELU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (N, in_channels) input tensor.

        Returns:
            (N, out_channels) output tensor.
        """
        x = self.fc(x)
        x = self.gelu(x)
        return x



### UNet Blocks

In [41]:
class ConvBlock(nn.Module):
    """
    Two consecutive Conv operations.
    Note that it has the same input and output shape as Conv.
    """

    def __init__(self, in_channels: int, out_channels: int):
        """
        Initializes ConvBlock
        Args:
            in_channels (int): The number of input channels
            out_channels (int): The number of output channels
        """
        super().__init__()
        self.conv1 = Conv(in_channels, out_channels)
        self.conv2 = Conv(out_channels, out_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (N, in_channels, H, W) input tensor.

        Returns:
            (N, out_channels, H, W) output tensor.
        """
        x = self.conv1(x)
        x = self.conv2(x)
        return x


class DownBlock(nn.Module):
    """
    DownConv followed by ConvBlock. Note that it has the same input and output
    shape as DownConv.
    """

    def __init__(self, in_channels: int, out_channels: int):
        """
        Initializes DownBlock
        Args:
            in_channels (int): The number of input channels
            out_channels (int): The number of output channels
        """
        super().__init__()
        self.conv1 = DownConv(in_channels, out_channels)
        self.conv2 = ConvBlock(out_channels, out_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (N, in_channels, H, W) input tensor.

        Returns:
            (N, out_channels, H/2, W/2) output tensor.
        """
        x = self.conv1(x)
        x = self.conv2(x)
        return x


class UpBlock(nn.Module):
    """
    UpConv followed by ConvBlock.
    Note that it has the same input and output shape as UpConv
    """

    def __init__(self, in_channels: int, out_channels: int):
        """
        Initializes UpBlock
        Args:
            in_channels (int): The number of input channels
            out_channels (int): The number of output channels
        """
        super().__init__()
        self.conv1 = UpConv(in_channels, out_channels)
        self.conv2 = ConvBlock(out_channels, out_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (N, in_channels, H, W) input tensor.

        Returns:
            (N, out_channels, H*2, W*2) output tensor.
        """
        x = self.conv1(x)
        x = self.conv2(x)
        return x


class FCBlock(nn.Module):
    """
    Fully-connected Block, consisting of FC layer followed by Linear layer. Note
    that it has the same input and output shape as FC.
    """

    def __init__(self, in_channels: int, out_channels: int):
        """
        Initializes FCBlock
        Args:
            in_channels (int): The number of input channels
            out_channels (int): The number of output channels
        """
        super().__init__()
        self.fc = FC(in_channels, out_channels)
        self.linear = nn.Linear(out_channels, out_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (N, in_channels) input tensor.

        Returns:
            (N, out_channels) output tensor.
        """
        x = self.fc(x)
        x = self.linear(x)
        return x

## 2. Unconditional Diffusion Framework


### 2.1 UNet architecture

In [42]:
class DenoisingUNet(nn.Module):
    def __init__(
            self,
            in_channels: int,  # 1
            num_hiddens: int  # D
    ):
        super().__init__()
        # t-tensor
        self.t_fc4 = FCBlock(in_channels, 2 * num_hiddens)
        self.t_fc2 = FCBlock(in_channels, num_hiddens)

        # In
        self.conv_block_in = ConvBlock(in_channels, num_hiddens)  # (N, D, 28, 28)
        # Down
        self.down_block1 = DownBlock(num_hiddens, num_hiddens)  # (N, D, 14, 14)
        self.down_block2 = DownBlock(num_hiddens, 2 * num_hiddens)  # (N, 2 * D, 7, 7)
        self.flatten = Flatten()  # (N, 2 * D, 1, 1)
        # Up / with skip connections and t-tensor addition
        self.expend = Unflatten(2 * num_hiddens)  # (N, [2 * D + 2 * D], 7, 7)
        self.up_block2 = UpBlock(4 * num_hiddens, num_hiddens)  # (N, [D + D], 14, 14)
        self.up_block1 = UpBlock(2 * num_hiddens, num_hiddens)  # (N, [D + D], 28, 28)
        # Out
        self.conv_block_out = ConvBlock(2 * num_hiddens, num_hiddens)  # (N, D, 28, 28)
        self.conv2d = Conv(num_hiddens, 1)  #  (N, 1, 28, 28)

    def forward(
            self,
            x: torch.Tensor,
            t: torch.Tensor,
    ) -> torch.Tensor:
        """
        Args:
            x: (N, C, H, W) input tensor.
            t: (N, 1) normalized time tensor.

        Returns:
            (N, C, H, W) output tensor.
        """
        assert x.shape[-2:] == (28, 28), "Expect input shape to be (28, 28)."
        # t-tensor
        t_layer4 = self.t_fc4(t).unsqueeze(-1).unsqueeze(-1)
        t_layer2 = self.t_fc2(t).unsqueeze(-1).unsqueeze(-1)

        # Down
        d_layer1 = self.conv_block_in(x)
        d_layer2 = self.down_block1(d_layer1)
        d_layer3 = self.down_block2(d_layer2)

        # Flatten/Extract
        flat = self.flatten(d_layer3)
        expend = self.expend(flat)

        expend += t_layer4  # t-tensor addition layer 4

        # Up
        u_layer2 = self.up_block2(torch.cat((d_layer3, expend), dim=1))  # skip connection layer 3
        u_layer2 += t_layer2  # t-tensor addition layer 2

        u_layer1 = self.up_block1(torch.cat((d_layer2, u_layer2), dim=1))  # skip connection layer 2
        output = self.conv_block_out(torch.cat((d_layer1, u_layer1), dim=1))  # skip connection layer 1
        output = self.conv2d(output)
        return output


data = torch.randn(10, 1, 28, 28)  # batch, channel, height, width
t = torch.randn(10, 1)
model = DenoisingUNet(data.shape[1], 10)
model.forward(x=data, t=t)


tensor([[[[ 0.5900,  0.5139,  0.5533,  ...,  0.1197,  0.5582,  0.2314],
          [ 0.8292,  0.6498,  0.7629,  ...,  0.2697,  0.2929,  0.5280],
          [ 0.4137,  0.1206,  0.3123,  ...,  0.7934,  0.1878,  0.1778],
          ...,
          [ 0.6513, -0.0946, -0.1580,  ..., -0.1355, -0.0608,  0.1265],
          [ 0.1745,  0.3268, -0.1534,  ...,  0.0405,  0.1543,  0.3352],
          [ 0.4307,  0.3598, -0.0025,  ..., -0.1144,  0.1708,  0.3015]]],


        [[[ 0.5589,  0.7514,  0.4481,  ...,  0.4048,  0.4661,  0.2583],
          [ 0.3091,  1.1313,  0.0553,  ...,  0.8931,  0.2321,  0.1722],
          [-0.0697,  0.1088,  0.2607,  ...,  0.6474,  0.2348,  0.0376],
          ...,
          [-0.1251,  0.0180,  0.5869,  ..., -0.1507,  1.1511,  0.4065],
          [ 0.5442,  0.4756,  0.4075,  ...,  1.2067,  0.4984,  0.0849],
          [ 1.0770, -0.1525,  0.0429,  ...,  0.7998,  0.5454,  0.5100]]],


        [[[ 0.4541,  0.2706, -0.1018,  ..., -0.1352,  0.6679, -0.1202],
          [ 0.4426,  0.584

### 2.2 DDPM Forward and Inverse Process


In [43]:
def ddpm_schedule(beta1: float, beta2: float, num_ts: int, device: str = 'cuda') -> dict:
    """Constants for DDPM training and sampling.

    Arguments:
        beta1: float, starting beta value.
        beta2: float, ending beta value.
        num_ts: int, number of timesteps.

    Returns:
        dict with keys:
            betas: linear schedule of betas from beta1 to beta2.
            alphas: 1 - betas.
            alpha_bars: cumulative product of alphas.
    """
    assert beta1 < beta2 < 1.0, "Expect beta1 < beta2 < 1.0."

    beta_seq = torch.linspace(beta1, beta2, num_ts, device=device)
    alpha_seq = 1.0 - beta_seq
    alpha_bar_seq = torch.cumprod(alpha_seq, dim=0)

    return {
        "beta_seq": beta_seq,
        "alpha_seq": alpha_seq,
        "alpha_bar_seq": alpha_bar_seq,
    }


In [44]:
def ddpm_forward(
        unet: DenoisingUNet,
        ddpm_schedule: dict,
        x_0: torch.Tensor,
        num_ts: int,
) -> torch.Tensor:
    """Algorithm 1 of the DDPM paper (not including gradient step).

    Args:
        unet: DenoisingUNet
        ddpm_schedule: dict
        x_0: (N, C, H, W) input tensor.
        num_ts: int, number of timesteps.
    Returns:
        (,) diffusion loss.
    """
    unet.train()
    t = torch.randint(low=0, high=num_ts, size=(x_0.shape[0],),
                      device=x_0.device)  # Uniform distribution for each sample
    epsilon = torch.randn_like(x_0)  # Normal distribution with mean 0 and variance 1

    alpha_bar_seq = ddpm_schedule["alpha_bar_seq"]
    alpha_bar_t = alpha_bar_seq[t].view(-1, 1, 1, 1)

    x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1 - alpha_bar_t) * epsilon
    t_normalized = t.view(-1, 1).float() / num_ts  # normalize t to [0, 1]

    epsilon_roof = unet(x_t, t_normalized)

    return F.mse_loss(epsilon_roof, epsilon)



In [45]:
@torch.inference_mode()
def ddpm_sample(
        unet: DenoisingUNet,
        ddpm_schedule: dict,
        img_wh: tuple[int, int],
        batch_size: int,
        num_ts: int,
        device: str = 'cuda'
) -> torch.Tensor:
    """Algorithm 2 of the DDPM paper.

    Args:
        unet: DenoisingUNet
        ddpm_schedule: dict
        img_wh: (H, W) output image width and height.
        num_ts: int, number of timesteps.

    Returns:
        (N, C, H, W) final sample.
    """
    unet.eval()
    x_t = torch.randn(batch_size, 1, img_wh[0], img_wh[1],
                      device=device)  # Normal distribution with mean 0 and variance 1

    beta_seq = ddpm_schedule["beta_seq"]
    alpha_seq = ddpm_schedule["alpha_seq"]
    alpha_bar_seq = ddpm_schedule["alpha_bar_seq"]

    for t in reversed(range(1, num_ts)):
        t_idx = torch.full((batch_size,), t, device=device, dtype=torch.long)
        t_normalized = t_idx.float().unsqueeze(1) / num_ts  # (N, 1)

        # Schedule scalars for batch
        beta_t = beta_seq[t_idx].view(-1, 1, 1, 1)
        alpha_t = alpha_seq[t_idx].view(-1, 1, 1, 1)
        alpha_bar_t = alpha_bar_seq[t_idx].view(-1, 1, 1, 1)

        if t > 1:
            alpha_bar_prev = alpha_seq[t_idx - 1].view(-1, 1, 1, 1)
            z = torch.randn_like(x_t)
        else:
            alpha_bar_prev = torch.ones_like(alpha_bar_t)
            z = torch.zeros_like(x_t)

        epsilon_roof = unet(x_t, t_normalized)
        x_hat0 = (x_t - torch.sqrt(1 - alpha_bar_t) * epsilon_roof) / torch.sqrt(alpha_bar_t)

        x_t = (torch.sqrt(alpha_bar_prev) * beta_t / (1 - alpha_bar_t) * x_hat0 +
               torch.sqrt(alpha_t) * (1 - alpha_bar_prev) / (1 - alpha_bar_t) * x_t +
               torch.sqrt(beta_t) * z
               )

    return x_t


In [46]:
# Do Not Modify
class DDPM(nn.Module):
    def __init__(
            self,
            unet: DenoisingUNet,
            betas: tuple[float, float] = (1e-4, 0.02),
            num_ts: int = 300,
            p_uncond: float = 0.1,
    ):
        super().__init__()
        self.unet = unet
        self.betas = betas
        self.num_ts = num_ts
        self.p_uncond = p_uncond
        self.ddpm_schedule = ddpm_schedule(betas[0], betas[1], num_ts)

        for k, v in ddpm_schedule(betas[0], betas[1], num_ts).items():
            self.register_buffer(k, v, persistent=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (N, C, H, W) input tensor.

        Returns:
            (,) diffusion loss.
        """
        return ddpm_forward(
            self.unet, self.ddpm_schedule, x, self.num_ts
        )

    @torch.inference_mode()
    def sample(
            self,
            img_wh: tuple[int, int],
            batch_size: int
    ):
        return ddpm_sample(
            self.unet, self.ddpm_schedule, img_wh, batch_size, self.num_ts
        )

### 2.3 Train your denoiser

In [None]:
import os
import random
import uuid
from pathlib import Path
from typing import Tuple

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torchvision.utils import make_grid, save_image
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# ──────────────────────────────── Hyper-parameters ──────────────────────────────
run_name = "uncond_ddpm_mnist"
num_hidden = 256
batch_size = 128
num_epochs = 30
lr = 2e-4
gamma = 0.1 ** (1.0 / num_epochs)  # exponential LR decay
img_wh: Tuple[int, int] = (28, 28)
eval_batch_size = 20
T = 300  # diffusion steps
checkpoint_dir = Path("checkpoints")
checkpoint_dir.mkdir(exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ──────────────────────────────── Data loaders ──────────────────────────────────
train_data = MNIST(root="./data", train=True, download=True, transform=ToTensor())
test_data = MNIST(root="./data", train=False, download=True, transform=ToTensor())

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)
eval_loader = DataLoader(test_data, batch_size=eval_batch_size, shuffle=True, drop_last=True)

# ──────────────────────────────── Model & DDPM ──────────────────────────────────
denoiser = DenoisingUNet(in_channels=1, num_hiddens=num_hidden)
ddpm = DDPM(denoiser, num_ts=T).to(device)

optimizer = torch.optim.Adam(ddpm.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)

# ──────────────────────────────── Tracking buffers ──────────────────────────────
batch_losses = []
epoch_losses = []

# ──────────────────────────────── Training loop ─────────────────────────────────
for epoch in range(num_epochs):
    ddpm.train()
    running = 0.0

    pbar = tqdm(enumerate(train_loader, 1),
                total=len(train_loader),
                desc=f"Epoch {epoch + 1:02d}/{num_epochs:02d}")

    for step, (data, _labels) in pbar:
        data = data.to(device)
        optimizer.zero_grad()

        loss = ddpm(data)  # forward & internally samples t
        loss.backward()
        optimizer.step()

        batch_loss = loss.item()
        running += batch_loss
        batch_losses.append(batch_loss)

        pbar.set_postfix(loss=f"{batch_loss:.4f}")

    avg_epoch_loss = running / len(train_loader)
    epoch_losses.append(avg_epoch_loss)
    print(f"Epoch {epoch + 1:02d}: mean loss = {avg_epoch_loss:.4f}")

    scheduler.step()

    # ───────────────────────── evaluation ─────────────────────────
    ddpm.eval()
    with torch.no_grad():
        noise_batch = next(iter(eval_loader))[0].to(device)  # just to get batch size/shape
        samples = ddpm.sample(img_wh=img_wh,
                              batch_size=eval_batch_size,
                              device=device)

    # grid & save
    grid = make_grid(samples, nrow=int(eval_batch_size ** 0.5), normalize=True, value_range=(0, 1))
    save_path = Path(f"samples_epoch{epoch + 1:02d}.png")
    save_image(grid, save_path)
    print(f"Saved sample grid to {save_path}")

    # ───────────────────────── checkpoint ────────────────────────
    random_id = random.randint(1000, 9999)
    ckpt_name = f"{run_name}-{random_id:04d}-epoch{epoch + 1:02d}.pt"
    torch.save({
        "epoch": epoch + 1,
        "model_state": ddpm.state_dict(),
        "opt_state": optimizer.state_dict(),
        "scheduler_state": scheduler.state_dict(),
        "batch_losses": batch_losses,
        "epoch_losses": epoch_losses,
    }, checkpoint_dir / ckpt_name)
    print(f"Checkpoint saved as {ckpt_name}")

# ──────────────────────────────── Plot losses ───────────────────────────────────
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(batch_losses, label="batch loss")
plt.title("Batch loss")
plt.xlabel("iteration")
plt.ylabel("MSE")
plt.grid(True);
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(epoch_losses, marker="o", label="epoch loss")
plt.title("Epoch loss")
plt.xlabel("epoch")
plt.ylabel("MSE")
plt.grid(True);
plt.legend()
plt.savefig("loss_curves.png")
print("Loss curves saved to loss_curves.png")
plt.show()


Epoch 01/30:   0%|          | 0/468 [00:00<?, ?it/s]

## 3. Implementing class-conditioned diffusion framework with CFG


###3.1 Adding Class-Conditioning to UNet architecture

In [None]:
class ConditionalDenoisingUNet(nn.Module):
    def __init__(
            self,
            in_channels: int,
            num_classes: int,
            num_hiddens: int,
    ):
        super().__init__()
        # YOUR CODE HERE.

    def forward(
            self,
            x: torch.Tensor,
            c: torch.Tensor,
            t: torch.Tensor,
            mask: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """
        Args:
            x: (N, C, H, W) input tensor.
            c: (N, num_classes) float condition tensor.
            t: (N, 1) normalized time tensor.
            mask: (N, 1) mask tensor. If not None, mask out condition when mask == 0.

        Returns:
            (N, C, H, W) output tensor.
        """
        assert x.shape[-2:] == (28, 28), "Expect input shape to be (28, 28)."
        # YOUR CODE HERE.
        raise NotImplementedError()

###3.2 DDPM Forward and Inverse Process with CFG

In [None]:
def ddpm_forward(
        unet: ConditionalDenoisingUNet,
        ddpm_schedule: dict,
        x_0: torch.Tensor,
        c: torch.Tensor,
        p_uncond: float,
        num_ts: int,
) -> torch.Tensor:
    """Algorithm 3 (not including gradient step).

    Args:
        unet: ConditionalDenoisingUNet
        ddpm_schedule: dict
        x_0: (N, C, H, W) input tensor.
        c: (N,) int64 condition tensor.
        p_uncond: float, probability of unconditioning the condition.
        num_ts: int, number of timesteps.

    Returns:
        (,) diffusion loss.
    """
    unet.train()
    # YOUR CODE HERE.
    raise NotImplementedError()

In [None]:
@torch.inference_mode()
def ddpm_cfg_sample(
        unet: ConditionalDenoisingUNet,
        ddpm_schedule: dict,
        c: torch.Tensor,
        img_wh: tuple[int, int],
        num_ts: int,
        guidance_scale: float = 5.0
) -> torch.Tensor:
    """Algorithm 4.

    Args:
        unet: ConditionalDenoisingUNet
        ddpm_schedule: dict
        c: (N,) int64 condition tensor. Only for class-conditional
        img_wh: (H, W) output image width and height.
        num_ts: int, number of timesteps.
        guidance_scale: float, CFG scale.

    Returns:
        (N, C, H, W) final sample.
    """
    unet.eval()
    # YOUR CODE HERE.
    raise NotImplementedError()

In [None]:
# Do Not Modify
class DDPM(nn.Module):
    def __init__(
            self,
            unet: ConditionalDenoisingUNet,
            betas: tuple[float, float] = (1e-4, 0.02),
            num_ts: int = 300,
            p_uncond: float = 0.1,
    ):
        super().__init__()
        self.unet = unet
        self.betas = betas
        self.num_ts = num_ts
        self.p_uncond = p_uncond
        self.ddpm_schedule = ddpm_schedule(betas[0], betas[1], num_ts)

    def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (N, C, H, W) input tensor.
            c: (N,) int64 condition tensor.

        Returns:
            (,) diffusion loss.
        """
        return ddpm_forward(
            self.unet, self.ddpm_schedule, x, c, self.p_uncond, self.num_ts
        )

    @torch.inference_mode()
    def sample(
            self,
            c: torch.Tensor,
            img_wh: tuple[int, int],
            guidance_scale: float = 5.0
    ):
        return ddpm_cfg_sample(
            self.unet, self.ddpm_schedule, c, img_wh, self.num_ts, guidance_scale
        )

###3.3 Train your class-conditioned denoiser

In [None]:
# YOUR CODE HERE.

###3.4 Experiment with different guidance sacles

In [None]:
# YOUR CODE HERE.
