# MP5: Training Your Diffusion Model!

## Setup environment

In [None]:
# Import essential modules. Feel free to add whatever you need.
import matplotlib.pyplot as plt
import torch # added
from torch import optim # added
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torchvision import transforms

## Visualization helper function

In [None]:
def visualize_images_with_titles(images: torch.Tensor, column_names: list[str]):
    """
    Visualize images as a grid and title the columns with the provided names.

    Args:
        images: (N, C, H, W) tensor of images, where N is (number of rows * number of columns)
        column_names: List of column names for the titles.

    Example usage:
    visualize_images_with_titles(torch.randn(16, 1, 32, 32), ['1', '2', '3', '4'])
    """
    num_images, num_columns = images.shape[0], len(column_names)
    assert num_images % num_columns == 0, 'Number of images must be a multiple of the number of columns.'

    num_rows = num_images // num_columns
    fig, axes = plt.subplots(num_rows, num_columns, figsize=(num_columns * 1, num_rows * 1))

    for i, ax in enumerate(axes.flat):
        img = images[i].permute(1, 2, 0).cpu().numpy()
        ax.imshow(img, cmap='gray')
        ax.axis('off')
        if i < num_columns:
            ax.set_title(column_names[i % num_columns])

    plt.tight_layout()
    plt.show()


# Part 1: Training a Single-step Denoising UNet


## Implementing Simple and Composed Ops

In [None]:
class Conv(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.gelu = nn.GELU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.gelu(self.bn(self.conv(x)))


class DownConv(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.gelu = nn.GELU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.gelu(self.bn(self.conv(x)))


class UpConv(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.gelu = nn.GELU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.gelu(self.bn(self.conv(x)))


class Flatten(nn.Module):
    def __init__(self):
        super().__init__()
        self.avg_pool = nn.AvgPool2d(kernel_size=7)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.avg_pool(x)


class Unflatten(nn.Module):
    def __init__(self, in_channels: int):
        super().__init__()
        self.conv = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=7, stride=1, padding=0)
        self.bn = nn.BatchNorm2d(in_channels)
        self.gelu = nn.GELU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.gelu(self.bn(self.conv(x)))


class ConvBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.conv1 = Conv(in_channels, out_channels)
        self.conv2 = Conv(out_channels, out_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = self.conv2(x)
        return x


class DownBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.downconv = DownConv(in_channels, out_channels)
        self.convblock = ConvBlock(out_channels, out_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.downconv(x)
        x = self.convblock(x)
        return x


class UpBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.upconv = UpConv(in_channels, out_channels)
        self.convblock = ConvBlock(out_channels, out_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.upconv(x)#handling imput channels properly
        x = self.convblock(x)
        return x

## Implementing Unconditional UNet

In [None]:
class UnconditionalUNet(nn.Module):
    def __init__(
        self,
        in_channels: int,
        num_hiddens: int,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.num_hiddens = num_hiddens
        self.initial_conv = Conv(in_channels, num_hiddens)
        self.down1 = DownBlock(num_hiddens, num_hiddens * 2)
        self.down2 = DownBlock(num_hiddens * 2, num_hiddens * 4)

        self.flatten = Flatten()#bottlenecking
        self.unflatten = Unflatten(num_hiddens * 4)
        self.up1 = UpBlock(num_hiddens * 8, num_hiddens * 2)
        self.up2 = UpBlock(num_hiddens * 4, num_hiddens)

        self.final_conv = nn.Conv2d(num_hiddens * 2, in_channels, kernel_size=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        assert x.shape[-2:] == (28, 28), "Expect input shape to be (28, 28)."
        x0 = self.initial_conv(x)
        x1 = self.down1(x0)
        x2 = self.down2(x1)
        x3 = self.flatten(x2)
        x4 = self.unflatten(x3)
        x5 = torch.cat([x4, x2], dim=1)
        x6 = self.up1(x5)

        x7 = torch.cat([x6, x1], dim=1)
        x8 = self.up2(x7)
        x9 = torch.cat([x8, x0], dim=1)
        out = self.final_conv(x9)

        return out

## Visualizing the noising process

In [None]:
def add_noise(images, sigma):
    """Add Gaussian noise to images.

    Args:
        images: Clean images tensor of shape (N, C, H, W)
        sigma: Standard deviation of the Gaussian noise

    Returns:
        Noisy images tensor of shape (N, C, H, W)
    """
    noise = torch.randn_like(images) * sigma
    return images + noise


dataset = MNIST(root="data", download=True, transform=ToTensor(), train=True)
sigmas = [0.0, 0.2, 0.4, 0.5, 0.6, 0.8, 1.0]
num_examples = 5

plt.figure(figsize=(15, 7))
for i, sample_idx in enumerate(range(num_examples)):
    image, _ = dataset[sample_idx]
    image = image.unsqueeze(0)

    for j, sigma in enumerate(sigmas):
        noisy_image = add_noise(image, sigma)
        plt.subplot(num_examples, len(sigmas), i * len(sigmas) + j + 1)
        plt.imshow(noisy_image[0, 0].numpy(), cmap='gray')

        if i == 0:
            plt.title(f"σ = {sigma}")

        plt.axis('off')

plt.tight_layout()
plt.suptitle("Varying noise levels on MNIST digits", y=1.02)
plt.show()

## Training a Single-Step Unconditional UNet

- Plot the loss curve
- Sample results on the test set

In [None]:
def train_unconditional_unet(num_epochs=5, sigma=0.5, batch_size=256):
    # Data loaders
    train_dataset = MNIST(root='data', download=True, transform=ToTensor(), train=True)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    test_dataset = MNIST(root='data', download=True, transform=ToTensor(), train=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = UnconditionalUNet(in_channels=1, num_hiddens=128).to(device)

    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    batch_losses = []

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0

        for batch_idx, (clean_images, _) in enumerate(train_loader):
            clean_images = clean_images.to(device)

            noisy_images = add_noise(clean_images, sigma)
            denoised_images = model(noisy_images)
            loss = F.mse_loss(denoised_images, clean_images)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            batch_losses.append(loss.item())

            epoch_loss += loss.item()

            if batch_idx % 100 == 0:
                print(f'Epoch: {epoch+1}/{num_epochs}, Batch: {batch_idx}/{len(train_loader)}, Loss: {loss.item():.6f}')

        avg_epoch_loss = epoch_loss / len(train_loader)
        print(f'Epoch: {epoch+1}/{num_epochs}, Average Loss: {avg_epoch_loss:.6f}')

        if epoch == 0 or epoch == num_epochs - 1:
            with torch.no_grad():
                model.eval()
                clean_images, _ = next(iter(test_loader))
                clean_images = clean_images[:10].to(device)
                noisy_images = add_noise(clean_images, sigma)

                denoised_images = model(noisy_images)
                comparison = torch.cat([clean_images, noisy_images, denoised_images], dim=0)

                column_names = ['Input', 'Noisy (σ=0.5)', 'Output']
                visualize_images_with_titles(comparison.cpu(), column_names)

    #plot
    plt.figure(figsize=(10, 5))
    plt.plot(batch_losses)
    plt.title('Training Losses')
    plt.xlabel('Step')
    plt.ylabel('Loss')
    plt.yscale('log')
    plt.grid(True)
    plt.show()

    return model

In [None]:
unconditional_model = train_unconditional_unet(num_epochs=5, sigma=0.5)

## Out-of-Distribution Testing

In [None]:
def test_out_of_distribution(model, sigmas=[0.0, 0.2, 0.4, 0.5, 0.6, 0.8, 1.0]):
    """Test the model on different noise levels."""
    test_dataset = MNIST(root='data', download=True, transform=ToTensor(), train=False)
    test_loader = DataLoader(test_dataset, batch_size=10, shuffle=True)
    clean_images, _ = next(iter(test_loader))

    device = next(model.parameters()).device
    clean_images = clean_images.to(device)

    model.eval()
    num_samples = len(clean_images)
    num_sigmas = len(sigmas)

    fig, axes = plt.subplots(num_sigmas * 2, num_samples, figsize=(num_samples * 1.2, num_sigmas * 2.4))
    for i, sigma in enumerate(sigmas):
        #adding noise
        with torch.no_grad():
            noisy_images = add_noise(clean_images, sigma)
            denoised_images = model(noisy_images)

        for j in range(num_samples):
            ax_noisy = axes[i*2, j]
            ax_noisy.imshow(noisy_images[j, 0].cpu().numpy(), cmap='gray')
            ax_noisy.axis('off')

            if j == 0:
                ax_noisy.set_title(f"Noisy (σ={sigma})")
            ax_denoised = axes[i*2+1, j]
            ax_denoised.imshow(denoised_images[j, 0].cpu().numpy(), cmap='gray')
            ax_denoised.axis('off')
            if j == 0:
                ax_denoised.set_title(f"Denoised (σ={sigma})")

    plt.tight_layout()
    plt.show()

test_out_of_distribution(unconditional_model)

# Part 2: Training a Diffusion Model

## Implementing a Time-conditioned UNet

In [None]:
class FCBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.linear1 = nn.Linear(in_channels, out_channels)
        self.gelu = nn.GELU()
        self.linear2 = nn.Linear(out_channels, out_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.gelu(self.linear1(x))
        x = self.linear2(x)
        return x


class TimeConditionalUNet(nn.Module):
    def __init__(
        self,
        in_channels: int,
        num_classes: int,
        num_hiddens: int,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.num_hiddens = num_hiddens
        self.initial_conv = Conv(in_channels, num_hiddens)

        self.down1 = DownBlock(num_hiddens, num_hiddens * 2)
        self.down2 = DownBlock(num_hiddens * 2, num_hiddens * 4)

        self.flatten = Flatten()
        self.unflatten = Unflatten(num_hiddens * 4)
        self.up1 = UpBlock(num_hiddens * 8, num_hiddens * 2)
        self.up2 = UpBlock(num_hiddens * 4, num_hiddens)

        self.final_conv = nn.Conv2d(num_hiddens * 2, in_channels, kernel_size=1)

        self.fc1_t = FCBlock(1, num_hiddens * 4)
        self.fc2_t = FCBlock(1, num_hiddens * 2)

    def forward(
        self,
        x: torch.Tensor,
        t: torch.Tensor,
    ) -> torch.Tensor:
        """
        Args:
            x: (N, C, H, W) input tensor.
            t: (N,) normalized time tensor.

        Returns:
            (N, C, H, W) output tensor.
        """
        assert x.shape[-2:] == (28, 28), "Expect input shape to be (28, 28)."
        t = t.view(-1, 1)
        x0 = self.initial_conv(x)
        x1 = self.down1(x0)
        x2 = self.down2(x1)

        x3 = self.flatten(x2)
        x4 = self.unflatten(x3)

        #for debug time embed
        t1 = self.fc1_t(t)
        t2 = self.fc2_t(t)

        t1 = t1.view(-1, self.num_hiddens * 4, 1, 1)
        x4 = x4 + t1

        x5 = torch.cat([x4, x2], dim=1)
        x6 = self.up1(x5)
        t2 = t2.view(-1, self.num_hiddens * 2, 1, 1)
        x6 = x6 + t2

        x7 = torch.cat([x6, x1], dim=1)
        x8 = self.up2(x7)
        x9 = torch.cat([x8, x0], dim=1)
        out = self.final_conv(x9)#(N, in_channels, 28, 28)

        return out

## Implementing DDPM Forward and Inverse Process for Time-conditioned Denoising

In [None]:
def ddpm_schedule(beta1: float, beta2: float, num_ts: int) -> dict:
    """Constants for DDPM training and sampling.

    Arguments:
        beta1: float, starting beta value.
        beta2: float, ending beta value.
        num_ts: int, number of timesteps.

    Returns:
        dict with keys:
            betas: linear schedule of betas from beta1 to beta2.
            alphas: 1 - betas.
            alpha_bars: cumulative product of alphas.
    """
    assert beta1 < beta2 < 1.0, "Expect beta1 < beta2 < 1.0."

    betas = torch.linspace(beta1, beta2, num_ts)
    alphas = 1 - betas
    alpha_bars = torch.cumprod(alphas, dim=0)
    return {
        'betas': nn.Parameter(betas, requires_grad=False),
        'alphas': nn.Parameter(alphas, requires_grad=False),
        'alpha_bars': nn.Parameter(alpha_bars, requires_grad=False)
    }

In [None]:
def ddpm_forward(
    unet: TimeConditionalUNet,
    ddpm_schedule: dict,
    x_0: torch.Tensor,
    num_ts: int,
) -> torch.Tensor:
    """Algorithm 1 of the DDPM paper.

    Args:
        unet: TimeConditionalUNet
        ddpm_schedule: dict
        x_0: (N, C, H, W) input tensor.
        num_ts: int, number of timesteps.
    Returns:
        (,) diffusion loss.
    """
    unet.train()

    device = x_0.device
    batch_size = x_0.shape[0]
    t = torch.randint(1, num_ts + 1, (batch_size,), device=device)
    epsilon = torch.randn_like(x_0)
    alpha_bars = ddpm_schedule['alpha_bars']
    alpha_bars_t = alpha_bars[t - 1].view(-1, 1, 1, 1)
    x_t = torch.sqrt(alpha_bars_t) * x_0 + torch.sqrt(1 - alpha_bars_t) * epsilon
    t_normalized = (t - 1) / (num_ts - 1)#normalized

    epsilon_theta = unet(x_t, t_normalized)

    loss = F.mse_loss(epsilon_theta, epsilon)

    return loss

In [None]:
@torch.inference_mode()
def ddpm_sample(
    unet: TimeConditionalUNet,
    ddpm_schedule: dict,
    img_wh: tuple[int, int],
    num_ts: int,
    seed: int = 0,
):
    unet.eval()
    torch.manual_seed(seed)
    device = next(unet.parameters()).device
    batch_size = 10

    betas = ddpm_schedule['betas']
    alphas = ddpm_schedule['alphas']
    alpha_bars = ddpm_schedule['alpha_bars']

    x_t = torch.randn(batch_size, unet.in_channels, *img_wh, device=device)

    all_samples = [x_t.cpu()]

    for t in range(num_ts, 0, -1):
        t_normalized = torch.full((batch_size,), (t - 1) / (num_ts - 1), device=device)
        z = torch.randn_like(x_t) if t > 1 else torch.zeros_like(x_t)
        with torch.no_grad():
            predicted_noise = unet(x_t, t_normalized)

        alpha_t = alphas[t - 1]
        alpha_bar_t = alpha_bars[t - 1]
        beta_t = betas[t - 1]

        x_0_pred = (x_t - torch.sqrt(1 - alpha_bar_t) * predicted_noise) / torch.sqrt(alpha_bar_t)

        x_0_pred = torch.clamp(x_0_pred, -1.0, 1.0)
        if t > 1:
            alpha_bar_t_minus_1 = alpha_bars[t - 2]
            coef1 = torch.sqrt(alpha_bar_t_minus_1) * beta_t / (1 - alpha_bar_t)
            coef2 = torch.sqrt(alpha_t) * (1 - alpha_bar_t_minus_1) / (1 - alpha_bar_t)
            x_t = coef1 * x_0_pred + coef2 * x_t + torch.sqrt(beta_t) * z
        else:
            x_t = x_0_pred

        if t % (num_ts // 10) == 0 or t == 1:
            all_samples.append(x_t.cpu())

    final_samples = x_t
    all_samples = torch.stack(all_samples, dim=1)

    return final_samples, all_samples

In [None]:
class DDPM(nn.Module):
    def __init__(
        self,
        unet: TimeConditionalUNet,
        betas: tuple[float, float] = (1e-4, 0.02),
        num_ts: int = 300,
        p_uncond: float = 0.1,
    ):
        super().__init__()
        self.unet = unet
        self.num_ts = num_ts
        self.p_uncond = p_uncond

        self.ddpm_schedule = nn.ParameterDict(ddpm_schedule(betas[0], betas[1], num_ts))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (N, C, H, W) input tensor.

        Returns:
            (,) diffusion loss.
        """
        return ddpm_forward(
            self.unet, self.ddpm_schedule, x, self.num_ts
        )

    @torch.inference_mode()
    def sample(
        self,
        img_wh: tuple[int, int],
        seed: int = 0,
    ):
        return ddpm_sample(
            self.unet, self.ddpm_schedule, img_wh, self.num_ts, seed
        )

## Training the Time-conditioned UNet

- Plot the loss curve
- Sample results on the test set

In [None]:
def train_time_conditioned_unet(num_epochs=20, batch_size=128):

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(0.5, 0.5)
    ])
    train_dataset = MNIST(root='data', download=True, transform=transform, train=True)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = TimeConditionalUNet(in_channels=1, num_classes=10, num_hiddens=64).to(device)

    ddpm = DDPM(model, betas=(1e-4, 0.02), num_ts=300).to(device)

    print("Testing time embedding...")#for debug
    test_t = torch.linspace(0, 1, 5).to(device)
    for t in test_t:
        t_batch = torch.ones(2, device=device) * t
        test_input = torch.randn(2, 1, 28, 28).to(device)
        test_out = model(test_input, t_batch)
        print(f"Time {t.item():.2f}: Output range: {test_out.min().item():.2f} to {test_out.max().item():.2f}")

    optimizer = optim.Adam(ddpm.parameters(), lr=2e-4)#1e-3 was leading to slow convergence

    gamma = 0.1**(1.0 / num_epochs)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)

    batch_losses = []

    for epoch in range(num_epochs):
        ddpm.train()
        epoch_loss = 0.0

        for batch_idx, (clean_images, _) in enumerate(train_loader):
            clean_images = clean_images.to(device)
            loss = ddpm(clean_images)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            batch_losses.append(loss.item())

            epoch_loss += loss.item()

            if batch_idx % 100 == 0:
                print(f'Epoch: {epoch+1}/{num_epochs}, Batch: {batch_idx}/{len(train_loader)}, Loss: {loss.item():.6f}')

        scheduler.step()

        avg_epoch_loss = epoch_loss / len(train_loader)
        print(f'Epoch: {epoch+1}/{num_epochs}, Average Loss: {avg_epoch_loss:.6f}')
        #visualization code
        if epoch in [0, 4, 9, 14, 19]:
            with torch.no_grad():
                ddpm.eval()
                num_rows = 4
                num_cols = 10
                fig, axs = plt.subplots(num_rows, num_cols, figsize=(10, num_rows), facecolor='black')

                fig.patch.set_facecolor('black')

                for col in range(num_cols):
                    for row in range(num_rows):
                        samples, _ = ddpm.sample(img_wh=(28, 28), seed=epoch + row * 100 + col)
                        sample = (samples[0, 0].cpu().numpy() + 1) / 2
                        axs[row, col].imshow(sample, cmap='gray')
                        axs[row, col].axis('off')
                        if row == 0:
                            axs[row, col].set_title(f"Sample {col+1}", color='white')

                plt.suptitle(f"Epoch {epoch+1}", color='white', fontsize=16)
                plt.tight_layout(rect=[0, 0, 1, 0.95])
                plt.savefig(f"time_conditioned_samples_epoch_{epoch+1}.png", facecolor='black')
                plt.show()

                samples, animation = ddpm.sample(img_wh=(28, 28), seed=epoch)
                animation = (animation + 1) / 2

                plt.figure(figsize=(12, 3), facecolor='black')
                num_frames = animation.shape[1]
                frames_to_show = min(8, num_frames)
                skip = num_frames // frames_to_show

                for i in range(frames_to_show):
                    idx = i * skip
                    if idx >= num_frames:
                        idx = num_frames - 1
                    plt.subplot(1, frames_to_show, i + 1)
                    plt.imshow(animation[0, idx, 0].cpu().numpy(), cmap='gray')
                    plt.axis('off')
                    if i == 0:
                        plt.title("Noise", color='white')
                    elif i == frames_to_show - 1:
                        plt.title("Final", color='white')
                    else:
                        plt.title(f"Step {idx}", color='white')
                plt.suptitle(f'Sampling process after epoch {epoch+1}', color='white')
                plt.tight_layout()
                plt.savefig(f"time_conditioned_process_epoch_{epoch+1}.png", facecolor='black')
                plt.show()

    plt.figure(figsize=(10, 5))
    plt.plot(batch_losses)
    plt.title('Training Loss')
    plt.xlabel('Step')
    plt.ylabel('Loss')
    plt.yscale('log')
    plt.grid(True)
    plt.show()

    return ddpm

tmodel = train_time_conditioned_unet(num_epochs=20, batch_size=128)

In [None]:
def create_time_conditioned_diffusion_gif(model, sample_index=0, output_dir="time_diffusion_gifs", epochs=[0, 4, 9, 14, 19]):
    """
    Create a GIF showing the diffusion process (noise to image) for time-conditioned generation.
    """
    device = next(model.parameters()).device
    os.makedirs(output_dir, exist_ok=True)
    for epoch in epochs:
        print(f"Creating diffusion process GIF for epoch {epoch+1}")
        samples, animation = model.sample(img_wh=(28, 28), seed=epoch)
        animation = (animation + 1) / 2
        frames = []
        num_frames = animation.shape[1]

        for t in range(num_frames):
            fig, ax = plt.subplots(figsize=(3, 3))
            ax.imshow(animation[0, t, 0].cpu().numpy(), cmap='gray')

            if t == 0:
                stage = "Noise"
            elif t == num_frames - 1:
                stage = "Final"
            else:
                stage = f"Step {t}"

            ax.set_title(f"Sample {sample_index}: {stage}")
            ax.axis('off')

            fig.canvas.draw()
            frame = np.array(fig.canvas.renderer.buffer_rgba())
            plt.close(fig)

            frames.append(frame)
        gif_path = os.path.join(output_dir, f"sample_{sample_index}_process_epoch_{epoch+1}.gif")
        imageio.mimsave(gif_path, frames, duration=0.1)#0.1 seconds per frame for faster animation
        print(f"Saved GIF to {gif_path}")

In [None]:
!pip install imageio tqdm

In [None]:
create_time_conditioned_diffusion_gif(tmodel, sample_index=0)


### Implementing class-conditioned UNet

In [None]:
class ClassConditionalUNet(nn.Module):
    def __init__(
        self,
        in_channels: int,
        num_classes: int,
        num_hiddens: int,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.num_hiddens = num_hiddens
        self.num_classes = num_classes
        self.initial_conv = Conv(in_channels, num_hiddens)
        self.down1 = DownBlock(num_hiddens, num_hiddens * 2)
        self.down2 = DownBlock(num_hiddens * 2, num_hiddens * 4)
        self.flatten = Flatten()#bottleneck 1x1
        self.unflatten = Unflatten(num_hiddens * 4)

        self.up1 = UpBlock(num_hiddens * 8, num_hiddens * 2)
        self.up2 = UpBlock(num_hiddens * 4, num_hiddens)

        self.final_conv = nn.Conv2d(num_hiddens * 2, in_channels, kernel_size=1)

        self.fc1_t = FCBlock(1, num_hiddens * 4)
        self.fc1_c = FCBlock(num_classes, num_hiddens * 4)
        self.fc2_t = FCBlock(1, num_hiddens * 2)
        self.fc2_c = FCBlock(num_classes, num_hiddens * 2)

    def forward(
        self,
        x: torch.Tensor,
        c: torch.Tensor,
        t: torch.Tensor,
        mask: torch.Tensor = None,
    ) -> torch.Tensor:
        """
        Args:
            x: (N, C, H, W) input tensor.
            c: (N,) int64 condition tensor.
            t: (N,) normalized time tensor.
            mask: (N,) mask tensor. If not None, mask out condition when mask == 0.

        Returns:
            (N, C, H, W) output tensor.
        """
        assert x.shape[-2:] == (28, 28), "Expect input shape to be (28, 28)."

        t = t.view(-1, 1)#(N, 1)
        c_one_hot = F.one_hot(c, num_classes=self.num_classes).float()

        if mask is not None:
            mask = mask.view(-1, 1)
            c_one_hot = c_one_hot * mask

        x0 = self.initial_conv(x)

        x1 = self.down1(x0)
        x2 = self.down2(x1)
        x3 = self.flatten(x2)
        x4 = self.unflatten(x3)

        t1 = self.fc1_t(t)
        c1 = self.fc1_c(c_one_hot)
        t2 = self.fc2_t(t)
        c2 = self.fc2_c(c_one_hot)

        t1 = t1.view(-1, self.num_hiddens * 4, 1, 1)
        c1 = c1.view(-1, self.num_hiddens * 4, 1, 1)
        t2 = t2.view(-1, self.num_hiddens * 2, 1, 1)
        c2 = c2.view(-1, self.num_hiddens * 2, 1, 1)
        x4 = c1 * x4 + t1#class conditioning algo

        x5 = torch.cat([x4, x2], dim=1)
        x6 = self.up1(x5)
        x6 = c2 * x6 + t2

        x7 = torch.cat([x6, x1], dim=1)
        x8 = self.up2(x7)
        x9 = torch.cat([x8, x0], dim=1)

        out = self.final_conv(x9)

        return out

In [None]:
def ddpm_forward(
    unet: ClassConditionalUNet,
    ddpm_schedule: dict,
    x_0: torch.Tensor,
    c: torch.Tensor,
    p_uncond: float,
    num_ts: int,
) -> torch.Tensor:
    """Algorithm 3 of the DDPM paper - Class-conditioned training.

    Args:
        unet: ClassConditionalUNet
        ddpm_schedule: dict
        x_0: (N, C, H, W) input tensor (clean image).
        c: (N,) int64 condition tensor (class labels).
        p_uncond: float, probability of unconditioning the condition.
        num_ts: int, number of timesteps.

    Returns:
        (,) diffusion loss.
    """
    unet.train()

    device = x_0.device
    batch_size = x_0.shape[0]
    mask = torch.bernoulli(torch.ones(batch_size, device=device) * (1 - p_uncond))
    t = torch.randint(1, num_ts + 1, (batch_size,), device=device)#step 6 of algo 3

    epsilon = torch.randn_like(x_0)

    alpha_bars = ddpm_schedule['alpha_bars']
    alpha_bars_t = alpha_bars[t - 1].view(-1, 1, 1, 1)
    x_t = torch.sqrt(alpha_bars_t) * x_0 + torch.sqrt(1 - alpha_bars_t) * epsilon
    t_normalized = (t - 1) / (num_ts - 1)
    epsilon_pred = unet(x_t, c, t_normalized, mask)

    loss = F.mse_loss(epsilon_pred, epsilon)

    return loss

In [None]:
@torch.inference_mode()
def ddpm_sample(
    unet: ClassConditionalUNet,
    ddpm_schedule: dict,
    c: torch.Tensor,
    img_wh: tuple[int, int],
    num_ts: int,
    guidance_scale: float = 5.0,
    seed: int = 0,
) -> torch.Tensor:
    """Algorithm 4 of the DDPM paper - Class-conditioned sampling with classifier-free guidance.

    Args:
        unet: ClassConditionalUNet
        ddpm_schedule: dict
        c: (N,) int64 condition tensor - class labels
        img_wh: (H, W) output image width and height.
        num_ts: int, number of timesteps.
        guidance_scale: float, classifier-free guidance scale (γ).
        seed: int, random seed.

    Returns:
        (N, C, H, W) final sample.
        (N, T_animation, C, H, W) caches.
    """
    unet.eval()
    torch.manual_seed(seed)

    device = next(unet.parameters()).device
    batch_size = c.shape[0]

    x_t = torch.randn(batch_size, unet.in_channels, *img_wh, device=device)
    betas = ddpm_schedule['betas']
    alphas = ddpm_schedule['alphas']
    alpha_bars = ddpm_schedule['alpha_bars']

    all_samples = [x_t.cpu()]

    for t in range(num_ts, 0, -1):
        t_normalized = torch.full((batch_size,), (t - 1) / (num_ts - 1), device=device)
        z = torch.randn_like(x_t) if t > 1 else torch.zeros_like(x_t)
        zero_mask = torch.zeros(batch_size, device=device)
        epsilon_u = unet(x_t, c, t_normalized, zero_mask)

        ones_mask = torch.ones(batch_size, device=device)
        epsilon_c = unet(x_t, c, t_normalized, ones_mask)

        epsilon = epsilon_u + guidance_scale * (epsilon_c - epsilon_u)

        alpha_t = alphas[t - 1]
        alpha_bar_t = alpha_bars[t - 1]
        beta_t = betas[t - 1]

        x_0_pred = (x_t - torch.sqrt(1 - alpha_bar_t) * epsilon) / torch.sqrt(alpha_bar_t)
        x_0_pred = torch.clamp(x_0_pred, -1.0, 1.0)

        if t > 1:
            alpha_bar_t_minus_1 = alpha_bars[t - 2]
            coef1 = torch.sqrt(alpha_bar_t_minus_1) * beta_t / (1 - alpha_bar_t)
            coef2 = torch.sqrt(alpha_t) * (1 - alpha_bar_t_minus_1) / (1 - alpha_bar_t)
            x_t = coef1 * x_0_pred + coef2 * x_t + torch.sqrt(beta_t) * z
        else:
            x_t = x_0_pred

        if t % (num_ts // 10) == 0 or t == 1:
            all_samples.append(x_t.cpu())

    final_samples = x_t
    all_samples = torch.stack(all_samples, dim=1)

    return final_samples, all_samples

In [None]:
class DDPM(nn.Module):
    def __init__(
        self,
        unet: ClassConditionalUNet,
        betas: tuple[float, float] = (1e-4, 0.02),
        num_ts: int = 300,
        p_uncond: float = 0.1,
    ):
        super().__init__()
        self.unet = unet
        self.betas = betas
        self.num_ts = num_ts
        self.p_uncond = p_uncond
        self.ddpm_schedule = nn.ParameterDict(ddpm_schedule(betas[0], betas[1], num_ts))

    def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (N, C, H, W) input tensor.
            c: (N,) int64 condition tensor.

        Returns:
            (,) diffusion loss.
        """
        return ddpm_forward(
            self.unet, self.ddpm_schedule, x, c, self.p_uncond, self.num_ts
        )

    @torch.inference_mode()
    def sample(
        self,
        c: torch.Tensor,
        img_wh: tuple[int, int],
        guidance_scale: float = 5.0,
        seed: int = 0,
    ):
        return ddpm_sample(
            self.unet, self.ddpm_schedule, c, img_wh, self.num_ts, guidance_scale, seed
        )

In [None]:
def visualize_class_conditioned_samples(model, epoch, num_rows=4, guidance_scale=5.0, seed=None):
    #for visualization
    device = next(model.parameters()).device
    if seed is None:
        seed = epoch
    plt.figure(figsize=(12, num_rows * 1.2))
    all_samples = []
    for digit in range(10):
        digit_samples = []
        for row in range(num_rows):
            c = torch.tensor([digit], device=device)
            samples, _ = model.sample(
                c,
                img_wh=(28, 28),
                guidance_scale=guidance_scale,
                seed=seed + row * 100
            )

            samples = (samples + 1) / 2

            digit_samples.append(samples[0, 0].cpu().numpy())

    fig, axs = plt.subplots(num_rows, 10, figsize=(10, num_rows), facecolor='black')

    fig.patch.set_facecolor('black')

    for digit in range(10):
        for row in range(num_rows):
            c = torch.tensor([digit], device=device)
            samples, _ = model.sample(c, img_wh=(28, 28), guidance_scale=guidance_scale, seed=seed + row * 100 + digit)
            sample = (samples[0, 0].cpu().numpy() + 1) / 2

            axs[row, digit].imshow(sample, cmap='gray')
            axs[row, digit].axis('off')

            if row == 0:
                axs[row, digit].set_title(str(digit), color='white')

    plt.suptitle(f"Epoch {epoch}", color='white', fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.savefig(f"class_conditioned_samples_epoch_{epoch}.png", facecolor='black')
    plt.show()

    return fig

## Training the Class-conditioned UNet

- Plot the loss curve
- Sample results on the test set

In [None]:
def train_class_conditioned_unet(num_epochs=20, batch_size=128, guidance_scale=5.0):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(0.5, 0.5)
    ])
    train_dataset = MNIST(root='data', download=True, transform=transform, train=True)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = ClassConditionalUNet(in_channels=1, num_classes=10, num_hiddens=64).to(device)

    ddpm = DDPM(model, betas=(1e-4, 0.02), num_ts=300, p_uncond=0.1).to(device)
    optimizer = optim.Adam(ddpm.parameters(), lr=1e-3)

    gamma = 0.1**(1.0 / num_epochs)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)

    batch_losses = []

    for epoch in range(num_epochs):
        ddpm.train()
        epoch_loss = 0.0

        for batch_idx, (clean_images, labels) in enumerate(train_loader):
            clean_images = clean_images.to(device)
            labels = labels.to(device)

            loss = ddpm(clean_images, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            batch_losses.append(loss.item())

            epoch_loss += loss.item()

            if batch_idx % 100 == 0:
                print(f'Epoch: {epoch+1}/{num_epochs}, Batch: {batch_idx}/{len(train_loader)}, Loss: {loss.item():.6f}')

        scheduler.step()

        avg_epoch_loss = epoch_loss / len(train_loader)
        print(f'Epoch: {epoch+1}/{num_epochs}, Average Loss: {avg_epoch_loss:.6f}')

        if epoch in [0, 4, 9, 14, 19]:
            with torch.no_grad():
                ddpm.eval()

                visualize_class_conditioned_samples(ddpm, epoch+1, num_rows=4, guidance_scale=guidance_scale)

    plt.figure(figsize=(10, 5))
    plt.plot(batch_losses)
    plt.title('Training Loss')
    plt.xlabel('Step')
    plt.ylabel('Loss')
    plt.yscale('log')
    plt.grid(True)
    plt.show()

    return ddpm

In [None]:
model = train_class_conditioned_unet(num_epochs=20, batch_size=128, guidance_scale=5.0)

In [None]:
!pip install imageio tqdm

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import transforms
import imageio
import os
from tqdm import tqdm

def create_generation_gif(model, num_epochs=20, output_dir="generation_gifs", digits=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]):
    """
    Create GIFs showing the generation of digits from noise over multiple epochs.
    """
    device = next(model.parameters()).device
    os.makedirs(output_dir, exist_ok=True)
    frames_by_digit = {digit: [] for digit in digits}
    for epoch in range(num_epochs):
        print(f"Processing epoch {epoch+1}/{num_epochs}")
        for digit in digits:
            c = torch.tensor([digit], device=device)
            samples, animation = model.sample(c, img_wh=(28, 28), guidance_scale=5.0, seed=epoch)
            animation = (animation + 1) / 2
            final_frame = animation[0, -1, 0].cpu().numpy()
            fig, ax = plt.subplots(figsize=(3, 3))
            ax.imshow(final_frame, cmap='gray')
            ax.set_title(f"Digit {digit}, Epoch {epoch+1}")
            ax.axis('off')
            fig.canvas.draw()
            frame = np.array(fig.canvas.renderer.buffer_rgba())
            plt.close(fig)
            frames_by_digit[digit].append(frame)

    for digit, frames in frames_by_digit.items():
        print(f"Creating GIF for digit {digit}")
        gif_path = os.path.join(output_dir, f"digit_{digit}_evolution.gif")

        imageio.mimsave(gif_path, frames, duration=0.5)#create with imageio
        print(f"Saved GIF to {gif_path}")

    print("Creating combined GIF with all digits")
    combined_frames = []
    for epoch in range(num_epochs):
        fig, axes = plt.subplots(2, 5, figsize=(10, 5))
        axes = axes.flatten()

        for i, digit in enumerate(digits):
            digit_frame = frames_by_digit[digit][epoch]
            axes[i].imshow(digit_frame)
            axes[i].set_title(f"Digit {digit}")
            axes[i].axis('off')

        fig.suptitle(f"Epoch {epoch+1}")
        fig.tight_layout()
        fig.canvas.draw()
        combined_frame = np.array(fig.canvas.renderer.buffer_rgba())
        plt.close(fig)

        combined_frames.append(combined_frame)

    combined_gif_path = os.path.join(output_dir, "all_digits_evolution.gif")
    imageio.mimsave(combined_gif_path, combined_frames, duration=1.0)
    print(f"Saved combined GIF to {combined_gif_path}")

    return frames_by_digit

In [None]:
def create_diffusion_process_gif(model, digit=5, output_dir="diffusion_gifs", epochs=[0, 4, 9, 14, 19]):
    """
    Create a GIF showing the diffusion process (noise to image) for a specific digit.
    """
    device = next(model.parameters()).device
    os.makedirs(output_dir, exist_ok=True)
    for epoch in epochs:
        print(f"Creating diffusion process GIF for epoch {epoch+1}")
        c = torch.tensor([digit], device=device)
        samples, animation = model.sample(c, img_wh=(28, 28), guidance_scale=5.0, seed=epoch)
        animation = (animation + 1) / 2

        frames = []
        num_frames = animation.shape[1]

        for t in range(num_frames):
            fig, ax = plt.subplots(figsize=(3, 3))
            ax.imshow(animation[0, t, 0].cpu().numpy(), cmap='gray')

            if t == 0:
                stage = "Noise"
            elif t == num_frames - 1:
                stage = "Final"
            else:
                stage = f"Step {t}"

            ax.set_title(f"Digit {digit}: {stage}")
            ax.axis('off')
            fig.canvas.draw()
            frame = np.array(fig.canvas.renderer.buffer_rgba())
            plt.close(fig)

            frames.append(frame)

        gif_path = os.path.join(output_dir, f"digit_{digit}_process_epoch_{epoch+1}.gif")
        imageio.mimsave(gif_path, frames, duration=0.1)# this uses 0.1 sec per frame
        print(f"Saved GIF to {gif_path}")

In [None]:
create_generation_gif(model, num_epochs=20)