定义Unet，基于时间的ScoreNet

In [2]:
import torch;
import torch.nn as nn;
import torch.nn.functional as F;
import numpy as np;
import functools;

In [3]:
class TimeEmbedding(nn.Module):

    "使用高斯随机特征进行 时间的embedding"

    def __init__(self,embed_dim,scale = 30.):
        super().__init__()
        self.w = nn.Parameter(torch.randn(embed_dim // 2)*scale,requires_grad=False)
    def forward(self,x):#传入的是时间t [32]
        x_proj = x[:,None] * self.w[None,:]*2*np.pi # None 是用来添加一个维度的，x变成了三维
        return torch.cat([torch.sin(x_proj),torch.cos(x_proj)],dim=-1) #[32,128]+[32,128]->[32,256]

MLP扩大维度

In [4]:
class Dense(nn.Module):
    def __init__(self,input_dim,output_dim):
        super().__init__()
        self.dense = nn.Linear(input_dim,output_dim)

    def forward(self,x):
        return self.dense(x)[...,None,None] #...就是前面所有维度的意思 ,最后返回的tensor应该是 以第一层为例，[1,32,1,1]

In [5]:
#定义ScoreNet

class ScoreNet(nn.Module):
    def __init__(self,marginal_prob_std,channels=[32,64,128,256],embed_dim = 256):
        super().__init__()
        #时间编码层
        self.embed = nn.Sequential(
            TimeEmbedding(embed_dim),
            nn.Linear(embed_dim,embed_dim)
        )
        #Unet编码层

        self.conv1 = nn.Conv2d(1,channels[0],3,stride=1,bias=False);
        self.dense1 = Dense(embed_dim,channels[0])
        self.gnorm1 = nn.GroupNorm(4,num_channels=channels[0])

        self.conv2  = nn.Conv2d(channels[0],channels[1],3,stride=2,bias=False)
        self.dense2 = Dense(embed_dim,channels[1])
        self.gnorm2 = nn.GroupNorm(32,num_channels=channels[1])

        self.conv3 = nn.Conv2d(channels[1],channels[2],3,stride=2,bias=False)
        self.dense3 = Dense(embed_dim,channels[2])
        self.gnorm3 = nn.GroupNorm(32,num_channels=channels[2])

        self.conv4 = nn.Conv2d(channels[2],channels[3],3,stride=2,bias=False)
        self.dense4 = Dense(embed_dim,channels[3])
        self.gnorm4 = nn.GroupNorm(32,num_channels=channels[3])

        #Unet解码器
        self.tconv4 = nn.ConvTranspose2d(channels[3],channels[2],3,stride=2,bias=False)
        self.dense5 = Dense(embed_dim,channels[2])
        self.tgnorm4 = nn.GroupNorm(32,num_channels=channels[2])

        self.tconv3 = nn.ConvTranspose2d(channels[2]+channels[2],channels[1],3,stride=2,bias=False,output_padding=1)
        self.dense6 = Dense(embed_dim,channels[1])
        self.tgnorm3 = nn.GroupNorm(32,num_channels=channels[1])

        self.tconv2 = nn.ConvTranspose2d(channels[1]+channels[1],channels[0],3,stride=2,bias=False,output_padding=1)
        self.dense7 = Dense(embed_dim,channels[0])
        self.tgnorm2 = nn.GroupNorm(32,num_channels=channels[0])

        #激活函数Swish
        self.act = lambda x:x*torch.sigmoid(x)

        self.tconv1 = nn.ConvTranspose2d(channels[0]+channels[0],1,3,stride=1)
        self.marginal_prob_std = marginal_prob_std

    def forward(self,x,t):
        embed = self.act(self.embed(t))
        print("经过时间embed之后，时间t的shape：",self.embed(t).shape) #torch.Size([32, 256])

        h1 = self.conv1(x)
        h1 += self.dense1(embed) #注入时间，后面也是一样的
        h1 =self.gnorm1(h1)
        h1 = self.act(h1)

        h2 = self.conv2(h1)
        h2 += self.dense2(embed)
        h2 = self.gnorm2(h2)
        h2 = self.act(h2)

        h3 = self.conv3(h2)
        h3 += self.dense3(embed)
        h3 = self.gnorm3(h3)
        h3 = self.act(h3)

        h4 = self.conv4(h3)
        h4 += self.dense4(embed)
        h4 = self.gnorm4(h4)
        h4 = self.act(h4)

#       解码器部分前向计算
        h = self.tconv4(h4)
        h+= self.dense5(embed)
        h = self.tgnorm4(h)
        h = self.act(h)
        # 按照channel那一个维度拼起来
        print("h:",h.shape)#torch.Size([32, 128, 5, 5]
        print(h3.shape)
        h = self.tconv3(torch.cat([h,h3],dim = 1))
        h += self.dense6(embed)
        h = self.tgnorm3(h)
        h = self.act(h)

        h =self.tconv2(torch.cat([h,h2],dim=1))
        h += self.dense7(embed)
        h = self.tgnorm2(h)
        h = self.act(h)

        h = self.tconv1(torch.cat([h,h1],dim=1))

        h = h/self.marginal_prob_std(t)[:,None,None,None] #SDE 里面对每个Unet的结果除以 2范数 平方的期望，希望我们预测分数可以逼近真实的分数
        return h

![来自内容的路径](Images/Images_for_ScoreBasedModel/20230322172639.png "相对路径演示")
定义损失函数，SDE,这里对SDE进行了简化，去掉了漂移系数，只保留了扩散系数。
https://zhuanlan.zhihu.com/p/399968951
![来自内容的路径](Images/Images_for_ScoreBasedModel/20230331004844.png "相对路径演示")

In [6]:
device = 'cuda' #cuda

def marginal_prob_std(t,sigma):
# 定义标准差，这里还把时间t给引入进来了，在分数模型2019中没有时间t，那个t只在郎之万采样的时候用了；这里的t带入就可以得到某一个噪声sigma下，不同t时刻的 扰动数据的 概率分布，也是近似于一个高斯分布；
    t = torch.tensor(t,device = device)
    return torch.sqrt((sigma**(2*t)-1.)/(2.*(np.log(sigma))))

def diffusion_coeff(t,sigma):
    return torch.tensor(sigma**t,device=device)
# 扩散系数
sigma = 25.0
# 构造无参函数：
marginal_prob_std_fn = functools.partial(marginal_prob_std,sigma = sigma)
diffusion_coeff_fn = functools.partial(diffusion_coeff,sigma = sigma)


![路径](Images/Images_for_ScoreBasedModel/20230327171159.png) "loss的定义"

In [7]:
def loss_fn(score_model,x,marginal_prob_std,eps = 1e-5):
    # 生成时间  t  ,一个时间t 对应 一张图片；
    random_t = torch.rand(x.shape[0],device=x.device)*(1 - eps)+eps
    print("random_t",random_t.shape,"random_t",random_t) #[32]
    # 基于参数重整化技巧添加扰动数据
    z =torch.randn_like(x)
    print("z",z.shape) #[32,1,28,28]
    #t时刻的标准差
    std = marginal_prob_std(random_t)

    #SDE这个加噪的方式，怎么也是参数重整化？仔细看上上个cell
    perturbed_x = x + z*std[:,None,None,None]

    # 传入扰动数据和时间t ，得到 分数score。
    score = score_model(perturbed_x,random_t)
    print("score.shape:",score.shape)#[32,1,28,28]
    print("std",std,std.shape)#[32]

# (score*std[:,None,None,None]+z)**2 分数乘以标准差+一个噪声z
    print(torch.sum((score*std[:,None,None,None]+z)**2,dim=(1,2,3)).shape)#[32]  可以看出来这个sum是把32个图像中每个图像的CHW维度数据全部加起来成一个值；然后再对32个图像取mean。
    print(torch.sum((score*std[:,None,None,None]+z)**2,dim=(1)).shape)
    print(torch.sum((score*std[:,None,None,None]+z)**2,dim=(2)).shape)
    print(torch.sum((score*std[:,None,None,None]+z)**2,dim=(3)).shape)
#torch.Size([32, 28, 28])   torch.Size([32, 1, 28])     torch.Size([32, 1, 28])     loss tensor(4290.4404, device='cuda:0', grad_fn=<MeanBackward0>)
    loss = torch.mean(torch.sum((score*std[:,None,None,None]+z)**2,dim=(1,2,3)))
    print("loss",loss)
    return loss;

编写EMA函数：对模型的权重进行指数平滑

In [8]:
from copy import deepcopy


#对权重 进行指数平滑
class EMA(nn.Module):
    def __init__(self,model,decay = 0.9999,device =None):
        super(EMA,self).__init__();#这是 python2 的写法，3的话不需要写EMA
        self.module = deepcopy(model)
        self.module.eval()
        self.decay = decay
        self.device = device
        if self.device is not None:
            self.module.to(device=device)
    def _update(self,model,update_fn):
        with torch.no_grad():
            #zip 把两个[] 组装成字典 {}
            for ema_v,model_v in zip(self.module.state_dict().values(),model.state_dict().values()):
                if self.device is not None:
                    model_v = model_v.to(device=device)
                ema_v.copy_(update_fn(ema_v,model_v))
    def update(self,model):
        self._update(model,update_fn=lambda e,m:self.decay*e+(1-self.decay)*m)
    def set(self,model):
        self._update(model,update_fn=lambda e,m:m)

基于Minist数据集的模型

In [9]:
import torch
import functools
from torch.optim import Adam
from torch.utils.data import DataLoader
import torchvision.transforms as transformers
from torchvision.datasets import MNIST
import tqdm

# 多机多卡一般用DDP-分布式训练
score_model = torch.nn.DataParallel(ScoreNet(marginal_prob_std=marginal_prob_std_fn))
#将cuda 0 设置成主GPU
device = torch.device("cuda:0")
score_model = score_model.to(device)

n_epechs = 1
batch_size = 32
lr = 1e-4

#根目录下没有数据集就自己下载
dataset = MNIST('.',True,transform=transformers.ToTensor(),download=True);
#开启4个进程去做
data_loader = DataLoader(dataset,batch_size,shuffle=True,num_workers=4)
#torch.Size([32, 1, 28, 28])
# for x,y in data_loader:
#     print(x.shape)

optimizer = Adam(score_model.parameters(),lr=lr)
tqdm_epochs = tqdm.tqdm(range(n_epechs))

ema = EMA(score_model)
for epoch in tqdm_epochs:
    avg_loss = 0.0;
    num_items = 0
    # x是图片，y是标签，也就是condition，条件，但是这里我们是无条件生成；
    for x,y in data_loader:
        x= x.to(device)
        loss = loss_fn(score_model,x,marginal_prob_std_fn)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        #均值平滑
        ema.update(score_model)

        avg_loss += loss.item() * x.shape[0]
        num_items += x.shape[0]
    print("Average ScoreMatching Loss:{:5f}".format(avg_loss/num_items))
    #保存模型参数
    torch.save(score_model.state_dict(),f"ckpt_{epoch}.pth")

  0%|          | 0/1 [00:00<?, ?it/s]

random_t torch.Size([32]) random_t tensor([0.6480, 0.1370, 0.3794, 0.7943, 0.2419, 0.5927, 0.0962, 0.9840, 0.5290,
        0.8001, 0.4745, 0.8542, 0.5952, 0.5344, 0.0647, 0.7283, 0.4906, 0.5260,
        0.2051, 0.6520, 0.7042, 0.9579, 0.7833, 0.4356, 0.0882, 0.7626, 0.4233,
        0.7621, 0.4400, 0.1801, 0.3459, 0.5212], device='cuda:0')
z torch.Size([32, 1, 28, 28])


  t = torch.tensor(t,device = device)


经过时间embed之后，时间t的shape： torch.Size([32, 256])
h: torch.Size([32, 128, 5, 5])
torch.Size([32, 128, 5, 5])
score.shape: torch.Size([32, 1, 28, 28])
std tensor([3.1485, 0.4688, 1.2772, 5.0659, 0.7627, 2.6262, 0.3649, 9.3513, 2.1269,
        5.1618, 1.7718, 6.1495, 2.6476, 2.1660, 0.2832, 4.0904, 1.8707, 2.1061,
        0.6529, 3.1898, 3.7820, 8.5952, 4.8887, 1.5526, 0.3446, 4.5717, 1.4880,
        4.5644, 1.5761, 0.5829, 1.1336, 2.0727], device='cuda:0') torch.Size([32])
torch.Size([32])
torch.Size([32, 28, 28])
torch.Size([32, 1, 28])
torch.Size([32, 1, 28])
loss tensor(6366.2500, device='cuda:0', grad_fn=<MeanBackward0>)
random_t torch.Size([32]) random_t tensor([0.0225, 0.1678, 0.4487, 0.5644, 0.6523, 0.1932, 0.6358, 0.2419, 0.9728,
        0.6065, 0.9646, 0.6019, 0.5155, 0.1150, 0.9685, 0.6436, 0.9147, 0.6898,
        0.1471, 0.7691, 0.6150, 0.6016, 0.6079, 0.1190, 0.0979, 0.4079, 0.1761,
        0.6075, 0.5110, 0.9552, 0.8422, 0.7709], device='cuda:0')
z torch.Size([32, 1, 28, 28])
经过

100%|██████████| 1/1 [00:28<00:00, 28.29s/it]

Average ScoreMatching Loss:305.608176





![路径](Images/Images_for_ScoreBasedModel/20230330233422.png) “采样“


In [1]:

num_steps = 500;

def euler_sampler(
        score_model,
        marginal_prob_std,
        batch_size=64,
        num_steps=num_steps,
        device='cuda',
        eps=1e-3,
):
    t = torch.ones(batch_size,device= device)
    init_x = torch.randn(batch_size,1,28,28,device=device)\
        * marginal_prob_std(t)[:,None,None,None]

    time_steps = torch.linspace(1.,eps,num_steps,device=device)
    # delta t
    step_size = time_steps[0] - time_steps[1]

    x = init_x
    with torch.no_grad():
        for time_step in tqdm.tqdm(time_steps):
            batch_time_step = torch.ones(batch_size,device=device)*time_step
            g = diffusion_coeff(batch_time_step)
            mean_x = x + (g**2)[:,None,None,None]*score_model(x,batch_time_step)*step_size
            x = mean_x + torch.sqrt(step_size) *g[:,None,None,None]*torch.randn_like(x)
    return mean_x


欧拉算法和郎之万采样直接的PC采样法：
![路径](Images/Images_for_ScoreBasedModel/20230331232719.png)

In [10]:
signal_to_noise_ratio = 0.16
num_steps = 500
def PC_sampler(
        score_model,
        marginal_prob_std,
        batch_size = 64,
        num_steps=num_steps,
        snr = signal_to_noise_ratio,
        device ="cuda",
        eps = 1e-3
):
    t = torch.ones(batch_size,device = device)
    init_x = torch.randn(batch_size,1,28,28,device=device)*marginal_prob_std(t)[:,None,None,None]
    
    time_steps = np.linspace(1.,eps,num_steps)
    step_size = time_steps[0]-time_steps[1]
    
    # 重复交替的进行欧拉数值求解：
    x = init_x
    with torch.no_grad():
        for time_step in tqdm.tqdm(time_steps):
            batch_time_step = torch.ones(batch_size,device=device)*time_step
            
            grad = score_model(x,batch_time_step)
            grad_norm = torch.norm(grad.reshape((grad.shape[0],-1),dim = -1)).mean()
            noise_norm = np.sqrt(np.prod(x.shape[1:]))
            # 步长，snr是为了保证每一步的信噪比是固定量级的。
            Langevin_step_size = 2*(snr*noise_norm/grad_norm)**2
            print(f"{Langevin_step_size=}")
            
            #做 10 步 矫正：
            for _ in range(10):
                x = x + Langevin_step_size*grad+torch.sqrt(2*Langevin_step_size)*torch.rand_like(x)
                qrad = score_model(x,batch_time_step)
                grad_norm = torch.norm(grad.reshape((grad.shape[0],-1),dim = -1)).mean()
                noise_norm = np.sqrt(np.prod(x.shape[1:]))
                Langevin_step_size = 2*(snr*noise_norm/grad_norm)**2
                print(f"{Langevin_step_size=}")
            
        # Predictor
        g = diffusion_coeff(batch_time_step)
        x_mean = x + (g**2)[:,None,None,None]*score_model(x,batch_time_step)*step_size
        x = x_mean +torch.sqrt(g**2*step_size)[:,None,None,None]*torch.randn_like(x)
    
    return x_mean
