# Assignment 2: Diffusion Model for Image Generation

## Meta Instructions
1. Environment: Install pytorch >= 1.7.0 and torchvision >= 0.8.0 to avoid any issues caused by package versions.

2. Finish coding tasks according to instructions in this file. **Only** change the code the code within the regions marked by “Your code starts here” and “Your code ends here”. Do not modify anywhere else including comments. This task does not require GPUs and can meet the basic requirements.

3. Please use the trained diffusion model to generate an image that contains the string of number in your StuID (e.g., if your StuID is A0123456J, please generate 0123456) and save it as a single jpg file, it is encouraged that you train the diffusion model by yourself. If you cannot find the computing resources, you can try to use this pretrained model for generating locally. (https://drive.google.com/file/d/1-9dozojlZxdpkei5lyxFeKofKYq4rQIf/view?usp=sharing)

4. Submission: submit a zip file named "StuID.zip" (e.g., "A0123456J.zip") to Canvas **Assignments -> Assignment2**. Note that it is **NOT** NUSNET ID. The zip file should **only** include "StuID_Assignment_2.ipynb", "StuID_Assignment_2.pdf" and "StuID_Assignment_2.jpg". The submissison deadline is **23:59 on Feb 21**.


### Optional Puzzles for Extra Credits

You are able to get a full marks of this assignment by correctly completing and submitting the code and PDF report according to the above requirements. However, if you want to earn 1-2 bonus points in your total assignment grade (not exceeding the maximum score), you can choose to solve one of the following optional puzzles and include it in your final codes and reports:

1. Puzzle 1: How to generate a two-digit image using an end-to-end diffusion model? (e.g., generate 21 directly)
2. Puzzle 2: Besides simply sampling the noise from a Gaussian distribution, what are some alternative sampling strategies that could improve the generation process? (you can refer to https://arxiv.org/pdf/2501.09732)


### For any questions, please do one of the following actions with priority:
1. Search for similar questions on Slack (https://app.slack.com/client/T088V95D8LC/C088L557RK8).
2. Propose a new question on Slack if not already answered.
3. For private inquiries, e-mail to Pengfei Zhou (e1374451@u.nus.edu) and Xiangyan Liu (e0950125@u.nus.edu) with the subject starting with "CS5260 2025 Spring"


# Assignment 3: Diffusion Model for Image Generation

Diffusion Models, including the Denoising Diffusion Probabilistic Model (SimpleDiffusion), represent a powerful class of generative models that simulate the gradual process of diffusing data into noise and then denoising it back into coherent samples.

Inspired by the natural diffusion process observed in physics, SimpleDiffusion operates in two phases: a forward phase where data is incrementally noised until it becomes indistinguishable from random noise, and a reverse phase where this process is inverted, gradually denoising to generate new data samples that closely mimic the original data distribution.

This innovative approach allows SimpleDiffusion to produce high-quality, diverse outputs across various domains such as images, audio, and text, without the adversarial training complexities associated with other generative models.

**References:**
- Denoising Diffusion Probabilistic Models (SimpleDiffusion) Paper:  [DDPM](https://arxiv.org/abs/2006.11239)
- Blog about Diffusion models: [DDPM Blog](https://lilianweng.github.io/posts/2021-07-11-diffusion-models/)

This assignment requires you to implement a simple DDPM-based diffusion model for generating handwritten digits images (32×32×3).

## Import Packages

It is to import the necessary packages.

In [1]:
from typing import Dict, Tuple
from tqdm import tqdm
import os
os.makedirs('./data/diffusion_outputs10', exist_ok=True)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import models, transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import numpy as np

## Create dataset

Load the MNIST dataset and apply the following preprocessing:

Resize images to image_size×image_size (you can use the torchvision.transforms.Resize function).
Normalize pixel values to the range [-1, 1].

In [2]:
def create_mnist_dataloaders(batch_size, image_size=32, num_workers=4):

    ################################
    # Your code starts here
    ################################
    preprocess = transforms.Compose([
        transforms.Resize((image_size,image_size)),
        transforms.ToTensor(),
        #[-1, 1]
        transforms.Normalize((0.5,), (0.5,))
    ])



    ################################
    # Your code ends here
    ################################

    train_dataset = torchvision.datasets.MNIST(
        root="./mnist_data",
        train=True,
        download=True,
        transform=preprocess
    )

    test_dataset = torchvision.datasets.MNIST(
        root="./mnist_data",
        train=False,
        download=True,
        transform=preprocess
    )

    return DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers),\
           DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)


# Model Architecture: U-Net

Build a neural network to predict noise. You can use a U-Net or any convolutional architecture of your choice. The network should take as input the noisy image along with the time-step information and output the predicted noise. Make sure that:

- The network handles MNIST images (e.g., 32×32).

- The network has a bottleneck structure with multiple downsampling and upsampling layers.

- The network includes residual connections.



In [3]:
class ResidualConvBlock(nn.Module):
    def __init__(
        self, in_channels: int, out_channels: int, is_res: bool = False
    ) -> None:
        super().__init__()
        '''
        standard ResNet style convolutional block
        '''
        self.same_channels = in_channels==out_channels
        self.is_res = is_res
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.GELU(),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.GELU(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        ################################
        # Your code starts here
        ################################
        x1 = self.conv1(x)
        x2 = self.conv2(x1)
        if self.is_res:
            if self.same_channels:
                out = x + x2
            else:
                out = x1 + x2
            return out / 1.414
        else:
            x1 = self.conv1(x)
            x2 = self.conv2(x1)
            return x2


        ################################
        # Your code ends here
        ################################

class UnetDown(nn.Module):

    ################################
    # Your code starts here
    ################################
    def __init__(self, in_channels, out_channels):
        super(UnetDown, self).__init__()
        layers = [ResidualConvBlock(in_channels, out_channels), nn.MaxPool2d(2)]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


    ################################
    # Your code ends here
    ################################

class UnetUp(nn.Module):

    ################################
    # Your code starts here
    ################################
    def __init__(self, in_channels, out_channels):
        super(UnetUp, self).__init__()
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
            ResidualConvBlock(out_channels, out_channels),
            ResidualConvBlock(out_channels, out_channels),
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, x, skip):
        x = torch.cat((x, skip), 1)
        x = self.model(x)
        return x


    ################################
    # Your code ends here
    ################################


class EmbedFC(nn.Module):
    def __init__(self, input_dim, emb_dim):
        super(EmbedFC, self).__init__()
        '''
        generic one layer FC NN for embedding things
        '''
        self.input_dim = input_dim
        layers = [
            nn.Linear(input_dim, emb_dim),
            nn.GELU(),
            nn.Linear(emb_dim, emb_dim),
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = x.view(-1, self.input_dim)
        return self.model(x)


class ContextUnet(nn.Module):
    def __init__(self, in_channels, n_feat = 256, n_classes=10):
        super(ContextUnet, self).__init__()

        self.in_channels = in_channels
        self.n_feat = n_feat
        self.n_classes = n_classes

        self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)

        self.down1 = UnetDown(n_feat, n_feat)
        self.down2 = UnetDown(n_feat, 2 * n_feat)

        self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())

        self.timeembed1 = EmbedFC(1, 2*n_feat)
        self.timeembed2 = EmbedFC(1, 1*n_feat)
        self.contextembed1 = EmbedFC(n_classes, 2*n_feat)
        self.contextembed2 = EmbedFC(n_classes, 1*n_feat)

        self.up0 = nn.Sequential(
            # nn.ConvTranspose2d(6 * n_feat, 2 * n_feat, 7, 7), # when concat temb and cemb end up w 6*n_feat
            nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, 8, 8), # otherwise just have 2*n_feat
            nn.GroupNorm(8, 2 * n_feat),
            nn.ReLU(),
        )

        self.up1 = UnetUp(4 * n_feat, n_feat)
        self.up2 = UnetUp(2 * n_feat, n_feat)
        self.out = nn.Sequential(
            nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1),
            nn.GroupNorm(8, n_feat),
            nn.ReLU(),
            nn.Conv2d(n_feat, self.in_channels, 3, 1, 1),
        )

    def forward(self, x, c, t, context_mask):
        # x is (noisy) image, c is context label, t is timestep,
        # context_mask says which samples to block the context on

        x = self.init_conv(x)
        down1 = self.down1(x)
        down2 = self.down2(down1)
        hiddenvec = self.to_vec(down2)

        # convert context to one hot embedding
        c = nn.functional.one_hot(c, num_classes=self.n_classes).type(torch.float)

        # mask out context if context_mask == 1
        context_mask = context_mask[:, None]
        context_mask = context_mask.repeat(1,self.n_classes)
        context_mask = (-1*(1-context_mask)) # need to flip 0 <-> 1
        c = c * context_mask

        # embed context, time step
        cemb1 = self.contextembed1(c).view(-1, self.n_feat * 2, 1, 1)
        temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
        cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1)
        temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)

        # could concatenate the context embedding here instead of adaGN
        # hiddenvec = torch.cat((hiddenvec, temb1, cemb1), 1)

        up1 = self.up0(hiddenvec)
        #up2 = self.up1(up1, down2) # if want to avoid add and multiply embeddings
        up2 = self.up1(cemb1*up1+ temb1, down2)  # add and multiply embeddings
        up3 = self.up2(cemb2*up2+ temb2, down1)
        out = self.out(torch.cat((up3, x), 1))
        return out

## The Diffusion Model

According to DDPM, define a noise schedule (e.g., a linear schedule) and implement the forward diffusion process. Given an original image, generate a noisy image at a randomly chosen time step t.

Implement the noise addition process and reverse process.

In [7]:

def ddpm_schedules(beta1, beta2, T):
    """
    Returns pre-computed schedules for DDPM sampling, training process.
    """
    assert beta1 < beta2 < 1.0, "beta1 and beta2 must be in (0, 1)"

    beta_t = (beta2 - beta1) * torch.arange(0, T + 1, dtype=torch.float32) / T + beta1
    sqrt_beta_t = torch.sqrt(beta_t)
    alpha_t = 1 - beta_t
    log_alpha_t = torch.log(alpha_t)
    alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp()

    sqrtab = torch.sqrt(alphabar_t)
    oneover_sqrta = 1 / torch.sqrt(alpha_t)

    sqrtmab = torch.sqrt(1 - alphabar_t)
    mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab

    return {
        "alpha_t": alpha_t,  # \alpha_t
        "oneover_sqrta": oneover_sqrta,  # 1/\sqrt{\alpha_t}
        "sqrt_beta_t": sqrt_beta_t,  # \sqrt{\beta_t}
        "alphabar_t": alphabar_t,  # \bar{\alpha_t}
        "sqrtab": sqrtab,  # \sqrt{\bar{\alpha_t}}
        "sqrtmab": sqrtmab,  # \sqrt{1-\bar{\alpha_t}}
        "mab_over_sqrtmab": mab_over_sqrtmab_inv,  # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}}
    }

In [9]:
class SimpleDiffusion(nn.Module):
    def __init__(self, nn_model, betas, n_T, device, drop_prob=0.1):
        super(SimpleDiffusion, self).__init__()
        self.nn_model = nn_model.to(device)

        # register_buffer allows accessing dictionary produced by ddpm_schedules
        # e.g. can access self.sqrtab later
        for k, v in ddpm_schedules(betas[0], betas[1], n_T).items():
            self.register_buffer(k, v)

        self.n_T = n_T
        self.device = device
        self.drop_prob = drop_prob
        self.loss_mse = nn.MSELoss()

    def forward(self, x, c):
        """
        this method is used in training, so samples t and noise randomly
        """

        _ts = torch.randint(1, self.n_T+1, (x.shape[0],)).to(self.device)  # t ~ Uniform(0, n_T)
        noise = torch.randn_like(x)  # eps ~ N(0, 1)


        x_t = (
        ################################
        # Your code starts here
        ################################

        self.sqrtab[_ts,None,None,None]*x + self.sqrtmab[_ts,None,None,None]*noise

        ################################
        # Your code ends here
        ################################
        )  # This is the x_t, which is sqrt(alphabar) x_0 + sqrt(1-alphabar) * eps
        # We should predict the "error term" from this x_t. Loss is what we return.

        # dropout context with some probability
        context_mask = torch.bernoulli(torch.zeros_like(c)+self.drop_prob).to(self.device)

        # return MSE between added noise, and our predicted noise
        return self.loss_mse(noise, self.nn_model(x_t, c, _ts / self.n_T, context_mask))

    def sample(self, n_sample, size, device, guide_w = 0.0):
        # we follow the guidance sampling scheme described in 'Classifier-Free Diffusion Guidance'
        # to make the fwd passes efficient, we concat two versions of the dataset,
        # one with context_mask=0 and the other context_mask=1
        # we then mix the outputs with the guidance scale, w
        # where w>0 means more guidance

        x_i = torch.randn(n_sample, *size).to(device)  # x_T ~ N(0, 1), sample initial noise
        c_i = torch.arange(0,10).to(device) # context for us just cycles throught the mnist labels
        c_i = c_i.repeat(int(n_sample/c_i.shape[0]))

        # don't drop context at test time
        context_mask = torch.zeros_like(c_i).to(device)

        # double the batch
        c_i = c_i.repeat(2)
        context_mask = context_mask.repeat(2)
        context_mask[n_sample:] = 1. # makes second half of batch context free

        x_i_store = [] # keep track of generated steps in case want to plot something
        print()
        for i in range(self.n_T, 0, -1):
            print(f'sampling timestep {i}',end='\r')
            t_is = torch.tensor([i / self.n_T]).to(device)
            t_is = t_is.repeat(n_sample,1,1,1)

            # double batch
            x_i = x_i.repeat(2,1,1,1)
            t_is = t_is.repeat(2,1,1,1)

            z = torch.randn(n_sample, *size).to(device) if i > 1 else 0

            # split predictions and compute weighting
            eps = self.nn_model(x_i, c_i, t_is, context_mask)
            eps1 = eps[:n_sample]
            eps2 = eps[n_sample:]
            eps = (1+guide_w)*eps1 - guide_w*eps2
            x_i = x_i[:n_sample]

            ################################
            # Your code starts here
            ################################
            x_i = (self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i]) + self.sqrt_beta_t[i] * z)


            ################################
            # Your code ends here
            ################################

            if i%20==0 or i==self.n_T or i<8:
                x_i_store.append(x_i.detach().cpu().numpy())

        x_i_store = np.array(x_i_store)
        return x_i, x_i_store

## Trainer

Including the main function.

You should complete the inference code that uses the trained model to generate desired images.

In [None]:
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def train_mnist():

    imagesize = 32
    # hardcoding these here
    n_epoch = 20
    batch_size = 256
    n_T = 400 # 500
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    n_classes = 10
    n_feat = 128 # 128 ok, 256 better (but slower)
    lrate = 1e-4
    save_model = False
    save_dir = './data/diffusion_outputs10/'
    ws_test = [0.0, 0.5, 2.0] # strength of generative guidance

    ddpm = SimpleDiffusion(nn_model=ContextUnet(in_channels=1, n_feat=n_feat, n_classes=n_classes), betas=(1e-4, 0.02), n_T=n_T, device=device, drop_prob=0.1)
    ddpm.to(device)

    # optionally load a model
    # ddpm.load_state_dict(torch.load("./data/diffusion_outputs/ddpm_unet01_mnist_9.pth"))

    dataloader, _ = create_mnist_dataloaders(batch_size=batch_size, image_size=imagesize, num_workers=4)
    optim = torch.optim.Adam(ddpm.parameters(), lr=lrate)

    for ep in range(n_epoch):
        print(f'epoch {ep}')
        ddpm.train()

        # linear lrate decay
        optim.param_groups[0]['lr'] = lrate*(1-ep/n_epoch)

        pbar = tqdm(dataloader)
        loss_ema = None
        for x, c in pbar:
            optim.zero_grad()
            x = x.to(device)
            c = c.to(device)
            loss = ddpm(x, c)
            loss.backward()
            if loss_ema is None:
                loss_ema = loss.item()
            else:
                loss_ema = 0.95 * loss_ema + 0.05 * loss.item()
            pbar.set_description(f"loss: {loss_ema:.4f}")
            optim.step()

        # for eval, save an image of currently generated samples (top rows)
        # followed by real images (bottom rows)
        ddpm.eval()
        with torch.no_grad():
            n_sample = 4*n_classes
            for w_i, w in enumerate(ws_test):
                x_gen, x_gen_store = ddpm.sample(n_sample, (1, 32, 32), device, guide_w=w)

                # append some real images at bottom, order by class also
                x_real = torch.Tensor(x_gen.shape).to(device)
                for k in range(n_classes):
                    for j in range(int(n_sample/n_classes)):
                        try:
                            idx = torch.squeeze((c == k).nonzero())[j]
                        except:
                            idx = 0
                        x_real[k+(j*n_classes)] = x[idx]

                x_all = torch.cat([x_gen, x_real])
                grid = make_grid(x_all*-1 + 1, nrow=10)
                save_image(grid, save_dir + f"image_ep{ep}_w{w}.png")
                print('saved image at ' + save_dir + f"image_ep{ep}_w{w}.png")

                if ep%5==0 or ep == int(n_epoch-1):
                    # create gif of images evolving over time, based on x_gen_store
                    fig, axs = plt.subplots(nrows=int(n_sample/n_classes), ncols=n_classes,sharex=True,sharey=True,figsize=(8,3))
                    def animate_diff(i, x_gen_store):
                        print(f'gif animating frame {i} of {x_gen_store.shape[0]}', end='\r')
                        plots = []
                        for row in range(int(n_sample/n_classes)):
                            for col in range(n_classes):
                                axs[row, col].clear()
                                axs[row, col].set_xticks([])
                                axs[row, col].set_yticks([])
                                # plots.append(axs[row, col].imshow(x_gen_store[i,(row*n_classes)+col,0],cmap='gray'))
                                plots.append(axs[row, col].imshow(-x_gen_store[i,(row*n_classes)+col,0],cmap='gray',vmin=(-x_gen_store[i]).min(), vmax=(-x_gen_store[i]).max()))
                        return plots
                    ani = FuncAnimation(fig, animate_diff, fargs=[x_gen_store],  interval=200, blit=False, repeat=True, frames=x_gen_store.shape[0])
                    ani.save(save_dir + f"gif_ep{ep}_w{w}.gif", dpi=100, writer=PillowWriter(fps=5))
                    print('saved image at ' + save_dir + f"gif_ep{ep}_w{w}.gif")
        # optionally save model
        if save_model and ep == int(n_epoch-1):
            torch.save(ddpm.state_dict(), save_dir + f"model_{ep}.pth")
            print('saved model at ' + save_dir + f"model_{ep}.pth")

def generate_samples(model_path, save_dir, n_samples=40, image_size=(1, 32, 32), device="cuda"):
    device = torch.device(device)
    ddpm = SimpleDiffusion(nn_model=ContextUnet(in_channels=1, n_feat=128, n_classes=10),betas=(1e-4, 0.02),n_T=400,device=device,drop_prob=0.1)
    ddpm.load_state_dict(torch.load(model_path, map_location=device))
    ddpm.to(device)
    ddpm.eval()

    student_id = "0297375"
    digits = [int(d) for d in student_id]
    n_digits = len(digits)

    samples = []
    with torch.no_grad(): # do not compute grad when not training
        for digit in digits:
            # belike rewrite sample()
            c_i = torch.tensor([digit]).to(device)
            context_mask = torch.zeros_like(c_i).to(device)
            c_i = c_i.repeat(2)
            context_mask = context_mask.repeat(2)
            context_mask[1:] = 1.  # makes second half of batch context free
            x_i = torch.randn(1, *image_size).to(device)

            for i in range(ddpm.n_T, 0, -1):
                print(f'sampling timestep {i} for digit {digit}', end='\r')
                t_is = torch.tensor([i / ddpm.n_T]).to(device)
                t_is = t_is.repeat(1, 1, 1, 1)

                # Double batch
                x_i_double = x_i.repeat(2, 1, 1, 1)
                t_is = t_is.repeat(2, 1, 1, 1)

                # Noise prediction
                z = torch.randn(1, *image_size).to(device) if i > 1 else 0
                eps = ddpm.nn_model(x_i_double, c_i, t_is, context_mask)
                eps1 = eps[:1]
                eps2 = eps[1:]
                guide_w = 0.5  # guidance using 0.5, 0 and 2 also works
                eps = (1 + guide_w) * eps1 - guide_w * eps2

                x_i = ( ddpm.oneover_sqrta[i] * (x_i - eps * ddpm.mab_over_sqrtmab[i]) + ddpm.sqrt_beta_t[i] * z )

            samples.append(x_i)

    x_gen = torch.cat(samples, dim=0)
    grid = make_grid(x_gen * -1 + 1, nrow=n_digits)
    save_image(grid, save_dir + "generated_samples.png")

if __name__ == "__main__":

    set_seed()
    #train_mnist()

    # Please use the generate function to generate your desired digit images.
    ################################
    # Your code starts here
    ################################

    generate_samples("./model_19.pth", "./data/diffusion_outputs10/", device="cpu")

    ################################
    # Your code ends here
    ################################


## Puzzle 1
How to generate a two-digit image using an end-to-end diffusion model？

* The key challenge is adapting the model to handle a 2-digit label space (00-99) instead of single digits (0-9)

In [None]:
class ContextUnet(nn.Module):
    def __init__(self, in_channels, n_feat = 256, n_classes=100):  # Changed n_classes to 100
        super(ContextUnet, self).__init__()

        self.in_channels = in_channels
        self.n_feat = n_feat
        self.n_classes = n_classes

        self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)

        self.down1 = UnetDown(n_feat, n_feat)
        self.down2 = UnetDown(n_feat, 2 * n_feat)

        self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())

        self.timeembed1 = EmbedFC(1, 2*n_feat)
        self.timeembed2 = EmbedFC(1, 1*n_feat)
        self.contextembed1 = EmbedFC(n_classes, 2*n_feat)
        self.contextembed2 = EmbedFC(n_classes, 1*n_feat)

        self.up0 = nn.Sequential(
            nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, 8, 8),
            nn.GroupNorm(8, 2 * n_feat),
            nn.ReLU(),
        )

        self.up1 = UnetUp(4 * n_feat, n_feat)
        self.up2 = UnetUp(2 * n_feat, n_feat)
        self.out = nn.Sequential(
            nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1),
            nn.GroupNorm(8, n_feat),
            nn.ReLU(),
            nn.Conv2d(n_feat, self.in_channels, 3, 1, 1),
        )

    def forward(self, x, c, t, context_mask):
        # Modified to handle two-digit numbers
        x = self.init_conv(x)
        down1 = self.down1(x)
        down2 = self.down2(down1)
        hiddenvec = self.to_vec(down2)

        # Convert context to one hot embedding for 100 classes
        c = nn.functional.one_hot(c, num_classes=self.n_classes).type(torch.float)

        # Mask out context if context_mask == 1
        context_mask = context_mask[:, None]
        context_mask = context_mask.repeat(1, self.n_classes)
        context_mask = (-1*(1-context_mask))
        c = c * context_mask

        # Embed context, time step
        cemb1 = self.contextembed1(c).view(-1, self.n_feat * 2, 1, 1)
        temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
        cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1)
        temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)

        up1 = self.up0(hiddenvec)
        up2 = self.up1(cemb1*up1+ temb1, down2)
        up3 = self.up2(cemb2*up2+ temb2, down1)
        out = self.out(torch.cat((up3, x), 1))
        return out

## Puzzle 2
Besides simply sampling the noise from a Gaussian distribution, what are some alternative sampling strategies that could improve the generation process?？

* Noise Selection: Select the most stable one from K random noises, sample the noise, computes stability score using cosine similarity, and retain the most stable noise

In [None]:
import torch.nn.functional as F

class NoiseSelectionDiffusion(SimpleDiffusion):
    def __init__(self, nn_model, betas, n_T, device, drop_prob=0.1):
        super().__init__(nn_model, betas, n_T, device, drop_prob)

    def noise_selection(self, K, c, size, device):
        best_noise, best_stability = None, -1
        for _ in range(K):
            noise = torch.randn(1, *size).to(device)
            stability = self.compute_inversion_stability(noise, c, size)
            if stability > best_stability:
                best_stability, best_noise = stability, noise
        return best_noise

    def compute_inversion_stability(self, noise, c, size):
        x_0 = self.sample_from_noise(noise, c, size)
        inverse_noise = self.inverse_noise(x_0, c)
        return F.cosine_similarity(noise.view(noise.shape[0], -1),
                                    inverse_noise.view(inverse_noise.shape[0], -1),
                                    dim=1)

    def sample_from_noise(self, noise, c, size):
        x_i = noise
        for i in range(self.n_T, 0, -1):
            t_is = torch.tensor([i / self.n_T]).to(self.device).repeat(noise.shape[0], 1, 1, 1)
            eps = self.nn_model(x_i, c, t_is, torch.zeros_like(c).to(self.device))
            z = torch.randn_like(x_i) if i > 1 else 0
            x_i = (self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i]) +
                    self.sqrt_beta_t[i] * z)
        return x_i

    def inverse_noise(self, x_0, c):
        x_t = x_0
        for i in range(1, self.n_T + 1):
            t_is = torch.tensor([i / self.n_T]).to(self.device).repeat(x_0.shape[0], 1, 1, 1)
            x_t = (self.sqrtab[i] * x_t + self.sqrtmab[i] * torch.randn_like(x_t))
        return x_t

def generate_samples(model_path, save_dir, K=100, device="cuda"):
    device = torch.device(device)
    ddpm = NoiseSelectionDiffusion(
        nn_model=ContextUnet(in_channels=1, n_feat=128, n_classes=10),
        betas=(1e-4, 0.02),
        n_T=400,
        device=device
    )
    #... all the same as above
