Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved.
Portions of this notebook consist of AI-generated content.

Permission is hereby granted, free of charge, to any person obtaining a copy

of this software and associated documentation files (the "Software"), to deal

in the Software without restriction, including without limitation the rights

to use, copy, modify, merge, publish, distribute, sublicense, and/or sell

copies of the Software, and to permit persons to whom the Software is

furnished to do so, subject to the following conditions:



The above copyright notice and this permission notice shall be included in all

copies or substantial portions of the Software.



THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR

IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,

FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE

AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER

LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,

OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE

SOFTWARE.

In [1]:
import sys, platform, torch, torchvision
print("Python:", sys.version.split()[0])
print("PyTorch:", torch.__version__)
print("Torchvision:", torchvision.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("CUDA device:", torch.cuda.get_device_name(0))

Python: 3.10.19
PyTorch: 2.10.0.dev20251101+rocm7.0
Torchvision: 0.25.0.dev20251102+rocm7.0
CUDA available: True
CUDA device: AMD Radeon AI PRO R9700


In [2]:
import os
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms
from torchvision.utils import make_grid

import matplotlib.pyplot as plt
from tqdm import tqdm


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
print("Using device:", device)

config = {
    "dataset": {
        "root": "~/data/flowers102",   
        "image_size": 64,              
    },
    "diffusion": {
        "num_timesteps": 1000,
        "beta_start": 1e-4,
        "beta_end": 2e-2,
    },
    "model": {
        "im_channels": 3,              # RGB
        "im_size": 64,
        "down_channels": [128, 256, 512, 512],
        "mid_channels": [512, 512, 512],
        "down_sample": [True, True, False],
        "time_emb_dim": 256,
        "num_down_layers": 3,
        "num_mid_layers": 3,
        "num_up_layers": 3,
        "num_heads": 4,
    },
    "train": {
        "task_name": "flowers_ddpm",
        "batch_size": 8,
        "num_epochs": 2000,
        "lr": 1e-4,
        "ckpt_name": "ddpm_flowers102.ckpt",
        "num_samples": 4,             
        "num_grid_rows": 2,            
    }
}


Path(config["train"]["task_name"]).mkdir(parents=True, exist_ok=True)
(Path(config["train"]["task_name"]) / "samples").mkdir(parents=True, exist_ok=True)


Using device: cuda


In [4]:
class LinearNoiseScheduler:
    def __init__(self, num_timesteps, beta_start, beta_end):
        self.num_timesteps = num_timesteps
        self.beta_start = beta_start
        self.beta_end = beta_end

        # linear beta schedule
        self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
        self.alphas = 1.0 - self.betas
        self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0)

        self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
        self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1.0 - self.alpha_cum_prod)

    def add_noise(self, x0, noise, t):
        """
        forward process: q(x_t | x_0)
        x_t = sqrt(alpha_bar_t) * x0 + sqrt(1 - alpha_bar_t) * noise

        x0: (B, C, H, W)
        noise: (B, C, H, W)
        t: (B,)  each sample's timestep
        """
        B = x0.shape[0]

        sqrt_alpha_cum = self.sqrt_alpha_cum_prod.to(x0.device)[t].reshape(B)
        sqrt_one_minus = self.sqrt_one_minus_alpha_cum_prod.to(x0.device)[t].reshape(B)

        # reshape 
        while len(sqrt_alpha_cum.shape) < len(x0.shape):
            sqrt_alpha_cum = sqrt_alpha_cum.unsqueeze(-1)
            sqrt_one_minus = sqrt_one_minus.unsqueeze(-1)

        return sqrt_alpha_cum * x0 + sqrt_one_minus * noise

    def sample_prev_timestep(self, xt, noise_pred, t_scalar):
        """
        reverse process: p_theta(x_{t-1} | x_t)

        xt: (B, C, H, W)
        noise_pred: (B, C, H, W)
        t_scalar: use same t for each batch
        """
        if isinstance(t_scalar, torch.Tensor):
            t_int = int(t_scalar.item())
        else:
            t_int = int(t_scalar)

        betas = self.betas.to(xt.device)
        alphas = self.alphas.to(xt.device)
        alpha_cum = self.alpha_cum_prod.to(xt.device)
        sqrt_one_minus = self.sqrt_one_minus_alpha_cum_prod.to(xt.device)

        beta_t = betas[t_int]
        alpha_t = alphas[t_int]
        alpha_bar_t = alpha_cum[t_int]
        sqrt_one_minus_t = sqrt_one_minus[t_int]

        x0_pred = (xt - sqrt_one_minus_t * noise_pred) / torch.sqrt(alpha_bar_t)
        x0_pred = torch.clamp(x0_pred, -1.0, 1.0)

        mean = (xt - beta_t * noise_pred / sqrt_one_minus_t) / torch.sqrt(alpha_t)

        if t_int == 0:
            return mean, x0_pred
        else:
            alpha_bar_prev = alpha_cum[t_int - 1]
            var = (1.0 - alpha_bar_prev) / (1.0 - alpha_bar_t) * beta_t
            sigma = torch.sqrt(var)
            z = torch.randn_like(xt)
            return mean + sigma * z, x0_pred


In [None]:
def get_time_embedding(time_steps, temb_dim):
    """
    sinusoidal time embedding, like Transformer positional encoding
    time_steps: (B,)
    return: (B, temb_dim)
    """
    assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"

    half_dim = temb_dim // 2
    exponent = torch.arange(half_dim, dtype=torch.float32, device=time_steps.device) / half_dim
    factor = 10000 ** exponent  

    # time steps reshape: (B, 1) -> (B, half_dim)
    t = time_steps.float().unsqueeze(1)
    t = t / factor.unsqueeze(0)

    emb = torch.cat([torch.sin(t), torch.cos(t)], dim=-1)
    return emb


In [6]:
class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, t_emb_dim,
                 down_sample=True, num_heads=4, num_layers=1):
        super().__init__()
        self.num_layers = num_layers
        self.down_sample = down_sample

        self.resnet_conv_first = nn.ModuleList([
            nn.Sequential(
                nn.GroupNorm(8, in_channels if i == 0 else out_channels),
                nn.SiLU(),
                nn.Conv2d(in_channels if i == 0 else out_channels,
                          out_channels, kernel_size=3, stride=1, padding=1),
            )
            for i in range(num_layers)
        ])

        self.t_emb_layers = nn.ModuleList([
            nn.Sequential(
                nn.SiLU(),
                nn.Linear(t_emb_dim, out_channels)
            )
            for _ in range(num_layers)
        ])

        self.resnet_conv_second = nn.ModuleList([
            nn.Sequential(
                nn.GroupNorm(8, out_channels),
                nn.SiLU(),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            )
            for _ in range(num_layers)
        ])

        self.attention_norms = nn.ModuleList([
            nn.GroupNorm(8, out_channels)
            for _ in range(num_layers)
        ])

        self.attentions = nn.ModuleList([
            nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
            for _ in range(num_layers)
        ])

        self.residual_input_conv = nn.ModuleList([
            nn.Conv2d(in_channels if i == 0 else out_channels,
                      out_channels, kernel_size=1)
            for i in range(num_layers)
        ])

        self.down_sample_conv = (
            nn.Conv2d(out_channels, out_channels, kernel_size=4, stride=2, padding=1)
            if self.down_sample else nn.Identity()
        )

    def forward(self, x, t_emb):
        out = x
        for i in range(self.num_layers):
            # ResNet block
            res_in = out
            out = self.resnet_conv_first[i](out)
            out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
            out = self.resnet_conv_second[i](out)
            out = out + self.residual_input_conv[i](res_in)

            # Self-attention
            B, C, H, W = out.shape
            attn_in = out.view(B, C, H * W)
            attn_in = self.attention_norms[i](attn_in)
            attn_in = attn_in.transpose(1, 2)   # (B, HW, C)
            attn_out, _ = self.attentions[i](attn_in, attn_in, attn_in)
            attn_out = attn_out.transpose(1, 2).view(B, C, H, W)
            out = out + attn_out

        out = self.down_sample_conv(out)
        return out


In [7]:
class MidBlock(nn.Module):
    def __init__(self, in_channels, out_channels, t_emb_dim,
                 num_heads=4, num_layers=1):
        super().__init__()
        self.num_layers = num_layers

        self.resnet_conv_first = nn.ModuleList([
            nn.Sequential(
                nn.GroupNorm(8, in_channels if i == 0 else out_channels),
                nn.SiLU(),
                nn.Conv2d(in_channels if i == 0 else out_channels,
                          out_channels, kernel_size=3, stride=1, padding=1),
            )
            for i in range(num_layers + 1)
        ])

        self.t_emb_layers = nn.ModuleList([
            nn.Sequential(
                nn.SiLU(),
                nn.Linear(t_emb_dim, out_channels)
            )
            for _ in range(num_layers + 1)
        ])

        self.resnet_conv_second = nn.ModuleList([
            nn.Sequential(
                nn.GroupNorm(8, out_channels),
                nn.SiLU(),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            )
            for _ in range(num_layers + 1)
        ])

        self.attention_norms = nn.ModuleList([
            nn.GroupNorm(8, out_channels)
            for _ in range(num_layers)
        ])

        self.attentions = nn.ModuleList([
            nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
            for _ in range(num_layers)
        ])

        self.residual_input_conv = nn.ModuleList([
            nn.Conv2d(in_channels if i == 0 else out_channels,
                      out_channels, kernel_size=1)
            for i in range(num_layers + 1)
        ])

    def forward(self, x, t_emb):
        out = x

        res_in = out
        out = self.resnet_conv_first[0](out)
        out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
        out = self.resnet_conv_second[0](out)
        out = out + self.residual_input_conv[0](res_in)

        for i in range(self.num_layers):
            B, C, H, W = out.shape
            attn_in = out.view(B, C, H * W)
            attn_in = self.attention_norms[i](attn_in)
            attn_in = attn_in.transpose(1, 2)
            attn_out, _ = self.attentions[i](attn_in, attn_in, attn_in)
            attn_out = attn_out.transpose(1, 2).view(B, C, H, W)
            out = out + attn_out

            res_in = out
            out = self.resnet_conv_first[i + 1](out)
            out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None]
            out = self.resnet_conv_second[i + 1](out)
            out = out + self.residual_input_conv[i + 1](res_in)

        return out


In [None]:
class UpBlock(nn.Module):
    def __init__(
        self,
        in_main_channels,   
        skip_channels,      
        out_channels,       
        t_emb_dim,
        up_sample=True,
        num_heads=4,
        num_layers=1,
    ):
        super().__init__()
        self.num_layers = num_layers
        self.up_sample = up_sample

        self.up_sample_conv = (
            nn.ConvTranspose2d(
                in_main_channels, in_main_channels,
                kernel_size=4, stride=2, padding=1
            )
            if self.up_sample else nn.Identity()
        )

        in_after_concat = in_main_channels + skip_channels

        self.resnet_conv_first = nn.ModuleList([
            nn.Sequential(
                nn.GroupNorm(8, in_after_concat if i == 0 else out_channels),
                nn.SiLU(),
                nn.Conv2d(
                    in_after_concat if i == 0 else out_channels,
                    out_channels,
                    kernel_size=3, stride=1, padding=1
                ),
            )
            for i in range(num_layers)
        ])

        self.t_emb_layers = nn.ModuleList([
            nn.Sequential(
                nn.SiLU(),
                nn.Linear(t_emb_dim, out_channels)
            )
            for _ in range(num_layers)
        ])

        self.resnet_conv_second = nn.ModuleList([
            nn.Sequential(
                nn.GroupNorm(8, out_channels),
                nn.SiLU(),
                nn.Conv2d(
                    out_channels,
                    out_channels,
                    kernel_size=3, stride=1, padding=1
                ),
            )
            for _ in range(num_layers)
        ])

        self.attention_norms = nn.ModuleList([
            nn.GroupNorm(8, out_channels)
            for _ in range(num_layers)
        ])

        self.attentions = nn.ModuleList([
            nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
            for _ in range(num_layers)
        ])

        self.residual_input_conv = nn.ModuleList([
            nn.Conv2d(
                in_after_concat if i == 0 else out_channels,
                out_channels,
                kernel_size=1
            )
            for i in range(num_layers)
        ])

    def forward(self, x, out_down, t_emb):
        """
        x: up block 的 feature (B, C_main, H, W)
        out_down: skip connection 的 feature (B, C_skip, H or 2H, W or 2W)
        """
        x = self.up_sample_conv(x)          
        x = torch.cat([x, out_down], dim=1) #  concat skip: (B, C_main + C_skip, H, W)

        out = x
        for i in range(self.num_layers):
            res_in = out
            # ResNet conv1
            out = self.resnet_conv_first[i](out)
            out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
            # ResNet conv2
            out = self.resnet_conv_second[i](out)
            
            out = out + self.residual_input_conv[i](res_in)

            # Self-attention
            B, C, H, W = out.shape
            attn_in = out.view(B, C, H * W)
            attn_in = self.attention_norms[i](attn_in)
            attn_in = attn_in.transpose(1, 2)         # (B, HW, C)
            attn_out, _ = self.attentions[i](attn_in, attn_in, attn_in)
            attn_out = attn_out.transpose(1, 2).view(B, C, H, W)
            out = out + attn_out

        return out


In [9]:
class UNet(nn.Module):
    def __init__(self, model_config):
        super().__init__()

        im_channels = model_config["im_channels"]
        self.down_channels = model_config["down_channels"]
        self.mid_channels = model_config["mid_channels"]
        self.down_sample = model_config["down_sample"]
        self.t_emb_dim = model_config["time_emb_dim"]
        self.num_down_layers = model_config["num_down_layers"]
        self.num_mid_layers = model_config["num_mid_layers"]
        self.num_up_layers = model_config["num_up_layers"]
        num_heads = model_config["num_heads"]

        assert self.mid_channels[0] == self.down_channels[-1]
        assert self.mid_channels[-1] == self.down_channels[-1]
        assert len(self.down_sample) == len(self.down_channels) - 1

        # time embedding MLP
        self.t_proj = nn.Sequential(
            nn.Linear(self.t_emb_dim, self.t_emb_dim),
            nn.SiLU(),
            nn.Linear(self.t_emb_dim, self.t_emb_dim),
        )

        # input conv
        self.conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=1)

        # Down blocks
        self.downs = nn.ModuleList()
        for i in range(len(self.down_channels) - 1):
            self.downs.append(
                DownBlock(
                    in_channels=self.down_channels[i],
                    out_channels=self.down_channels[i + 1],
                    t_emb_dim=self.t_emb_dim,
                    down_sample=self.down_sample[i],
                    num_heads=num_heads,
                    num_layers=self.num_down_layers,
                )
            )

        # Mid blocks
        self.mids = nn.ModuleList()
        for i in range(len(self.mid_channels) - 1):
            self.mids.append(
                MidBlock(
                    in_channels=self.mid_channels[i],
                    out_channels=self.mid_channels[i + 1],
                    t_emb_dim=self.t_emb_dim,
                    num_heads=num_heads,
                    num_layers=self.num_mid_layers,
                )
            )

        self.ups = nn.ModuleList()

        current_channels = self.mid_channels[-1]  

        for i in reversed(range(len(self.down_channels) - 1)):
            skip_ch = self.down_channels[i]     

            out_ch = self.down_channels[i] if i > 0 else 32

            self.ups.append(
                UpBlock(
                    in_main_channels=current_channels,
                    skip_channels=skip_ch,
                    out_channels=out_ch,
                    t_emb_dim=self.t_emb_dim,
                    up_sample=self.down_sample[i],
                    num_heads=num_heads,
                    num_layers=self.num_up_layers,
                )
            )

            current_channels = out_ch 

        self.final_channels = current_channels
        self.norm_out = nn.GroupNorm(8, self.final_channels)
        self.conv_out = nn.Conv2d(self.final_channels, im_channels, kernel_size=3, padding=1)

    def forward(self, x, t):
        # x: (B, C, H, W)
        # t: (B,)  timestep for each image
        B = x.shape[0]

        out = self.conv_in(x)  # -> (B, C1, H, W)

        t = t.view(B)  # -> (B,)
        t_emb = get_time_embedding(t, self.t_emb_dim)
        t_emb = self.t_proj(t_emb)

        # Down path
        down_outs = []
        for down in self.downs:
            down_outs.append(out)
            out = down(out, t_emb)

        # Mid blocks
        for mid in self.mids:
            out = mid(out, t_emb)

        # Up path
        for up in self.ups:
            skip = down_outs.pop()
            out = up(out, skip, t_emb)

        out = self.norm_out(out)
        out = F.silu(out)
        out = self.conv_out(out)  #input shape = output shape (B, C, H, W)
        return out


In [None]:
@torch.no_grad()
def debug_one_step_denoise(model, scheduler, train_loader, t_value=500, n_samples=6):
    model.eval()

    imgs, _ = next(iter(train_loader))  # 1 batch
    imgs = imgs[:n_samples].to(device)

    B = imgs.size(0)
    t = torch.full((B,), t_value, device=device, dtype=torch.long)

    # add noise
    noise = torch.randn_like(imgs)
    x_t = scheduler.add_noise(imgs, noise, t)

    # predict noise
    noise_pred = model(x_t, t)

    alpha_bar_t = scheduler.alpha_cum_prod.to(device)[t_value]
    sqrt_one_minus_t = scheduler.sqrt_one_minus_alpha_cum_prod.to(device)[t_value]
    x0_pred = (x_t - sqrt_one_minus_t * noise_pred) / torch.sqrt(alpha_bar_t)
    x0_pred = torch.clamp(x0_pred, -1.0, 1.0)

    to_01 = lambda x: (torch.clamp(x, -1, 1) + 1) / 2

    grid_real  = make_grid(to_01(imgs), nrow=n_samples)
    grid_noisy = make_grid(to_01(x_t), nrow=n_samples)
    grid_pred  = make_grid(to_01(x0_pred), nrow=n_samples)

    plt.figure(figsize=(14,4))
    plt.subplot(1,3,1)
    plt.title("Real x0")
    plt.imshow(grid_real.permute(1,2,0).cpu().numpy())
    plt.axis("off")

    plt.subplot(1,3,2)
    plt.title(f"x_t (t={t_value})")
    plt.imshow(grid_noisy.permute(1,2,0).cpu().numpy())
    plt.axis("off")

    plt.subplot(1,3,3)
    plt.title("Predicted x0 (denoised)")
    plt.imshow(grid_pred.permute(1,2,0).cpu().numpy())
    plt.axis("off")

    plt.show()

In [11]:
image_size = config["dataset"]["image_size"]

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    #  map from [0,1] to [-1,1]： (x - 0.5) / 0.5
    transforms.Normalize([0.5, 0.5, 0.5],
                         [0.5, 0.5, 0.5]),
])

from torchvision.datasets import Flowers102

train_dataset = Flowers102(
    root=config["dataset"]["root"],
    split="train",
    download=True,
    transform=transform,
)

train_loader = DataLoader(
    train_dataset,
    batch_size=config["train"]["batch_size"],
    shuffle=True,
    num_workers=4,
    pin_memory=True,
)

len(train_dataset), len(train_loader)


(1020, 128)

In [12]:
diff_cfg = config["diffusion"]
model_cfg = config["model"]
train_cfg = config["train"]

# build scheduler
scheduler = LinearNoiseScheduler(
    num_timesteps=diff_cfg["num_timesteps"],
    beta_start=diff_cfg["beta_start"],
    beta_end=diff_cfg["beta_end"],
)

# build model
model = UNet(model_cfg).to(device)

# optimizer / loss
optimizer = Adam(model.parameters(), lr=train_cfg["lr"])
criterion = nn.MSELoss()

# checkpoint path
ckpt_path = Path(train_cfg["task_name"]) / train_cfg["ckpt_name"]

# if pretrained
if ckpt_path.exists():
    print(f"Loading checkpoint from {ckpt_path}")
    model.load_state_dict(torch.load(ckpt_path, map_location=device))


In [None]:
num_epochs = train_cfg["num_epochs"]
epoch_losses = []   

# log file path
log_path = Path(train_cfg["task_name"]) / "train_log.txt"
log_file = open(log_path, "a")

for epoch in range(num_epochs):
    model.train()
    losses = []

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for imgs, _ in pbar:
        imgs = imgs.to(device)

        # 1. sample gaussian noise
        noise = torch.randn_like(imgs)

        # 2. sample timesteps t ~ Uniform({0,...,T-1})
        B = imgs.shape[0]
        t = torch.randint(
            low=0,
            high=diff_cfg["num_timesteps"],
            size=(B,),
            device=device,
        ).long()

        # 3. forward process: get x_t
        noisy_imgs = scheduler.add_noise(imgs, noise, t)

        # 4. Use U-Net to predict noise
        noise_pred = model(noisy_imgs, t)

        # 5. loss = MSE(noise_pred, noise_real)
        loss = criterion(noise_pred, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})

    avg_loss = sum(losses) / len(losses)
    epoch_losses.append(avg_loss)  
    print(f"Epoch {epoch+1}/{num_epochs} - loss: {avg_loss:.4f}")
    # write to log file
    log_file.write(f"Epoch {epoch+1}/{num_epochs} - loss: {avg_loss:.4f}\n")
    log_file.flush()

    # save ckpt each epoch
    torch.save(model.state_dict(), ckpt_path)

log_file.close()

print("Training log saved to:", log_path)

Epoch 1/2000:  12%|██████▉                                                | 16/128 [00:36<04:11,  2.25s/it, loss=0.7443]

In [None]:
debug_one_step_denoise(model, scheduler, train_loader, t_value=99, n_samples=6)

In [None]:
debug_one_step_denoise(model, scheduler, train_loader, t_value=199, n_samples=6)

In [None]:
debug_one_step_denoise(model, scheduler, train_loader, t_value=299, n_samples=6)

In [None]:
debug_one_step_denoise(model, scheduler, train_loader, t_value=399, n_samples=6)

In [None]:
debug_one_step_denoise(model, scheduler, train_loader, t_value=499, n_samples=6)

In [None]:
debug_one_step_denoise(model, scheduler, train_loader, t_value=599, n_samples=6)

In [None]:
debug_one_step_denoise(model, scheduler, train_loader, t_value=699, n_samples=6)

In [None]:
debug_one_step_denoise(model, scheduler, train_loader, t_value=799, n_samples=6)

In [None]:
debug_one_step_denoise(model, scheduler, train_loader, t_value=899, n_samples=6)

In [None]:
debug_one_step_denoise(model, scheduler, train_loader, t_value=999, n_samples=6)

In [None]:
@torch.no_grad()
def sample_flowers(model, scheduler, num_samples=16, num_grid_rows=4, save_name="sample.png"):
    
    model.eval()
    T = scheduler.num_timesteps

    betas = scheduler.betas.to(device)
    alphas = scheduler.alphas.to(device)
    alpha_cum = scheduler.alpha_cum_prod.to(device)
    sqrt_one_minus_all = scheduler.sqrt_one_minus_alpha_cum_prod.to(device)

    # x_T:  Gaussian noise
    xt = torch.randn(
        num_samples,
        config["model"]["im_channels"],
        config["model"]["im_size"],
        config["model"]["im_size"],
        device=device,
    )

    for t in reversed(range(T)):
        t_batch = torch.full(
            (xt.shape[0],),
            t,
            device=device,
            dtype=torch.long,
        )
e
        noise_pred = model(xt, t_batch)

        beta_t = betas[t]
        alpha_t = alphas[t]
        alpha_bar_t = alpha_cum[t]
        sqrt_one_minus_t = sqrt_one_minus_all[t]

        x0_pred = (xt - sqrt_one_minus_t * noise_pred) / torch.sqrt(alpha_bar_t)
        x0_pred = torch.clamp(x0_pred, -1.0, 1.0)

        mean = (xt - beta_t * noise_pred / sqrt_one_minus_t) / torch.sqrt(alpha_t)

        if t > 0:
            alpha_bar_prev = alpha_cum[t - 1]
            var = (1.0 - alpha_bar_prev) / (1.0 - alpha_bar_t) * beta_t
            sigma = torch.sqrt(var)
            z = torch.randn_like(xt)
            xt = mean + sigma * z
        else:
            xt = mean

    ims = torch.clamp(xt, -1.0, 1.0).cpu()
    ims = (ims + 1.0) / 2.0  # [-1,1] → [0,1]

    grid = make_grid(ims, nrow=num_grid_rows)
    np_grid = grid.permute(1, 2, 0).numpy()

    plt.figure(figsize=(6, 6))
    plt.imshow(np_grid)
    plt.axis("off")
    plt.show()

    out_path = Path(config["train"]["task_name"]) / "samples" / save_name
    torchvision.utils.save_image(grid, out_path)
    print("Saved samples to", out_path)

In [None]:
# load model ckpt
if ckpt_path.exists():
    model.load_state_dict(torch.load(ckpt_path, map_location=device))
    print("Loaded trained weights from", ckpt_path)
else:
    print("Warning: checkpoint not found, sampling from randomly initialized model")

sample_flowers(
    model,
    scheduler,
    num_samples=4,
    num_grid_rows=2,
    save_name="flowers_epoch_last.png",
)
