### 前言：本代码是根据DDPM论文和网上相关参考资料，从零搭建自己的DDPM模型

参考资料：

[DDPM原论文](https://arxiv.org/pdf/2006.11239)

[图科学实验室](https://mp.weixin.qq.com/s/Rj51LTCjbuX_bn7iALqMIg)

[大白话AI](https://www.bilibili.com/video/BV1tz4y1h7q1/?spm_id_from=333.337.search-card.all.click&vd_source=1a02178b1644ddc9b579739c3c1616b4)

[手写AI](https://www.bilibili.com/video/BV1BN41117NJ?spm_id_from=333.788.videopod.sections&vd_source=1a02178b1644ddc9b579739c3c1616b4)

# 导入相关模块

In [1]:
import torch
from torch.utils.data import Dataset
from torchvision.transforms import Compose, ToTensor, Lambda, Resize
from torch.utils.data import DataLoader
from natsort import natsorted
import glob
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
from torchvision.datasets.mnist import MNIST, FashionMNIST
from torchvision.utils import make_grid
from einops import rearrange
import imageio
import yaml
import os
from torch.optim import Adam
from tqdm import tqdm

# 自定义数据集类

In [2]:
class MyCustomDataset(Dataset):
    """
    作用：加载自定义数据集并进行预处理
    """
    def __init__(self, root, transform=None, mode="train"):
        """
        输入：
            root:数据集的根目录
            transforms:一个列表，包含要应用于图像的转换操作
            mode:可以是 "train" 或 "test"
        """

        # 将传入的 transforms_ 列表组合成一个可以对图像进行转换的函数链
        self.transform = transform

        # 获取数据集目录下所有图片的路径并进行排序
        self.img_paths = natsorted(glob.glob(root + "/*.png"))
    
    def __getitem__(self, index):
        # 通过index取出图像路径并打开图像
        img_path = self.img_paths[index]

        # 打开图片
        img = Image.open(img_path).convert('RGB')  # 以防万一，转成 RGB

        # 做transform
        if self.transform is not None:
            img = self.transform(img)
        return img
    
    def __len__(self):
        return len(self.img_paths)

# MNIST和FashionMNIST数据集类

In [3]:
class MNISTDataset(Dataset):
    """
    自定义数据集类,用于加载MNIST或FashionMNIST数据集。
    
    参数:
        dataset_name (str): 数据集类型，支持'MNIST'和'FashionMNIST'。
        root (str): 数据集存储的根目录。
        train (bool): 是否加载训练集。True 加载训练集,False 加载测试集。
        transform (callable, optional): 应用于样本的可调用转换。
        download (bool): 如果数据集不存在,是否自动下载。
    """
    def __init__(self, dataset_name='MNIST', root='./datasets', train=True, transform=None, download=True):
        super(MNISTDataset, self).__init__()

        self.dataset_name = dataset_name.lower()
        self.train = train
        self.transform = transform

        if self.dataset_name == 'mnist':
            self.dataset = MNIST(
                root=root,
                train=self.train,
                transform=self.transform,
                download=download
            )
        elif self.dataset_name == 'fashionmnist':
            self.dataset = FashionMNIST(
                root=root,
                train=self.train,
                transform=self.transform,
                download=download
            )
        else:
            raise ValueError("Unsupported dataset_name. Choose either 'MNIST' or 'FashionMNIST'.")

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        """
        根据索引返回数据(包含数据和索引)
        """
        data = self.dataset[index]
        return data


# 定义前向加噪过程和反向去噪过程

前向加噪过程
$$
x_t = \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon + \sqrt{\bar{\alpha}_t} \cdot x_0
$$

反向去噪过程
$$
P(X_{t-1} \mid X_t, X_0) \sim \mathcal{N} \left( \frac{1}{\sqrt{\alpha_t}} \left( X_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon \right), \frac{\beta_t (1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} \right)
$$

In [4]:
# 噪声调度器
class LinearNoiseScheduler:
    def __init__(self, num_timesteps, beta_start, beta_end):
        """
        num_timesteps: 模型的时间步数，表示去噪过程中的离散时间步总数
        beta_start: 初始的噪声强度（β）值
        beta_end: 最终的噪声强度（β）值
        """
        self.num_timesteps = num_timesteps
        self.beta_start = beta_start
        self.beta_end = beta_end
        
        self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
        self.alphas = 1. - self.betas
        self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0)
        self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
        self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod)

    def add_noise(self, original, noise, t):
        """
        original:输入的原始图像,形状为 (batch_size, channels, height, width) 的张量
        noise:添加的噪声,形状为 (batch_size, channels, height, width) 的张量
        t:当前的时间步
        """
        original_shape = original.shape
        batch_size = original_shape[0]

        # 获取对应时间步的噪声强度
        sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size)
        sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size)

        
        # 将噪声强度的维度从(batch_size,)扩展为(batch_size, 1, 1, 1)
        for _ in range(len(original_shape) - 1):
            sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1)
        for _ in range(len(original_shape) - 1):
            sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1)
        
        # 根据前向加噪公式进行加噪
        return (sqrt_alpha_cum_prod.to(original.device) * original
                + sqrt_one_minus_alpha_cum_prod.to(original.device) * noise)
    
    # 反向去噪过程
    def sample_prev_timestep(self, xt, noise_pred, t):
        """
        xt: 当前时间步t的图像,形状为(batch_size, channels, height, width)
        noise_pred: 模型预测的噪声,形状与xt相同
        t: 当前的时间步，通常是一个标量或形状为(batch_size,)的张量
        """

        # 计算反向过程的均值μt
        mean = xt - ((self.betas.to(xt.device)[t]) * noise_pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t])
        mean = mean / torch.sqrt(self.alphas.to(xt.device)[t])
        
        # 根据时间步决定是否需要添加额外的噪声
        if t == 0:
            # 如果是最开始的时间步，则无需添加噪声，直接返回
            return mean
        else:
            # 计算反向过程的方差作为噪声的权重，从而引入额外的噪声，使生成过程更加多样化
            variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t])
            variance = variance * self.betas.to(xt.device)[t]
            sigma = variance ** 0.5
            z = torch.randn(xt.shape).to(xt.device)
            
            return mean + sigma * z

# 定义用来预测噪声的UNet网络

### 时间步正弦位置嵌入

In [5]:
# 将时间步进行正弦位置嵌入.将离散的时间步转换为连续的、高维的向量表示
def get_time_embedding(time_steps, temb_dim):
    """
    time_steps: 一个形状为(batch_size,)的张量
    temb_dim: 时间嵌入的维度，必须是偶数
    """

    # 确保嵌入的维度是偶数
    assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
    
    # factor = 10000^(2i/d_model)
    factor = 10000 ** ((torch.arange(
        start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2))
    )
    
    # pos / factor
    # timesteps B -> B, 1 -> B, temb_dim
    t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
    t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
    return t_emb

### 下采样模块

In [6]:
# 下采样模块
class DownBlock(nn.Module):

    def __init__(self, in_channels, out_channels, t_emb_dim,
                 down_sample=True, num_heads=4, num_layers=1):
        """
        in_channels: 输入特征图的通道数
        out_channels: 输出特征图的通道数
        t_emb_dim: 时间嵌入向量的维度
        down_sample: 是否进行下采样。如果为True,则在最后通过卷积层将特征图尺寸缩小一半
        num_heads: 注意力机制的头数
        num_layers: 此块中的ResNet + 注意力层的数量
        """
        super().__init__()
        self.num_layers = num_layers
        self.down_sample = down_sample
        
        # ResNet块的第一个卷积层
        self.resnet_conv_first = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(8, in_channels if i == 0 else out_channels),
                    nn.SiLU(),
                    nn.Conv2d(in_channels if i == 0 else out_channels, out_channels,
                              kernel_size=3, stride=1, padding=1),
                )
                for i in range(num_layers)
            ]
        )

        # 时间嵌入投影层，用于调整时间嵌入向量的维度
        self.t_emb_layers = nn.ModuleList([
            nn.Sequential(
                nn.SiLU(),
                nn.Linear(t_emb_dim, out_channels)
            )
            for _ in range(num_layers)
        ])

        # ResNet块的第二个卷积层
        self.resnet_conv_second = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(8, out_channels),
                    nn.SiLU(),
                    nn.Conv2d(out_channels, out_channels,
                              kernel_size=3, stride=1, padding=1),
                )
                for _ in range(num_layers)
            ]
        )

        # 自注意力机制
        self.attention_norms = nn.ModuleList(
            [nn.GroupNorm(8, out_channels)
             for _ in range(num_layers)]
        )
        self.attentions = nn.ModuleList(
            [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
             for _ in range(num_layers)]
        )

        # 用于残差连接的1x1卷积
        self.residual_input_conv = nn.ModuleList(
            [
                nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
                for i in range(num_layers)
            ]
        )

        # 下采样
        self.down_sample_conv = nn.Conv2d(out_channels, out_channels,
                                          4, 2, 1) if self.down_sample else nn.Identity()


    def forward(self, x, t_emb):
        out = x
        for i in range(self.num_layers):
            
            # Resnet block of Unet
            # 保存当前的out，作为残差连接的一部分
            resnet_input = out

            # 通过第一部分的ResNet块，进行归一化、激活和卷积操作
            out = self.resnet_conv_first[i](out)

            # 将处理后的时间嵌入向量添加到out
            out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]

            # 通过第二部分的 ResNet 块，继续进行归一化、激活和卷积操作
            out = self.resnet_conv_second[i](out)

            # 将之前保存的resnet_input通过1x1卷积层调整通道数后，与当前out进行残差连接
            out = out + self.residual_input_conv[i](resnet_input)
            

            # Attention block of Unet
            # 获取当前特征图的形状信息
            batch_size, channels, h, w = out.shape

            # 将特征图展平成 (batch_size, channels, h * w)，为注意力机制做准备
            in_attn = out.reshape(batch_size, channels, h * w)

            # 对展平后的特征图进行组归一化
            in_attn = self.attention_norms[i](in_attn)

            # 转置为(batch_size, h * w, channels)，符合nn.MultiheadAttention的输入格式
            in_attn = in_attn.transpose(1, 2)

            # 通过多头注意力机制处理特征图
            out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)

            # 将注意力后的输出转置回(batch_size, channels, h * w)，然后重新reshape回(batch_size, channels, h, w)
            out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)

            # 将注意力后的输出与当前的out进行残差连接，增强特征表达
            out = out + out_attn
            
        out = self.down_sample_conv(out)
        return out

### 颈部模块

In [7]:
class MidBlock(nn.Module):

    def __init__(self, in_channels, out_channels, t_emb_dim, num_heads=4, num_layers=1):
        """
        in_channels: 输入特征图的通道数
        out_channels: 输出特征图的通道数
        t_emb_dim: 时间嵌入向量的维度
        num_heads: 注意力机制的头数
        num_layers: 此块中的ResNet + 注意力层的数量
        """
        super().__init__()
        self.num_layers = num_layers

        # ResNet块的第一个卷积层
        self.resnet_conv_first = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(8, in_channels if i == 0 else out_channels),
                    nn.SiLU(),
                    nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
                              padding=1),
                )
                for i in range(num_layers+1)
            ]
        )

        # 时间嵌入投影层
        self.t_emb_layers = nn.ModuleList([
            nn.Sequential(
                nn.SiLU(),
                nn.Linear(t_emb_dim, out_channels)
            )
            for _ in range(num_layers + 1)
        ])

        # ResNet块的第二个卷积层
        self.resnet_conv_second = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(8, out_channels),
                    nn.SiLU(),
                    nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
                )
                for _ in range(num_layers+1)
            ]
        )
        
        # 自注意力模块
        self.attention_norms = nn.ModuleList(
            [nn.GroupNorm(8, out_channels)
                for _ in range(num_layers)]
        )
        self.attentions = nn.ModuleList(
            [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
                for _ in range(num_layers)]
        )

        # 用于残差连接的1x1卷积
        self.residual_input_conv = nn.ModuleList(
            [
                nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
                for i in range(num_layers+1)
            ]
        )
    
    def forward(self, x, t_emb):
        out = x
        
        # First resnet block
        # 保存当前的out，作为残差连接的一部分
        resnet_input = out

        # 通过第一个ResNet块的第一部分，进行组归一化、激活和卷积操作
        out = self.resnet_conv_first[0](out)

        # 将处理后的时间嵌入向量添加到out
        out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]

        # 通过第一个ResNet块的第二部分，继续进行组归一化、激活和卷积操作
        out = self.resnet_conv_second[0](out)

        # 将之前保存的resnet_input通过 1x1 卷积层调整通道数后，与当前out进行残差连接
        out = out + self.residual_input_conv[0](resnet_input)
        
        for i in range(self.num_layers):
            
            # Attention Block
            # 获取当前特征图的形状信息
            batch_size, channels, h, w = out.shape

            # 将特征图展平成(batch_size, channels, h * w)，为注意力机制做准备
            in_attn = out.reshape(batch_size, channels, h * w)

            # 对展平后的特征图进行组归一化
            in_attn = self.attention_norms[i](in_attn)

            # 转置为(batch_size, h * w, channels)，符合nn.MultiheadAttention的输入格式
            in_attn = in_attn.transpose(1, 2)

            # 通过多头注意力机制处理特征图
            out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)

            # 将注意力后的输出转置回(batch_size, channels, h * w)，然后重新reshape回(batch_size, channels, h, w)
            out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)

            # 将注意力后的输出与当前的out进行残差连接，增强特征表达
            out = out + out_attn
            
            # Resnet Block
            # 保存当前的out，作为残差连接的一部分
            resnet_input = out

            # 通过ResNet块的第一部分，进行组归一化、激活和卷积操作
            out = self.resnet_conv_first[i+1](out)

            # 将处理后的时间嵌入向量添加到out
            out = out + self.t_emb_layers[i+1](t_emb)[:, :, None, None]

            # 通过ResNet块的第二部分，继续进行组归一化、激活和卷积操作
            out = self.resnet_conv_second[i+1](out)

            # 将之前保存的resnet_input通过1x1卷积层调整通道数后，与当前out进行残差连接
            out = out + self.residual_input_conv[i+1](resnet_input)
        
        return out

### 上采样模块

In [8]:
class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, t_emb_dim, up_sample=True, num_heads=4, num_layers=1):
        """
        in_channels: 输入特征图的通道数
        out_channels: 输出特征图的通道数
        t_emb_dim: 时间嵌入向量的维度
        up_sample: 是否进行上采样。如果为True,则在前向传播过程中通过反卷积将特征图的空间尺寸放大一倍
        num_heads: 注意力机制的头数
        num_layers: 此块中的ResNet + 注意力层的数量
        """
        super().__init__()
        self.num_layers = num_layers
        self.up_sample = up_sample

        # ResNet块的第一个卷积层
        self.resnet_conv_first = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(8, in_channels if i == 0 else out_channels),
                    nn.SiLU(),
                    nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
                              padding=1),
                )
                for i in range(num_layers)
            ]
        )

        # 时间嵌入投影层
        self.t_emb_layers = nn.ModuleList([
            nn.Sequential(
                nn.SiLU(),
                nn.Linear(t_emb_dim, out_channels)
            )
            for _ in range(num_layers)
        ])

        # ResNet块的第二个卷积层
        self.resnet_conv_second = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(8, out_channels),
                    nn.SiLU(),
                    nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
                )
                for _ in range(num_layers)
            ]
        )
        
        # 自注意力模块
        self.attention_norms = nn.ModuleList(
            [
                nn.GroupNorm(8, out_channels)
                for _ in range(num_layers)
            ]
        )
        self.attentions = nn.ModuleList(
            [
                nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
                for _ in range(num_layers)
            ]
        )

        # 用于残差连接的1x1卷积
        self.residual_input_conv = nn.ModuleList(
            [
                nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
                for i in range(num_layers)
            ]
        )

        # 上采样层
        self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, 4, 2, 1) if self.up_sample else nn.Identity()
    
    def forward(self, x, out_down, t_emb):
        # 如果 up_sample=True，通过转置卷积将特征图的空间尺寸放大一倍
        x = self.up_sample_conv(x)

        # 上采样后的特征图x与下采样路径中的对应特征图out_down在通道维度上拼接
        x = torch.cat([x, out_down], dim=1)
        
        out = x
        for i in range(self.num_layers):

            # 保存当前的out，作为残差连接的一部分
            resnet_input = out

            # 通过ResNet块的第一部分，进行组归一化、激活和卷积操作
            out = self.resnet_conv_first[i](out)

            # 将处理后的时间嵌入向量添加到out
            out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]

            # 通过ResNet块的第二部分，继续进行组归一化、激活和卷积操作
            out = self.resnet_conv_second[i](out)

            # 将之前保存的resnet_input通过1x1卷积层调整通道数后，与当前out进行残差连接
            out = out + self.residual_input_conv[i](resnet_input)
            

            # 获取当前特征图的形状信息
            batch_size, channels, h, w = out.shape

            # 将特征图展平成 (batch_size, channels, h * w)，为注意力机制做准备
            in_attn = out.reshape(batch_size, channels, h * w)

            # 对展平后的特征图进行组归一化
            in_attn = self.attention_norms[i](in_attn)

            # 转置为(batch_size, h * w, channels)，符合nn.MultiheadAttention的输入格式
            in_attn = in_attn.transpose(1, 2)

            # 通过多头注意力机制处理特征图
            out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)

            # 将注意力后的输出转置回(batch_size, channels, h * w)，然后重新reshape回(batch_size, channels, h, w)
            out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)

            # 将注意力后的输出与当前的out进行残差连接，增强特征表达
            out = out + out_attn

        return out

### UNet整体结构

In [9]:
class Unet(nn.Module):

    def __init__(self, model_config):
        super().__init__()
        im_channels = model_config['im_channels']
        self.down_channels = model_config['down_channels']
        self.mid_channels = model_config['mid_channels']
        self.t_emb_dim = model_config['time_emb_dim']
        self.down_sample = model_config['down_sample']
        self.num_down_layers = model_config['num_down_layers']
        self.num_mid_layers = model_config['num_mid_layers']
        self.num_up_layers = model_config['num_up_layers']
        
        # 确保中间块（mid_channels）和下采样块（down_channels）的连接通道数匹配，并且下采样标记的数量正确
        assert self.mid_channels[0] == self.down_channels[-1]
        assert self.mid_channels[-1] == self.down_channels[-2]
        assert len(self.down_sample) == len(self.down_channels) - 1
        
        # 时间嵌入投影
        self.t_proj = nn.Sequential(
            nn.Linear(self.t_emb_dim, self.t_emb_dim),
            nn.SiLU(),
            nn.Linear(self.t_emb_dim, self.t_emb_dim)
        )

        self.up_sample = list(reversed(self.down_sample))

        # 将原始图像的通道数映射到下采样路径所需的第一个通道数 
        self.conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1))
        
        # 下采样块列表
        self.downs = nn.ModuleList([])
        for i in range(len(self.down_channels)-1):
            self.downs.append(DownBlock(self.down_channels[i], self.down_channels[i+1], self.t_emb_dim,
                                        down_sample=self.down_sample[i], num_layers=self.num_down_layers))
        # 中间块列表
        self.mids = nn.ModuleList([])
        for i in range(len(self.mid_channels)-1):
            self.mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i+1], self.t_emb_dim,
                                      num_layers=self.num_mid_layers))
        # 上采样块列表
        self.ups = nn.ModuleList([])
        for i in reversed(range(len(self.down_channels)-1)):
            self.ups.append(UpBlock(self.down_channels[i] * 2, self.down_channels[i-1] if i != 0 else 16,
                                    self.t_emb_dim, up_sample=self.down_sample[i], num_layers=self.num_up_layers))
        
        # 输出层
        self.norm_out = nn.GroupNorm(8, 16)
        self.conv_out = nn.Conv2d(16, im_channels, kernel_size=3, padding=1)
    
    def forward(self, x, t):
        # 1) 输入卷积
        out = self.conv_in(x)
        
        
        # 2) 时间步嵌入
        t_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim)
        t_emb = self.t_proj(t_emb)
        
        # 3) 下采样路径
        down_outs = []
        for idx, down in enumerate(self.downs):
            down_outs.append(out)
            out = down(out, t_emb)

        # 4) 中间块
        for mid in self.mids:
            out = mid(out, t_emb)


        # 5) 上采样路径
        for up in self.ups:
            down_out = down_outs.pop()
            out = up(out, down_out, t_emb)
        
        # 6) 最后输出
        out = self.norm_out(out)
        out = nn.SiLU()(out)
        out = self.conv_out(out)
        # out B x C x H x W
        return out

# 定义用于训练的函数

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train(config_path):
    with open(config_path, 'r') as file:
        try:
            config = yaml.safe_load(file)
        except yaml.YAMLError as exc:
            print(exc)

    diffusion_config = config['diffusion_params']
    dataset_config = config['dataset_params']
    model_config = config['model_params']
    train_config = config['train_params']

    datasets_name = dataset_config['datasets_name']

    img_size = model_config['im_size']

    # 是否需要训练
    train_flag = train_config['train_flag']
    if train_flag == False:
        print("train_flag为False,不需要训练")
        return

    # Create the noise scheduler
    scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],
                                     beta_start=diffusion_config['beta_start'],
                                     beta_end=diffusion_config['beta_end'])
    
    # Create the dataset
    if datasets_name == 'MNIST' or datasets_name == 'FashionMNIST':

        transform = Compose([
        ToTensor(),  # 转换为张量并归一化[0,1]
        Lambda(lambda x: (x - 0.5) * 2)]  # 进一步将像素值从 [0, 1] 映射到 [-1, 1]
        )

        dataset = MNISTDataset(dataset_name=datasets_name, root=dataset_config['im_path'], transform=transform)

        # 创建数据集加载器
        loader = DataLoader(dataset, batch_size=train_config['batch_size'], shuffle=True)

    elif datasets_name == 'Custom':
        # 定义transform
        transform = Compose([
            Resize((img_size, img_size)),
            ToTensor(),
            Lambda(lambda x: (x - 0.5) * 2),
        ])

        dataset = MyCustomDataset(root=dataset_config['im_path'], transform=transform)

        loader = DataLoader(dataset, batch_size=train_config['batch_size'], shuffle=True)   
    else:
        raise ValueError("Unsupported dataset_name. Choose either 'MNIST' or 'FashionMNIST' or 'Custom'.")

    # Instantiate the model
    model = Unet(model_config).to(device)
    model.train()
    
    # Create output directories
    if not os.path.exists(train_config['task_name']):
        os.mkdir(train_config['task_name'])

    # Load checkpoint if found
    model_name = f'ddpm_{datasets_name}.pth'
    if os.path.exists(os.path.join(train_config['task_name'],model_name)):
        print('Loading checkpoint as found one')
        model.load_state_dict(torch.load(os.path.join(train_config['task_name'], model_name), map_location=device))
    
    # Specify training parameters
    num_epochs = train_config['num_epochs']
    optimizer = Adam(model.parameters(), lr=train_config['lr'])
    criterion = torch.nn.MSELoss()
    # 初始化最优损失为正无穷大，用于保存损失最低的模型
    best_loss = float("inf")

    for epoch_idx in range(num_epochs):
        losses = []
        # 遍历每一个batch
        for im in tqdm(loader, leave=False, desc=f"Epoch {epoch_idx + 1}/{num_epochs}", colour="#005500"):
            if datasets_name == 'MNIST' or datasets_name == 'FashionMNIST':
                im = im[0].float().to(device)
            elif datasets_name == 'Custom':
                im = im.float().to(device)
            else:
                raise ValueError('图像维度错误')
        

            # Sample random noise
            noise = torch.randn_like(im).to(device)

            # Sample timestep
            t = torch.randint(0, diffusion_config['num_timesteps'], (im.shape[0],)).to(device)

            # Add noise to images according to timestep
            noisy_im = scheduler.add_noise(im, noise, t)
            noise_pred = model(noisy_im, t)

            loss = criterion(noise_pred, noise)
            losses.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        # 计算每个epoch的平均损失
        epoch_loss = np.mean(losses)

        log_string = f"Loss at epoch {epoch_idx + 1}: {epoch_loss:.3f}"

        # 保存最好的模型
        if best_loss > epoch_loss:
            best_loss = epoch_loss
            model_name = f'ddpm_{datasets_name}.pth'
            torch.save(model.state_dict(), os.path.join(train_config['task_name'], model_name))
            
            log_string += " --> Best model ever (stored)"

        print(log_string)
    
    print('Done Training ...')

# 定义用于推理的函数

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def sample(model, scheduler, train_config, model_config, diffusion_config):
    """
    model:神经网络模型
    scheduler:扩散过程的调度器
    train_config:训练或采样的配置字典
    model_config:模型的配置字典
    diffusion_config:扩散过程的配置字典
    """
    # 动画GIF的帧数
    frames_per_gif=100

    # 从0到n_steps生成frames_per_gif个帧索引
    frame_idxs = np.linspace(0, diffusion_config['num_timesteps'] - 1, frames_per_gif, dtype=int)
    frame_idxs = sorted(set(frame_idxs))  # 确保唯一且有序

    # 存储生成过程中的图像帧
    frames = []

    c, h, w = model_config['im_channels'], model_config['im_size'], model_config['im_size']

    # 随机初始化xt
    xt = torch.randn((train_config['num_samples'],
                      model_config['im_channels'],
                      model_config['im_size'],
                      model_config['im_size'])).to(device)

    # 逆向扩散循环
    for i in tqdm(reversed(range(diffusion_config['num_timesteps'])), desc="generating images progess", leave=False, colour="#005500"):
        # 预测噪声
        noise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device))
        
        # 通过调度器（scheduler）从 xtxt​ 计算 xt−1xt−1​
        xt = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device))

        # 保存生成过程的关键帧
        if i in frame_idxs or i == 0: 

            # 保存最后一步去噪的结果
            if i == 0:
                x0 = xt.clone()

            # 将图片像素值从[-1, 1]转换到[0, 255]范围
            normalized = (xt.clone() + 1) / 2  # [0,1]
            normalized = normalized.clamp(0, 1)
            normalized = (normalized * 255).byte()  # [0,255]并转为uint8

            # 转换为NumPy数组
            frame = normalized.cpu().numpy()

            if c == 1:
                # 如果是单通道，移除最后一个维度
                frame = np.squeeze(frame, axis=1)  # (B, H, W)
            else:
                # 多通道，调整为 (B, H, W, C)
                frame = np.transpose(frame, (0, 2, 3, 1))  # (B, H, W, C)

            # 创建图像网格，例如4x4
            grid_size = int(np.sqrt(train_config['num_samples']))
            if grid_size ** 2 != train_config['num_samples']:
                grid_size = int(np.ceil(np.sqrt(train_config['num_samples'])))

            
            # 填充不足的图像
            if train_config['num_samples'] < grid_size ** 2:
                padding = grid_size ** 2 - train_config['num_samples']
                if c == 1:
                    pad_img = np.zeros((padding, h, w), dtype=np.uint8)
                else:
                    pad_img = np.zeros((padding, h, w, c), dtype=np.uint8)
                frame = np.concatenate([frame, pad_img], axis=0)
            
            # 使用einops将图像排列成网格
            if c == 1:
                # 对于单通道，保持为二维数组
                frame = rearrange(frame, '(b1 b2) h w -> (b1 h) (b2 w)', b1=grid_size)
                # 现在frame是(H_total, W_total)，单通道
            else:
                frame = rearrange(frame, '(b1 b2) h w c -> (b1 h) (b2 w) c', b1=grid_size)

            # 添加到帧列表
            frames.append(frame)
    
    # 保存为GIF
    gif_path = os.path.join(train_config['task_name'], 'samples', "ddpm_mydata.gif")
    os.makedirs(os.path.dirname(gif_path), exist_ok=True)
    
    # 保存为GIF
    with imageio.get_writer(gif_path, mode="I") as writer:
        for frame in frames:
            if c == 1:
                # imageio expects (H, W) for grayscale
                writer.append_data(frame)
            else:
                # For RGB or other multi-channel images
                writer.append_data(frame)
        # 在最后一帧停留更长时间
        for _ in range(frames_per_gif // 3):
            writer.append_data(frames[-1])
    
    print(f'结果保存在:{gif_path}')

    # 返回最后一步去噪的结果
    return x0

In [12]:
def infer(config_path):
    # 读取配置文件
    with open(config_path, 'r') as file:
        try:
            config = yaml.safe_load(file)
        except yaml.YAMLError as exc:
            print(exc)
    
    diffusion_config = config['diffusion_params']
    model_config = config['model_params']
    train_config = config['train_params']
    dataset_config = config['dataset_params']

    datasets_name = dataset_config['datasets_name']
    
    # 加载模型
    model_name = f'ddpm_{datasets_name}.pth'
    model = Unet(model_config).to(device)
    model.load_state_dict(torch.load(os.path.join(train_config['task_name'], model_name), map_location=device))
    model.eval()
    
    # 创建噪声调度器
    scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'],
                                     beta_start=diffusion_config['beta_start'],
                                     beta_end=diffusion_config['beta_end'])

    # 推理
    with torch.no_grad():
        generated = sample(model, scheduler, train_config, model_config, diffusion_config)
    
    # 返回最后一步去噪的结果
    return generated

# 定义用于可视化图片的函数

In [13]:
def show_images(images, title=""):
    """
    images: 一个包含多张RGB或灰度图片的输入，形状 (B, C, H, W)
    title: 用来给显示的图片设置标题
    """
    # 将图片转换为numpy数组
    if isinstance(images, torch.Tensor):
        images = images.detach().cpu().numpy()
    
    # 检查通道数并转换为(B, H, W, C)或(B, H, W)
    if images.shape[1] == 3:
        images = images.transpose(0, 2, 3, 1)  # (B, H, W, 3)
        is_grayscale = False
    elif images.shape[1] == 1:
        images = images.squeeze(1)  # (B, H, W)
        is_grayscale = True
    else:
        raise ValueError("图像维度错误")
    
    # 将像素值从 [-1, 1] 映射到 [0, 255]
    images = (images + 1) / 2 * 255
    images = np.clip(images, 0, 255).astype(np.uint8)  # 确保值在 [0, 255] 并转换为uint8
    
    # 定义显示网格的行数和列数
    num_images = len(images)
    rows = int(np.sqrt(num_images))
    cols = int(np.ceil(num_images / rows))

    # 创建带有黑色背景的图形
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 2, rows * 2), facecolor='black')
    fig.suptitle(title, fontsize=30, color='white')

    # 如果只有一个子图，axes不再是二维数组
    if rows * cols == 1:
        axes = np.array([[axes]])
    elif rows == 1 or cols == 1:
        axes = axes.reshape(rows, cols)
    
    # 填充网格并显示图片
    idx = 0
    for r in range(rows):
        for c in range(cols):
            ax = axes[r, c]
            ax.set_facecolor('black')  # 设置子图背景为黑色
            ax.axis("off")  # 去掉坐标轴
            if idx < num_images:
                if is_grayscale:
                    ax.imshow(images[idx], cmap='gray', vmin=0, vmax=255)
                else:
                    ax.imshow(images[idx])
                idx += 1
            else:
                ax.remove()  # 如果没有更多的图片，移除多余的子图
    
    plt.tight_layout()
    plt.subplots_adjust(top=0.90)  # 调整顶部以适应标题
    plt.show()

# 可视化数据集的第一个batch

In [14]:
# 可视化第一个batch的函数
def show_first_batch(config_path):
    with open(config_path, 'r') as file:
        try:
            config = yaml.safe_load(file)
        except yaml.YAMLError as exc:
            print(exc)
    
    diffusion_config = config['diffusion_params']
    dataset_config = config['dataset_params']
    model_config = config['model_params']
    train_config = config['train_params']

    datasets_name = dataset_config['datasets_name']

    img_size = model_config['im_size']

    # Create the dataset
    if datasets_name == 'MNIST' or datasets_name == 'FashionMNIST':

        transform = Compose([
        ToTensor(),  # 转换为张量并归一化[0,1]
        Lambda(lambda x: (x - 0.5) * 2)]  # 进一步将像素值从 [0, 1] 映射到 [-1, 1]
        )

        dataset = MNISTDataset(dataset_name=datasets_name, root=dataset_config['im_path'], transform=transform)

        # 创建数据集加载器
        loader = DataLoader(dataset, batch_size=train_config['batch_size'], shuffle=True)

    elif datasets_name == 'Custom':
        # 定义transform
        transform = Compose([
            Resize((img_size, img_size)),
            ToTensor(),
            Lambda(lambda x: (x - 0.5) * 2),
        ])

        dataset = MyCustomDataset(root=dataset_config['im_path'], transform=transform)

        loader = DataLoader(dataset, batch_size=train_config['batch_size'], shuffle=True)   
    else:
        raise ValueError("Unsupported dataset_name. Choose either 'MNIST' or 'FashionMNIST' or 'Custom'.")
    
    
    for batch in loader:
        if datasets_name == 'MNIST' or datasets_name == 'FashionMNIST':
            print(batch[0].shape)
            show_images(batch[0], "Images in the first batch")
        elif datasets_name == 'Custom':
            print(batch.shape)
            show_images(batch, "Images in the first batch")
        else:
            raise ValueError("图像维度错误")
        break

In [None]:
config_path = 'cfg.yaml'
show_first_batch(config_path)

# 可视化前向加噪过程

In [16]:
def show_forward(config_path):
    num_timesteps = 1000
    beta_start = 0.0001
    beta_end = 0.02

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    with open(config_path, 'r') as file:
        try:
            config = yaml.safe_load(file)
        except yaml.YAMLError as exc:
            print(exc)
    
    diffusion_config = config['diffusion_params']
    dataset_config = config['dataset_params']
    model_config = config['model_params']
    train_config = config['train_params']

    datasets_name = dataset_config['datasets_name']

    img_size = model_config['im_size']

    # Create the dataset
    if datasets_name == 'MNIST' or datasets_name == 'FashionMNIST':

        transform = Compose([
        ToTensor(),  # 转换为张量并归一化[0,1]
        Lambda(lambda x: (x - 0.5) * 2)]  # 进一步将像素值从 [0, 1] 映射到 [-1, 1]
        )

        dataset = MNISTDataset(dataset_name=datasets_name, root=dataset_config['im_path'], transform=transform)

        # 创建数据集加载器
        loader = DataLoader(dataset, batch_size=train_config['batch_size'], shuffle=True)

    elif datasets_name == 'Custom':
        # 定义transform
        transform = Compose([
            Resize((img_size, img_size)),
            ToTensor(),
            Lambda(lambda x: (x - 0.5) * 2),
        ])

        dataset = MyCustomDataset(root=dataset_config['im_path'], transform=transform)

        loader = DataLoader(dataset, batch_size=train_config['batch_size'], shuffle=True)   
    else:
        raise ValueError("Unsupported dataset_name. Choose either 'MNIST' or 'FashionMNIST' or 'Custom'.")
    

    for batch in loader:
        if datasets_name == 'MNIST' or datasets_name == 'FashionMNIST':
            imgs = batch[0].to(device)
        elif datasets_name == 'Custom':
            imgs = batch.to(device)
        else:
            raise ValueError("图像维度错误")
            

        n, c, h, w = imgs.shape

        # 可视化原始图像
        show_images(imgs, "Original Images")

        # 定义需要进行可视化的噪声比例
        percents = [0.25, 0.5, 0.75, 1.0]

        scheduler = LinearNoiseScheduler(num_timesteps, beta_start, beta_end)

        for percent in percents:
            # 计算对应的时间步，确保不超出范围
            t_step = max(0, min(int(percent * num_timesteps) - 1, num_timesteps - 1))

            # 创建一个与批量大小相同的时间步张量
            t = torch.full((n,), t_step, device=device, dtype=torch.long)

            eta = torch.randn(n, c, h, w).to(device)
            
            # 进行前向加噪过程
            noisy_image = scheduler.add_noise(original=imgs, noise=eta, t=t)

            # 可视化带噪声的图像
            show_images(noisy_image, f"DDPM Noisy Images {int(percent * 100)}%")


        break

In [None]:
config_path = 'cfg.yaml'
show_forward(config_path)

# 开始训练

In [None]:
config_path = "cfg.yaml"

train(config_path)

# 利用训练好的模型进行推理

In [None]:
config_path = "cfg.yaml"
generated = infer(config_path)
show_images(generated, "Final result")