In [None]:
import torch, torchvision
from torch.utils import data
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
from denoising_diffusion_pytorch import Unet
from typing import *

In [None]:
mnist_dataset = torchvision.datasets.MNIST(root='../dataset', 
    train=True, transform=torchvision.transforms.ToTensor()
)

In [None]:
shape = (1, 28, 28)
num_steps = 100
gpu = torch.device("cuda")

betas = torch.linspace(0.0001, 0.02, num_steps).to(gpu)
alphas = 1. - betas
alphas_bar = torch.cumprod(alphas, dim=0)
alphas_bar_sqrt = alphas_bar.sqrt()
betas_bar_sqrt = torch.sqrt(1. - alphas_bar)

In [None]:
def get_x_t(x_0: torch.Tensor, t: torch.Tensor, e_t: Optional[torch.Tensor]=None) -> torch.Tensor:
    """
    Get x_t by x_0 and t.
    """
    if e_t is None:
        e_t = torch.randn_like(x_0)
    
    return x_0 * alphas_bar_sqrt[t].reshape(-1, 1, 1, 1)\
         + e_t * betas_bar_sqrt[t].reshape(-1, 1, 1, 1)


`Unet` 参数：

* `dim` 为中间隐含层的通道数
* `dim_mults` 为隐含层的压缩倍数
* `channels` 为输入的通道数

In [None]:
unet = Unet(dim=16, dim_mults=(1, 2, 4), channels=1)

开始训练

In [None]:
batch_size = 10
epochs = 1000
lr = 1e-3


loader = data.DataLoader(mnist_dataset, batch_size, shuffle=True)
mse = nn.MSELoss()
optimizer = torch.optim.Adam(unet.parameters(), lr)

In [None]:
unet.to(gpu)
unet.train()

for epoch in range(epochs):
    sum_loss = 0.
    cnt = 0

    for x, _ in loader:
        x = x.to(gpu)
        t = torch.randint(0, num_steps, size=(batch_size,)).long().to(gpu)
        e_t = torch.randn_like(x)
        x_t = get_x_t(x, t, e_t)
        e_hat: torch.Tensor = unet(x_t, t)
        loss: torch.Tensor = mse(e_hat, e_t)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        sum_loss += float(loss)
        cnt += 1
    
    print(f"Epoch {epoch + 1}, loss {sum_loss / cnt}")


In [None]:
unet = torch.load("./unet.pkl")

@torch.no_grad()
def generate(num: int) -> torch.Tensor:
    unet.eval()
    x_t = torch.randn((num, *shape)).to(gpu)
    for t in reversed(range(num_steps)):
        z_t = torch.randn_like(x_t) if t > 0 else torch.zeros_like(x_t)
        t = t * torch.ones(num).long().to(gpu)
        # print(x_t.shape, t.shape)
        e_hat = unet(x_t, t)
        t = t.reshape(-1, 1, 1, 1)
        x_t = 1 / alphas[t].sqrt() * (x_t - betas[t] / betas_bar_sqrt[t] * e_hat) \
            + betas[t] * z_t
    return x_t

In [None]:
x = generate(100)

x.shape

for k in range(100):
    plt.subplot(10, 10, k + 1)
    plt.imshow(x[k].cpu().numpy().squeeze(), cmap='gray')


In [None]:
from torchvision.utils import save_image


save_image(x[1], "1.png", normalize=True)