In [1]:
# 使用MNIST数据集
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import numpy as np
from torch import nn, einsum
# 数据集
train_set=torchvision.datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transforms.Compose([transforms.ToTensor()])
)
test_set=torchvision.datasets.MNIST(
    root='./data',
    train=False,
    download=True,
    transform=transforms.Compose([transforms.ToTensor()])
)
batch_size = 512
train_loader=torch.utils.data.DataLoader(train_set,batch_size=batch_size,shuffle=True,num_workers=8,drop_last=True)
test_loader=torch.utils.data.DataLoader(test_set,batch_size=batch_size,shuffle=False,drop_last=True)



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 画图
def plot_images(images,labels):
    n_images=len(images)
    rows=int(np.sqrt(n_images))
    cols=int(np.sqrt(n_images))
    fig=plt.figure()
    for i in range(rows*cols):
        ax=fig.add_subplot(rows,cols,i+1)
        ax.imshow(images[i].view(28,28).cpu().numpy(),cmap='bone')
        ax.set_title(labels[i].item())
        ax.axis('off')
    plt.show()
images,labels=next(iter(train_loader))
# plot_images(images,labels)
# 如果val非None则返回val，否则(如果d为函数则返回d(),否则返回d)
from inspect import isfunction # inspect模块https://www.cnblogs.com/yaohong/p/8874154.html主要提供了四种用处：1.对是否是模块、框架、函数进行类型检查 2.获取源码 3.获取类或者函数的参数信息 4.解析堆栈
from functools import partial 
def exists(x):
    return x is not None

def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


# 加噪测试

In [3]:
t=1000
import numpy as np
# beta from 0.0001 to 0.02 in 1000 steps 
# exp from log(0.0001) to log(0.02) in 1000 steps

beta=torch.tensor(np.linspace(0.0001,0.02,t))
alpha=1-beta
# alpha 累乘
alpha_multiply=torch.cumprod(alpha,dim=-1)

# 创建和图片一样shape高斯噪声
noise=torch.randn_like(images)*torch.sqrt(1-alpha_multiply[-1])+torch.sqrt(alpha_multiply[-1])*images

# 网络定义

In [4]:
from einops import rearrange 

In [5]:

# plot_images(noise,labels)
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x
# 上采样
def Upsample(dim):
    return nn.ConvTranspose2d(dim, dim, 4, 2, 1)

# 下采样
def Downsample(dim):
    return nn.Conv2d(dim, dim, 4, 2, 1)
import math
class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time.unsqueeze(1) * embeddings.unsqueeze(0)
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings


# PE测试

In [6]:
pos_emd=SinusoidalPositionEmbeddings(28)


In [7]:
t=torch.full((batch_size,),1).long()

In [8]:
pos_emd(t).shape

torch.Size([512, 28])

In [9]:
class Block(nn.Module):
    def __init__(self, dim, dim_out, groups = 8):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
        self.norm = nn.GroupNorm(groups, dim_out) #GN归一化 https://zhuanlan.zhihu.com/p/177853578
        self.act = nn.SiLU()

    def forward(self, x, scale_shift = None):
        x = self.proj(x)
        x = self.norm(x)

        if scale_shift is not None:
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x
class ConvNextBlock(nn.Module):
    """https://arxiv.org/abs/2201.03545"""

    def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
        super().__init__()
        # 如果time_emb_dim存在则有mlp层
        self.mlp = (
            nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim))
            if exists(time_emb_dim)
            else None
        )

        self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)

        self.net = nn.Sequential(
            nn.GroupNorm(1, dim) if norm else nn.Identity(),
            nn.Conv2d(dim, dim_out * mult, 3, padding=1),
            nn.GELU(), # Gaussian Error Linear Unit
            nn.GroupNorm(1, dim_out * mult),
            nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),
        )

        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb=None):
        h = self.ds_conv(x)

        if self.mlp is not None and time_emb is not None:
            assert time_emb is not None, "time embedding must be passed in"
            condition = self.mlp(time_emb)
            h = h + rearrange(condition, "b c -> b c 1 1")

        h = self.net(h)
        return h + self.res_conv(x)

In [10]:
# 先norm后fn
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.GroupNorm(1, dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)

In [11]:
class Attention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)# qkv为一个元组，其中每一个元素的大小为torch.Size([b, hidden_dim, h, w])
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        ) # qkv中每个元素从torch.Size([b, hidden_dim, h, w])变为torch.Size([b, heads, dim_head, h*w])
        q = q * self.scale # q扩大dim_head**-0.5倍

        sim = einsum("b h d i, b h d j -> b h i j", q, k) # sim有torch.Size([b, heads, h*w, h*w])
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1) # attn有torch.Size([b, heads, h*w, h*w])

        out = einsum("b h i j, b h d j -> b h i d", attn, v) # [b, heads, h*w, h*w]和[b, heads, dim_head, h*w] 得 out为[b, heads, h*w, dim_head]
        out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) # 得out为[b, hidden_dim, h, w]
        return self.to_out(out) # 得 [b, dim, h, w]

# 和class Attention几乎一致
class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)

        self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), 
                                    nn.GroupNorm(1, dim))

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1) 
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )

        q = q.softmax(dim=-2)
        k = k.softmax(dim=-1)

        q = q * self.scale
        context = torch.einsum("b h d n, b h e n -> b h d e", k, v)

        out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
        out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
        return self.to_out(out)

In [12]:
class Unet(nn.Module):
    def __init__(
        self,
        dim, # 下例中，dim=image_size=28
        init_dim=None,# 默认为None，最终取dim // 3 * 2
        out_dim=None, # 默认为None，最终取channels
        dim_mults=(1,2,4,8),
        channels=3, # 通道数默认为3
        with_time_emb=True, # 是否使用embeddings
        resnet_block_groups=8, # 如果使用ResnetBlock，groups=resnet_block_groups
        use_convnext=True, # 是True使用ConvNextBlock，是Flase使用ResnetBlock
        convnext_mult=2, # 如果使用ConvNextBlock，mult=convnext_mult
    ):
        super().__init__()
        self.channels = channels
        
        init_dim = default(init_dim, dim // 3 * 2)
        self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3)
        
        dims = [init_dim, *map(lambda m: dim * m, dim_mults)] # 从头到尾dim组成的列表
        in_out = list(zip(dims[:-1], dims[1:])) # dim对组成的列表
        # 使用ConvNextBlock或ResnetBlock
        if use_convnext:
            block_klass = partial(ConvNextBlock, mult=convnext_mult)
        # else:
        #     block_klass = partial(ResnetBlock, groups=resnet_block_groups)
            
        # time embeddings
        if with_time_emb:
            time_dim = dim * 4
            self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(dim),
                nn.Linear(dim, time_dim),
                nn.GELU(),
                nn.Linear(time_dim, time_dim),
            )
        else:
            time_dim = None
            self.time_mlp = None
            
        # layers
        self.downs = nn.ModuleList([]) # 初始化下采样网络列表
        self.ups = nn.ModuleList([]) # 初始化上采样网络列表
        num_resolutions = len(in_out) # dim对组成的列表的长度
        
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1) # 是否到了最后一对
            
            self.downs.append(
                nn.ModuleList(
                [
                    block_klass(dim_in, dim_out, time_emb_dim=time_dim),
                    block_klass(dim_out, dim_out, time_emb_dim=time_dim),
                    Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                    Downsample(dim_out) if not is_last else nn.Identity(),
                ]
                )
            )
        
        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim,time_emb_dim=time_dim)
        
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (num_resolutions - 1)

            self.ups.append(
                nn.ModuleList(
                    [
                        block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                        Upsample(dim_in) if not is_last else nn.Identity(),
                    ]
                )
            )
            
        out_dim = default(out_dim, channels)
        self.final_conv = nn.Sequential(
            block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
        )
        
    def forward(self, x, time):
        x = self.init_conv(x)

        t = self.time_mlp(time) if exists(self.time_mlp) else None

        h = []

        # downsample
        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)
            h.append(x)
            x = downsample(x)

        # bottleneck
        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        # upsample
        for block1, block2, attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim=1)
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)
            x = upsample(x)

        return self.final_conv(x)

In [13]:
image_size = 28
channels = 1


In [14]:
model = Unet(
    dim=image_size,
    channels = channels,
    dim_mults=(1,2,4)
)

# Diffusion

In [15]:

class SimpleGaussianDiffusion(nn.Module):
    def __init__(self,T,denoise_net,device):
        super(SimpleGaussianDiffusion,self).__init__()
        self.beta=torch.tensor(np.linspace(0.0001,0.02,T)).repeat(batch_size,1).to(device)
        # print(self.beta.shape)
        self.device=device
        self.T=T
        self.denoise_net=denoise_net
        self.alpha=1-self.beta
        # positional embedding about t
        # self.t_emb=nn.Embedding(T,128)

        # alpha 累乘
        self.alpha_multiply=torch.cumprod(self.alpha,dim=-1)
        print(self.alpha_multiply.shape)
        self.alpha_multiply_prev=F.pad(self.alpha_multiply[...,:-1],(1,0),value=1.)
        print(self.alpha_multiply_prev.shape)
        self.posterior_variance = self.beta * (1. - self.alpha_multiply_prev) / (1. - self.alpha_multiply)
        
        print(self.posterior_variance.shape)
    def forward_add_noise(self,images,t,eps=None):
        # 获取alpha
        if eps is  None:
            eps=torch.randn_like(images)
        alpha_multiply=torch.gather(self.alpha_multiply,dim=1,index=t.unsqueeze(1)).float()
        # print(alpha_multiply.shape)
        # 创建和图片一样shape高斯噪声
        noise=eps*torch.sqrt(1-alpha_multiply).unsqueeze(1).unsqueeze(1) +torch.sqrt(alpha_multiply).unsqueeze(1).unsqueeze(1)*images
        return noise,eps
    
    def forward(self,x,t,eps=None):
        # t is a scalar
        # x shape is [batch_size,H,W,C]
        batch_size,C,H,W=x.shape
        # print(x.shape)
        # x=x.permute(0,3,1,2)
        # embedding t to [batch_size,1,1,128]
        # pe_t=pe_t.unsqueeze(0).unsqueeze(2).unsqueeze(2).expand(batch_size,128,H,W)
        noise,noise_std=self.forward_add_noise(x,t,eps)
        # print(noise.shape)
        # print(pe_t.shape)
        # noise=torch.cat([noise,pe_t],dim=1)
        # predited noise
        pred_noise=self.denoise_net(noise,t)

        return pred_noise
    @torch.no_grad()
    def p_sample(self,x,t):
        # print(t.shape)
        B=x.shape[0]
        t_=torch.full((B,),t).long().cuda()
        z=torch.randn_like(x)
        beta=torch.gather(self.beta,dim=1,index=t_.unsqueeze(1)).unsqueeze(1).unsqueeze(1).float()
        alpha=torch.gather(self.alpha,dim=1,index=t_.unsqueeze(1)).unsqueeze(1).unsqueeze(1).float()
        alpha_multiply=torch.gather(self.alpha_multiply,dim=1,index=t_.unsqueeze(1)).unsqueeze(1).unsqueeze(1).float()
        # print(alpha_multiply.shape,alpha.shape,beta.shape)
        # print(x.shape)
        # print(self.denoise_net(x,t).shape)
        x_t_1=1/(torch.sqrt(alpha))*(x-(beta/torch.sqrt(1-alpha_multiply))*self.denoise_net(x,t_))
        # t==0的地方，beta=0
        posterior_variance_t=torch.gather(self.posterior_variance,dim=1,index=t_.unsqueeze(1)).unsqueeze(1).unsqueeze(1).float()
        if t==0:
        
            return x_t_1
        else:
            return x_t_1+ torch.sqrt(beta) *z
    @torch.no_grad()
    def p_sample_loop(self,x):
        for i in reversed(range(0,self.T)):
            x=self.p_sample(x,i)
        return x
# test


In [16]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

diffusion=SimpleGaussianDiffusion(T=1000,denoise_net=model.to(device)
,device=device).to(device)


torch.Size([512, 1000])
torch.Size([512, 1000])
torch.Size([512, 1000])


# 训练

In [17]:
loss=nn.MSELoss()
optimizer=optim.Adam(model.parameters(),lr=1e-3)
# 训练
eps=None
for epoch in range(10):
    for images,labels in train_loader:
        images=images.to(device)
        labels=labels.to(device)
        optimizer.zero_grad()
        t = torch.randint(0, 1000, (batch_size,), device=device).long()
        # print(t[0])
        if eps is None:
            eps=torch.randn_like(images)
        pred_noise=diffusion(images,t.to(device),eps)
        l=loss(pred_noise,eps)
        l.backward()
        optimizer.step()
    print(f'epoch {epoch},loss={l.item()}')

# 采样

In [None]:
z=torch.randn(16,1,28,28).to(device)
t=torch.tensor([0]).to(device)


In [48]:
re_sample=diffusion.p_sample_loop(z)


In [None]:
plot_images(re_sample,torch.zeros(16))

# 要点：
* 去噪要逆向
* 训练时，噪声需要固定（？
