In [1]:
from torch import nn
import torch
import numpy as np
from Unet import UNet
from abc import abstractmethod

In [2]:
class EDM_base(nn.Module):
    def __init__(self,param,N=1000):
        super().__init__()
        self.config =param
  
        self.N = N
        self.t = torch.arange(self.N+1)
        
    @abstractmethod
    def time_step(self):
        pass
    @abstractmethod
    def Schedule(self):
        pass
    @abstractmethod
    def Schedule_prm(self):
        pass
    @abstractmethod
    def Scaling(self):
        pass
    @abstractmethod
    def Scaling_prm(self):
        pass
    @abstractmethod
    def Skip_scaling(self):
        pass
    @abstractmethod
    def Output_scaling(self):
        pass
    @abstractmethod
    def Input_scaling(self):
        pass
    @abstractmethod
    def Noise_cond(self):
        pass
    @abstractmethod
    def Noise_distribution(self):
        pass
    @abstractmethod
    def Loss_weighting(self):
        pass
    
    def D_x_eta(self,x,eta):
       
        model_output = self.model(self.Input_scaling(eta) * x,self.Noise_cond(eta))
        return self.Skip_scaling(eta) * x + self.Output_scaling(eta)* model_output
    def Deterministic_sampling(self,bs,size):
        z  = torch.rand(bs,3,size,size)
        x_0 = self.Schedule(self.time_step(0)) * self.Scaling(self.time_step(0)) * z 
        for i in range(self.N):
            t_i = self.time_step(i)
            Scaling =self.Scaling(t_i)
            Schedule = self.Schedule(t_i)
            Scaling_prm = self.Scaling_prm(t_i)
            Schedule_prm = self.Schedule_prm(t_i)
            d_i = (Scaling_prm/Scaling + Schedule_prm/Schedule)*x_0
            d_i -= Schedule_prm * Scaling/Schedule*self.D_x_eta(x_0/Scaling ,Schedule)
            x_0 = x_0 + (self.time_step(i+1)-self.time_step(i))*d_i
            if self.Schedule(self.time_step(i+1))!=0:
                t_i_post = self.time_step(i+1)
                Scaling_post =self.Scaling(t_i_post)
                Schedule_post = self.Schedule(t_i_post)
                Scaling_prm_post = self.Scaling_prm(t_i_post)
                Schedule_prm_post = self.Schedule_prm(t_i_post)
                d_i_pm = (Scaling_prm_post/Scaling_post + Schedule_prm_post/Schedule_post )*x_0
                d_i_pm -= Schedule_prm_post * Scaling_post/Schedule_post * self.D_x_eta(x_0 / Scaling_post, Schedule_prm_post   )
                x_0 =x_0 + (t_i_post-t_i)*(0.5*d_i + 0.5*d_i_pm)
        return x_0
    def stochastic_sampler(self,bs,size,
            S_churn  = 30,
            S_tmin = 0.01,
            S_tmax = 1,
            S_noise = 1.007):
        z =torch.rand(bs,3,size,size)
        x_0 =z * self.time_step(0)
        for i in range(self.N):
            t_i =self.time_step(i)
            t_i_post = self.time_step(i+1)
            gam_i = min( S_churn/self.N,torch.sqrt(torch.tensor(2))-1) if S_tmin<=t_i<=S_tmax else 0 
            eps_i = torch.rand(*z.shape) * S_noise**2
            t_i_hat = t_i + gam_i*t_i
            x_0_hat = x_0 +torch.sqrt(t_i_hat**2 - t_i**2)*eps_i
            d_i = (x_0_hat - self.D_x_eta(x_0_hat,t_i_hat))/t_i_hat
            x_0 = x_0_hat+ (t_i_post-t_i_hat)*d_i
            if self.time_step(i+1)!=0:
                d_i_prm = x_0 - self.D_x_eta(x_0,t_i_post)/t_i_post
                x_0 = x_0_hat + (t_i_post-t_i_hat)*(0.5*d_i+0.5*d_i_prm)
        return x_0

In [3]:
class VP_EDM(EDM_base):
    
    def __init__(self,
                     model,
                 N=1000,
                 para=dict(beta_d=19.9,
                            beta_min =0.1,
                            eps_s=10e-3,
                            eps_t=10e-5,
                            M=1000)):
        super().__init__(para)
        self.model = model
        self.N = N
    def time_step(self,i):
        return  1+i/(torch.tensor(self.N)-1)*(self.config['eps_s']-1)

    def Schedule(self,t):
        a = torch.exp(  torch.exp(0.5*self.config['beta_d']*torch.tensor(t)**2) + self.config['beta_min']* t )
        return torch.sqrt(a-1)
    
    def Schedule_prm(self,t):
        a =torch.exp(0.5*self.config['beta_d']*torch.tensor(t)**2 + self.config['beta_min']*t) * (self.config['beta_d']*t + self.config['beta_min'])
        return a/(2.0*torch.sqrt(torch.exp(0.5*self.config['beta_d']*torch.tensor(t)**2+self.config['beta_min']*t)  - 1 ))
    def Scaling(self,t):
        a = torch.sqrt( torch.exp(0.5 * self.config['beta_d']*torch.tensor(t)**2) + self.config['beta_min']*t)
        return 1.0 /a 

    def Scaling_prm(self,t):
        root = 2*torch.sqrt(torch.exp(0.5*self.config['beta_d']*torch.tensor(t)**2   + self.config['beta_min']*t ))
        return -(self.config['beta_d']*t  + self.config['beta_min'])/root
    def Skip_scaling(self,eta):
        return 1

    def Output_scaling(self,eta):
        return -eta

    def Input_scaling(self,eta):
        return  1.0/torch.sqrt(eta**2 + 1)

    def Noise_cond(self,eta):
        
        return (self.config['M'] - 1) * 1.0/self.Noise_distribution(eta).sample()

    def Noise_distribution(self,eta):
        
        return torch.distributions.Uniform(self.config['eps_t'],1)
    def Loss_weighting(self,eta):
        return 1.0/eta**2


In [4]:
class VE_EDM(EDM_base):
    def __init__(self,model,
                 N=1000,param = dict(sigma_min = 0.02,sigma_max =100)):
        super().__init__(param)
        self.N=N
    def time_step(self,i):
        return self.config['sigma_max']**2 * (self.config['sigma_min']**2 / self.config['sigma_max']**2 ) ** (i/(self.N-1))
    def Schedule(self,t):
        return torch.sqrt(t)
    def Scaling_prm(self,t):
        return 1/(0.5*torch.sqrt(t))
    def Scaling(self,t):
        return 1
    def Scaling_prm(self,t):
        return 0
    def Skip_scaling(self,eta):
        return 1
    def Output_scaling(self,eta):
        return eta
    def Input_scaling(self,eta):
        return 1
    def Noise_cond(self,eta):
        return self.Noise_distribution(0.5*eta).sample()
    def Noise_distribution(self,eta):
        return torch.distributions.Uniform(torch.log(self.config['sigma_min']),torch.log(self.config['sigma_max']))
    def Loss_weighting(self,eta):
        return 1/eta**2
    

In [None]:
class EDM(EDM_base):
    def __init__(self,model,N=1000,
                 param= dict(sigma_min =0.002,sigma_max=80,
                            sigma_data=0.5,phi = 7,
                            P_mean = -1.2,P_std = 1.2)):
        super().__init__(param)
        self.model = model
        self.N=N
    def time_step(self,i):
        ro = self.config['phi']
        
        return (  self.config['sigma_max']**(1/ro)     + i/(self.N-1) *( self.config['sigma_min']**(1.0/ro) - self.config['sigma_max']**(1.0/ ro) )    )**ro
    def Schedule(self,t):
        return t
    def Schedule_prm(self,t):
        return 1
    def Scaling(self,t):
        return 1

    def Scaling_prm(self,t):
        return 0
    def Skip_scaling(self,eta):
        return self.config['sigma_data']**2 / (self.config['sigma']**2+ self.config['sigma_data'])
    def Output_scaling(self,eta):
        return eta * self.config['sigma_data']/torch.sqrt( self.config['sigma_data']+eta )
    def Input_scaling(self,eta):
        return 1/torch.sqrt(eta**2 + self.config['sigma_data'])
    def Noise_cond(self,eta):
        return 1/4 * self.Noise_distribution(eta).sample()
    def Noise_distribution(self,eta):
        return torch.distributions.Normal(mean =self.config['P_mean'],std = self.config['P_std'])
    def Loss_weighting(self,eta):
        return (  eta **2 + self.config['sigma_data']**2  )/(  eta * self.config['sigma_data'] )**2

# Train

In [None]:
img_size = 128  # 快速实验，可以改为64
batch_size = 32  # 如果显存不够，可以降低为16
embedding_size = 128
channels = [1, 1, 2, 2, 4, 4]
blocks = 2  # 如果显存不够，可以降低为1

In [None]:
pipe = Compose([transforms.Resize([224,224]),
               transforms.ToTensor(),
               ])

In [None]:
optimizer = torch.optim.Adam(model.parameters(),lr=0.001,weight_decay=0.99)
learn_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=10)

In [None]:
model = UNet(in_channels=3, out_channels=3)

In [None]:
dataloader = DataLoader(data,
                        batch_size=32,a
                        shuffle=True,
                        pin_memory=True)

In [None]:
diff_model = VP_EDM(model)
diff_model.N=2

In [None]:
epoches = 20
for i in range(epoches):
    
    for data in dataloader :
        image,_= data
        bs=image.shape[0]
        optimizer.zero_grad()
        n= torch.randint(0,diff_model.N,(bs,1,1,1))
        t = diff_model.time_step(n)
        eta = diff_model.Schedule(t)
        label = diff_model.D_x_eta(image,eta)
        loss = torch.sum(image,label)
        
        loss.backward()
        optimizer.step()
        learn_scheduler.step()
    

In [5]:
class ADM(nn.Module):
    parameter={'VP':dict(beta_d=19.9,beta_min =0.1,eps_s=10e-3,eps_t=10e-5,M=1000),
               'VE':dict(sigma_min=0.02,sigma_max = 100),
               'DDIM':dict(M=1000,j_0=8,C_1=0.001,C_2=0.008),
               'EDM':dict(sigma_min = 0.002,sigma_max=80,sigma_data=0.5,
                         phi = 7,P_mean = -1.2,P_std=1.2)}
    
    
    compute = {'VP':
            'VE':
            'DDIM':
            'EDM':}
    def __init__(self,name,
                 N,
                score_model,
                ):
        super().__init__()
        self.config = parameter[name]
        self.N =N
        self.model = score_model
        
    def time_step(self,i):
        assert (i<self.N)
        Time_steps ={'VP':1+i/(self.N-1)*(self.config['eps_s']-1),
                    'VE':self.config['sigma_max']**2 *(self.config['sigma_min']/self.config['sigma_max'])**(2*i/(self.N-1))
                   , 'DDIM':
                    ,'EDM':}

    


In [6]:
tet =dict(a=5)

In [8]:
tet['a']

5