In [None]:
import torch
from torch import nn,optim
from torch.utils.data import DataLoader
from torch.nn import functional as F
from torchvision import datasets,transforms
import matplotlib.pyplot as plt
import numpy as np
import math


device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,),(0.5,))
])

train_dataset=datasets.MNIST(root='./dataset',train=True,download=True,transform=transform)
train_loader=DataLoader(train_dataset,batch_size=128,shuffle=True)

In [2]:
T=300
beta_start=1e-4
beta_end=0.02
betas=torch.linspace(beta_start,beta_end,T)

alphas=1.-betas
alphas_cumprod=torch.cumprod(alphas,dim=0)
sqrt_alphas_cumprod=torch.sqrt(alphas_cumprod)
sqrt_1minusalphas_cumprod=torch.sqrt(1-alphas_cumprod)

print(f'sqrt_alphas_cumprod.shape:{sqrt_alphas_cumprod.shape}')

def q_sample(x_start,t,noise=None):
    if noise is None:
        noises = torch.randn_like(x_start)
    sqrt_alpha=sqrt_alphas_cumprod[t].view(-1,1,1,1)
    sqrt_1minusalpha=sqrt_1minusalphas_cumprod[t].view(-1,1,1,1)
    return x_start*sqrt_alpha+noise*sqrt_1minusalpha

sqrt_alphas_cumprod.shape:torch.Size([300])


In [3]:
def get_timestep_embedding(t, dim):
    half_dim = dim // 2
    emb=math.log(10000)/(half_dim-1)
    emb=torch.exp(-emb*torch.arange(half_dim,device=t.device))
    emb=t.float()[:,None]*emb[None,:]
    emb=torch.cat([torch.sin(emb),torch.cos(emb)],dim=-1)
    if dim % 2 == 1:  # odd embedding_dim
        emb = F.pad(emb, (0,1))
    return emb


In [None]:
class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, base_channels=32, time_emb_dim=128):
        super().__init__()
        # 时间步嵌入MLP
        self.time_mlp = nn.Sequential(
            nn.Linear(time_emb_dim, base_channels*4),
            nn.ReLU(),
            nn.Linear(base_channels*4, base_channels*4)
        )
        # 编码器
        self.enc1 = nn.Sequential(
            nn.Conv2d(in_channels, base_channels, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(base_channels, base_channels, 3, padding=1),
            nn.ReLU()
        )
        self.enc2 = nn.Sequential(
            nn.MaxPool2d(2),
            nn.Conv2d(base_channels, base_channels*2, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(base_channels*2, base_channels*2, 3, padding=1),
            nn.ReLU()
        )
        self.enc3 = nn.Sequential(
            nn.MaxPool2d(2),
            nn.Conv2d(base_channels*2, base_channels*4, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(base_channels*4, base_channels*4, 3, padding=1),
            nn.ReLU()
        )
        # 解码器
        self.up1 = nn.ConvTranspose2d(base_channels*4, base_channels*2, 2, stride=2)
        self.dec1 = nn.Sequential(
            nn.Conv2d(base_channels*4, base_channels*2, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(base_channels*2, base_channels*2, 3, padding=1),
            nn.ReLU()
        )
        self.up2 = nn.ConvTranspose2d(base_channels*2, base_channels, 2, stride=2)
        self.dec2 = nn.Sequential(
            nn.Conv2d(base_channels*2, base_channels, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(base_channels, base_channels, 3, padding=1),
            nn.ReLU()
        )
        self.out_conv = nn.Conv2d(base_channels, out_channels, 1)

        self.time_emb_dim = time_emb_dim

    def forward(self, x, t):
        # t: [batch] int64
        t_emb = get_timestep_embedding(t, self.time_emb_dim)  # [batch, time_emb_dim]
        t_emb = self.time_mlp(t_emb)  # [batch, base_channels*4]
        # 编码
        e1 = self.enc1(x)         # [B, C, 28, 28]
        e2 = self.enc2(e1)        # [B, 2C, 14, 14]
        e3 = self.enc3(e2)        # [B, 4C, 7, 7]
        # 加时间步嵌入到bottleneck
        t_emb = t_emb[:, :, None, None]  # [B, 4C, 1, 1]
        e3 = e3 + t_emb
        # 解码
        d1 = self.up1(e3)         # [B, 2C, 14, 14]
        d1 = torch.cat([d1, e2], dim=1)
        d1 = self.dec1(d1)
        d2 = self.up2(d1)         # [B, C, 28, 28]
        d2 = torch.cat([d2, e1], dim=1)
        d2 = self.dec2(d2)
        out = self.out_conv(d2)
        return out

model=UNet().to(device)
optimizer=optim.Adam(model.parameters(),lr=1e-3)

epochs=10

for epoch in range(epochs):
    model.train()
    for x,_ in train_loader:
        x=x.to(device)
        batch_size=x.size(0)
        t=torch.randint(0,T,(batch_size,),device=device,).long()
        noise=torch.randn_like(x)
        x_noisy=q_sample(x,t,noise)
        pred_noise=model(x_noisy,t)
        loss=F.mse_loss(pred_noise,noise)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1} Loss: {loss.item():.4f}")

torch.save(model.state_dict())

AttributeError: module 'torch.functional' has no attribute 'mse_loss'