# 图片扩散模型

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

class SimpleUNet(nn.Module):
    """简单的U-Net结构作为去噪网络"""
    def __init__(self, in_channels=1):  # 修改为1通道，因为MNIST是灰度图
        super().__init__()
        # 编码器
        self.enc1 = nn.Conv2d(in_channels, 64, 3, padding=1)
        self.enc2 = nn.Conv2d(64, 128, 3, padding=1, stride=2)
        self.enc3 = nn.Conv2d(128, 256, 3, padding=1, stride=2)
        
        # 时间嵌入处理
        self.time_mlp = nn.Sequential(
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
        )
        
        # 解码器
        self.dec3 = nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1)
        self.dec2 = nn.ConvTranspose2d(256, 64, 4, stride=2, padding=1)
        self.dec1 = nn.Conv2d(128, in_channels, 3, padding=1)
        
    def forward(self, x, time_emb):
        # 编码过程
        e1 = F.relu(self.enc1(x))
        e2 = F.relu(self.enc2(e1))
        e3 = F.relu(self.enc3(e2))
        
        # 处理时间嵌入
        time_emb = self.time_mlp(time_emb)
        time_emb = time_emb.view(-1, 256, 1, 1).repeat(1, 1, e3.shape[2], e3.shape[3])
        e3 = e3 + time_emb
        
        # 解码过程(使用跳跃连接)
        d3 = F.relu(self.dec3(e3))
        d2 = F.relu(self.dec2(torch.cat([d3, e2], dim=1)))
        d1 = self.dec1(torch.cat([d2, e1], dim=1))
        return d1

class DiffusionModel:
    def __init__(self, timesteps=1000):
        self.timesteps = timesteps
        # 定义噪声调度
        self.beta = torch.linspace(1e-4, 0.02, timesteps)
        self.alpha = 1. - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
        
        # 初始化去噪网络
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.denoise_model = SimpleUNet().to(self.device)
        self.optimizer = optim.Adam(self.denoise_model.parameters(), lr=1e-4)
        
    def time_embedding(self, timesteps, dim=256):
        """生成时间步的嵌入向量"""
        half = dim // 2
        freqs = torch.exp(
            -torch.log(torch.tensor(10000.0)) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(self.device)
        args = timesteps[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        return embedding

    def add_noise(self, x_0, t):
        """给图像添加噪声"""
        eps = torch.randn_like(x_0)
        alpha_bar_t = self.alpha_bar[t].to(self.device)
        noisy = torch.sqrt(alpha_bar_t)[:, None, None, None] * x_0 + \
                torch.sqrt(1 - alpha_bar_t)[:, None, None, None] * eps
        return noisy, eps
    
    def train_step(self, x_0):
        """单个训练步骤"""
        batch_size = x_0.shape[0]
        
        # 随机采样时间步
        t = torch.randint(0, self.timesteps, (batch_size,)).to(self.device)
        
        # 添加噪声
        noisy_images, noise = self.add_noise(x_0, t)
        
        # 预测噪声
        time_emb = self.time_embedding(t)
        predicted_noise = self.denoise_model(noisy_images, time_emb)
        
        # 计算损失
        loss = F.mse_loss(predicted_noise, noise)
        
        # 反向传播
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        return loss.item()
    
    def train(self, dataloader, epochs):
        """训练过程"""
        self.denoise_model.train()
        for epoch in range(epochs):
            total_loss = 0
            with tqdm(dataloader, desc=f'Epoch {epoch+1}/{epochs}') as pbar:
                for batch in pbar:
                    images, _ = batch
                    images = images.to(self.device)
                    loss = self.train_step(images)
                    total_loss += loss
                    pbar.set_postfix({'loss': loss})
            
            avg_loss = total_loss / len(dataloader)
            print(f'Epoch {epoch+1}, Average Loss: {avg_loss:.4f}')
            
            # 每个epoch结束后生成一个示例
            if (epoch + 1) % 5 == 0:
                self.generate_samples(epoch + 1)
    
    @torch.no_grad()
    def sample(self, batch_size=1, img_size=32):
        """生成图像采样过程"""
        self.denoise_model.eval()
        x = torch.randn(batch_size, 1, img_size, img_size).to(self.device)
        
        for t in reversed(range(self.timesteps)):
            t_batch = torch.full((batch_size,), t, dtype=torch.long).to(self.device)
            time_emb = self.time_embedding(t_batch)
            predicted_noise = self.denoise_model(x, time_emb)
            
            alpha_t = self.alpha[t].to(self.device)
            alpha_bar_t = self.alpha_bar[t].to(self.device)
            beta_t = self.beta[t].to(self.device)
            
            if t > 0:
                noise = torch.randn_like(x)
            else:
                noise = 0
                
            x = (1 / torch.sqrt(alpha_t)) * (x - ((1 - alpha_t) / torch.sqrt(1 - alpha_bar_t)) * predicted_noise) + \
                torch.sqrt(beta_t) * noise
        
        return x
    
    def generate_samples(self, epoch, num_samples=5):
        """生成示例图像并保存"""
        samples = self.sample(batch_size=num_samples)
        samples = samples.cpu()
        
        # 创建图像网格
        fig, axes = plt.subplots(1, num_samples, figsize=(15, 3))
        for i, sample in enumerate(samples):
            axes[i].imshow(sample.squeeze(), cmap='gray')
            axes[i].axis('off')
        
        plt.suptitle(f'Generated Samples at Epoch {epoch}')
        plt.savefig(f'samples_epoch_{epoch}.png')
        plt.close()

def load_mnist_data(batch_size=64):
    """加载MNIST数据集"""
    transform = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    
    dataset = datasets.MNIST(root='./data', train=True, 
                           download=True, transform=transform)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

def main():
    # 设置随机种子
    torch.manual_seed(42)
    
    # 创建模型和数据加载器
    diffusion = DiffusionModel(timesteps=1000)
    dataloader = load_mnist_data(batch_size=128)
    
    # 训练模型
    print("开始训练...")
    diffusion.train(dataloader, epochs=2)
    
    # 保存模型
    torch.save(diffusion.denoise_model.state_dict(), 'diffusion_model.pth')
    
    # 生成最终样本
    print("生成最终样本...")
    samples = diffusion.sample(batch_size=10)
    
    # 显示生成的样本
    plt.figure(figsize=(20, 4))
    for i, sample in enumerate(samples):
        plt.subplot(1, 10, i+1)
        plt.imshow(sample.cpu().squeeze(), cmap='gray')
        plt.axis('off')
    plt.savefig('final_samples.png')
    plt.close()

if __name__ == "__main__":
    main()

开始训练...


Epoch 1/2: 100%|██████████| 469/469 [07:18<00:00,  1.07it/s, loss=0.0609]


Epoch 1, Average Loss: 0.1818


Epoch 2/2: 100%|██████████| 469/469 [08:15<00:00,  1.06s/it, loss=0.0519]


Epoch 2, Average Loss: 0.0572
生成最终样本...
