<font size = 5>**This is a jupyter notebook for kaggle to train DDPM+CFG+EMA+DDIM**

In [None]:
%%writefile submodules.py

import os
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    def __init__(self,in_channels,out_channels,mid_channels=None,residual=False):
        super().__init__()
        self.residual = residual
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels,mid_channels, kernel_size = 3,padding = 1, bias = False),
            nn.GroupNorm(1,mid_channels),
            nn.GELU(),
            nn.Conv2d(mid_channels,out_channels, kernel_size = 3,padding = 1, bias = False),
            nn.GroupNorm(1,out_channels),
        )
    
    def forward(self,x):
        if self.residual:
            return F.gelu(x + self.double_conv(x))
        else:
            return self.double_conv(x)
class Down(nn.Module):
    def __init__(self,in_channels,out_channels, emb_dim = 256):
        super().__init__()
        self.maxpool_conv  = nn.Sequential(
        nn.MaxPool2d(2),
        DoubleConv(in_channels,in_channels,residual=True),
        DoubleConv(in_channels,out_channels),
        )
        self.emb_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(
                emb_dim,
                out_channels,
            ),)
        
    def forward(self,x,t):
        x = self.maxpool_conv(x)
        emb = self.emb_layer(t)[:,:,None,None].repeat(1,1,x.shape[-2],x.shape[-1])
        return x+emb

class Up(nn.Module):
    def __init__(self,in_channels,out_channels, emb_dim = 256):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2,mode="bilinear",align_corners=True)
        self.conv = nn.Sequential(
            DoubleConv(in_channels,in_channels,residual=True),
            DoubleConv(in_channels,out_channels,in_channels//2),
        )
        self.emb_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(
                emb_dim,
                out_channels,
            ),)
        
    def forward(self,x,skip_x,t):
        x = self.up(x)
        x = torch.cat([skip_x,x], dim = 1)
        x = self.conv(x)
        emb = self.emb_layer(t)[:,:,None,None].repeat(1,1,x.shape[-2],x.shape[-1])
        return x+emb
    
class SelfAttention(nn.Module):
    def __init__(self,channels,size):
        super(SelfAttention,self).__init__()
        self.channels = channels
        self.size = size
        self.mha = nn.MultiheadAttention(channels,4,batch_first=True)
        self.ln = nn.LayerNorm([channels])
        self.ff_self = nn.Sequential(
            nn.LayerNorm([channels]),
            nn.Linear(channels,channels),
            nn.GELU(),
            nn.Linear(channels,channels),
        )
        
    def forward(self,x):
        x = x.view(-1,self.channels,self.size * self.size).swapaxes(1,2)
        x_ln = self.ln(x)
        attn_value,_ = self.mha(x_ln,x_ln,x_ln)
        attn_value = attn_value + x
        attn_value = self.ff_self(attn_value) + attn_value
        return attn_value.swapaxes(2,1).view(-1,self.channels,self.size,self.size)
    
    

class EMA:
    def __init__(self, beta):
        super().__init__()
        self.beta = beta
        self.step = 0

    def update_model_average(self, ma_model, current_model):
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = self.update_average(old_weight, up_weight)

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

    def step_ema(self, ema_model, model, step_start_ema=2000):
        if self.step < step_start_ema:
            self.reset_parameters(ema_model, model)
            self.step += 1
            return
        self.update_model_average(ema_model, model)
        self.step += 1

    def reset_parameters(self, ema_model, model):
        ema_model.load_state_dict(model.state_dict())

In [None]:
%%writefile UNet.py
import os
import torch
import torch.nn as nn
from submodules import DoubleConv
from submodules import Down
from submodules import SelfAttention
from submodules import Up


class UNet(nn.Module):
    def __init__(self,c_in=3,c_out=3,time_dim=256,device="cuda"):
        super().__init__()
        self.device = device
        self.time_dim = time_dim
        self.inc = DoubleConv(c_in,64)
        
        self.down1 = Down(64,128) # input_channel, output_channel
        self.sa1 = SelfAttention(128,32) # channel, img resolution
        
        self.down2 = Down(128,256)
        self.sa2 = SelfAttention(256,16)
        
        self.down3 = Down(256,256)
        self.sa3 = SelfAttention(256,8)
        
        # bottle neck
        self.bot1 = DoubleConv(256,512)
        self.bot2 = DoubleConv(512,512)
        self.bot3 = DoubleConv(512,256)
        
        self.up1 = Up(512,128)
        self.sa4 = SelfAttention(128,16)
        
        self.up2 = Up(256,64)
        self.sa5 = SelfAttention(64,32)
        
        self.up3 = Up(128,64)
        self.sa6 = SelfAttention(64,64)
        
        self.outc = nn.Conv2d(64, c_out, kernel_size=1)
        
    def pos_encoding(self,t,channels):
        inv_freq = 1.0/ (
            10000
            ** (torch.arange(0,channels,2,device=self.device).float()/channels)
        )
        pos_enc_a = torch.sin(t.repeat(1,channels // 2) * inv_freq)
        pos_enc_b = torch.cos(t.repeat(1,channels // 2) * inv_freq)
        pos_enc = torch.cat([pos_enc_a,pos_enc_b],dim=-1)
        return pos_enc
    
    def forward(self,x,t):
        t = t.unsqueeze(-1).type(torch.float)
        t = self.pos_encoding(t,self.time_dim)
        
        x1 = self.inc(x)
        
        x2 = self.down1(x1,t)
        x2 = self.sa1(x2)
        
        x3 = self.down2(x2,t)
        x3 = self.sa2(x3)
        
        x4 = self.down3(x3,t)
        x4 = self.sa3(x4)
        
        x4 = self.bot1(x4)
        x4 = self.bot2(x4)
        x4 = self.bot3(x4)
        
        x = self.up1(x4,x3,t)
        x = self.sa4(x)
        
        x = self.up2(x,x2,t)
        x = self.sa5(x)
        
        x = self.up3(x,x1,t)
        x = self.sa6(x)
        
        output = self.outc(x)
        return output
    
        
        
        

In [None]:
%%writefile UNet_conditional.py
import os
import torch
import torch.nn as nn
from submodules import DoubleConv
from submodules import Down
from submodules import SelfAttention
from submodules import Up


class UNet_conditional(nn.Module):
    def __init__(self,c_in=3,c_out=3,time_dim=256,num_classes=None,device="cuda"):
        super().__init__()
        self.device = device
        self.time_dim = time_dim
        self.inc = DoubleConv(c_in,64)
        
        self.down1 = Down(64,128) # input_channel, output_channel
        self.sa1 = SelfAttention(128,32) # channel, img resolution
        
        self.down2 = Down(128,256)
        self.sa2 = SelfAttention(256,16)
        
        self.down3 = Down(256,256)
        self.sa3 = SelfAttention(256,8)
        
        # bottle neck
        self.bot1 = DoubleConv(256,512)
        self.bot2 = DoubleConv(512,512)
        self.bot3 = DoubleConv(512,256)
        
        self.up1 = Up(512,128)
        self.sa4 = SelfAttention(128,16)
        
        self.up2 = Up(256,64)
        self.sa5 = SelfAttention(64,32)
        
        self.up3 = Up(128,64)
        self.sa6 = SelfAttention(64,64)
        
        self.outc = nn.Conv2d(64, c_out, kernel_size=1)
        
        if num_classes is not None:
            self.label_emb = nn.Embedding(num_classes,time_dim)
            
        
    def pos_encoding(self,t,channels):
        inv_freq = 1.0/ (
            10000
            ** (torch.arange(0,channels,2,device=self.device).float()/channels)
        )
        pos_enc_a = torch.sin(t.repeat(1,channels // 2) * inv_freq)
        pos_enc_b = torch.cos(t.repeat(1,channels // 2) * inv_freq)
        pos_enc = torch.cat([pos_enc_a,pos_enc_b],dim=-1)
        return pos_enc
    
    def forward(self,x,t,y=None):
        t = t.unsqueeze(-1).type(torch.float)
        t = self.pos_encoding(t,self.time_dim)
        if y is not None:
            t += self.label_emb(y)
        
        x1 = self.inc(x)
        
        x2 = self.down1(x1,t)
        x2 = self.sa1(x2)
        
        x3 = self.down2(x2,t)
        x3 = self.sa2(x3)
        
        x4 = self.down3(x3,t)
        x4 = self.sa3(x4)
        
        x4 = self.bot1(x4)
        x4 = self.bot2(x4)
        x4 = self.bot3(x4)
        
        x = self.up1(x4,x3,t)
        x = self.sa4(x)
        
        x = self.up2(x,x2,t)
        x = self.sa5(x)
        
        x = self.up3(x,x1,t)
        x = self.sa6(x)
        
        output = self.outc(x)
        return output
    
        
        

In [None]:

!rm -r /kaggle/working/utils.py


In [None]:
%%writefile utils.py

import os
import torch
import torchvision
from PIL import Image
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from torch import optim
from UNet import UNet
import torch.nn as nn
import logging


def plot_images(images):
    plt.figure(figsize = (32,32) )
    plt.imshow(torch.cat(
        [
            torch.cat([i for i in images.cpu()],dim=-1)
        ],dim=-2).permute(1,2,0).cpu()
        )
    plt.show()

def save_images(images,path,**kwargs):
    grid = torchvision.utils.make_grid(images,**kwargs)
    ndarr = grid.permute(1,2,0).to("cpu").numpy()
    im = Image.fromarray(ndarr)
    im.save(path)
    
def get_data(args):
    transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize(80),
        torchvision.transforms.RandomResizedCrop(args.img_size,scale=(0.8,1.0)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    ])
    dataset = torchvision.datasets.ImageFolder(args.dataset_path,transform=transforms)
    
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
    return dataloader

def setup_logging(run_name):
    os.makedirs("models",exist_ok=True)
    os.makedirs("results",exist_ok=True)
    os.makedirs(os.path.join("models",run_name),exist_ok=True)
    os.makedirs(os.path.join("results",run_name),exist_ok=True)



In [None]:
%cd /kaggle/working/
!rm ddpm.py
!rm -r /kaggle/working/results/DDPM_conditional

In [None]:
%%writefile ddpm.py


import os
import copy
import numpy as np
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from torch import optim
from tqdm import tqdm
import logging
from torch.utils.tensorboard import SummaryWriter

from utils import plot_images
from utils import save_images
from utils import get_data
from utils import setup_logging

# from UNet import UNet
from UNet_conditional import UNet_conditional
from submodules import EMA

from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
import torch.nn.functional as F

logging.basicConfig(format="%(asctime)s-%(levelname)s: %(message)s",level=logging.INFO,datafmt="%I:%M:%S")

from collections import OrderedDict

def extract(v, t, x_shape):
    # v[T]
    # t[B] x_shape = [B,C,H,W]
    out = torch.gather(v, index=t, dim=0).float()
    # [B,1,1,1],分别代表batch_size,通道数,长,宽
    return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))

class Diffusion:
    def __init__(self,noise_steps=1000,beta_start=1e-4,beta_end=0.02,img_size=64,device="cuda"):
        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.img_size = img_size
        self.device = device
#         self.beta = self.prepare_noise_schedule()
        self.beta = self.prepare_noise_schedule().to(device) # shape: (1000,)
        self.alpha = 1. - self.beta                       # shape: (1000,)
        self.alpha_hat = torch.cumprod(self.alpha,dim=0) # (1-β_i)(1-β_i+1)... shape: (1000,)
        # DDIM
        self.sqrt_recip_alphas_bar = torch.sqrt(1. / self.alpha_hat)
        self.sqrt_recipm1_alphas_bar = torch.sqrt(1. / self.alpha_hat - 1)
        
        self.alphas_bar_prev_whole = F.pad(self.alpha_hat, [1, 0], value=1)
        self.ddim_eta = 0 
        
    def prepare_noise_schedule(self):
        return torch.linspace(self.beta_start,self.beta_end,self.noise_steps)
    
    def noise_images(self,x,t):
        """
        x: images, shape:
        t: time_steps, shape:
        """
        sqrt_alpha_hat  = torch.sqrt(self.alpha_hat[t])[:,None,None,None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1-self.alpha_hat[t])[:,None,None,None]
        sigma = torch.randn_like(x)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * sigma, sigma
    
    def sample_timesteps(self,n):
        return torch.randint(low=1,high = self.noise_steps,size=(n,)) # rand int value
    
    def predict_xstart_from_eps(self, x_t, t, eps):
        return (
            extract(self.sqrt_recip_alphas_bar, t, x_t.shape) * x_t -
            extract(self.sqrt_recipm1_alphas_bar, t, x_t.shape) * eps
        )
    
#     def sample(self,model,n):
    def sample(self,model,n,labels,cfg_scale=3):
        """
        n=?
        """
        logging.info(f"sampling {n} new images... ")
        model.eval()
        with torch.no_grad():
#              x = torch.randn((n,3,self.img_size,self.img_size)).to(self.device) # shape:n,3,h,w
#             for i in tqdm(reversed(range(1,self.noise_steps)),position=0):
#                 t = (torch.ones(n)*i).long().to(self.device) # shape:(bs,) value:(i,i,......) like 1000,1000,1000
# #                 predicted_noise = model(x,t)
#                 predicted_noise = model(x,t,labels)
#                 if cfg_scale>0:
#                     uncond_predicted_noise = model(x, t, None)
#                     predicted_noise = torch.lerp(uncond_predicted_noise,predicted_noise,cfg_scale)
                
#                 alpha = self.alpha[t][:,None,None,None] # only scalar value like 1000:(1,1,1,1)
#                 alpha_hat = self.alpha_hat[t][:,None,None,None]
#                 beta = self.beta[t][:,None,None,None]
#                 if i>1:
#                     noise = torch.randn_like(x) # noise z
#                 else:
#                     noise = torch.zeros_like(x)
                    
#                 x = 1/torch.sqrt(alpha)*(x-((1-alpha)/(torch.sqrt(1-alpha_hat)))*predicted_noise)+torch.sqrt(beta)*noise
          # try DDIM:
            sample_steps = 5 # faster 10 
            t_seq = torch.arange(sample_steps, self.noise_steps + 1, sample_steps) # x_t
            t_prev_seq = t_seq - sample_steps # x_{prev}
            
            x_t = torch.randn((n,3,self.img_size,self.img_size)).to(self.device) # shape:n,3,h,w
            for i, j in tqdm(zip(reversed(list(t_seq)), reversed(list(t_prev_seq))), desc='Inference'):
                t = x_t.new_ones([x_t.shape[0], ], dtype=torch.long) * i
                prev_t = x_t.new_ones([x_t.shape[0], ], dtype=torch.long) * j
                alpha_cumprod_t = extract(self.alphas_bar_prev_whole, t, x_t.shape)
                alpha_cumprod_t_prev = extract(self.alphas_bar_prev_whole, prev_t, x_t.shape)
                
                eps = model(x_t, t - 1,labels) # 采用t-1是因为原本的ddpm的0位置元素代表t=1时刻,差了一个1
                # 计算x_0,用于第一项
                x_0 = self.predict_xstart_from_eps(x_t, t - 1, eps)
#                 if self.clip_denoised:
#                     x_0 = torch.clamp(x_0, min=-1., max=1.) # 裁剪梯度
                if cfg_scale>0:
                     uncond_predicted_noise = model(x_t, t-1, None)
                     predicted_noise = torch.lerp(uncond_predicted_noise,eps,cfg_scale)
                
                sigma_t = self.ddim_eta * torch.sqrt(
                    (1 - alpha_cumprod_t_prev) / (1 - alpha_cumprod_t) * (1 - alpha_cumprod_t / alpha_cumprod_t_prev))
                
                pred_dir_xt = torch.sqrt(1 - alpha_cumprod_t_prev - sigma_t ** 2) * eps
                if j>1:
                     noise = torch.randn_like(x_t) # noise z
                else:
                     noise = torch.zeros_like(x_t)
                        
                x_prev = torch.sqrt(alpha_cumprod_t_prev) * x_0 + pred_dir_xt + sigma_t ** 2 * noise
                x_t = x_prev               
        model.train()
        x = (x_t.clamp(-1,1)+1)/2 # normalize
        x = (x * 255).type(torch.uint8) # factor pixels
                
        return x
    
def train(args):
    setup_logging(args.run_name)
    


    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
    accelerator = Accelerator(kwargs_handlers=[kwargs])

    
#     accelerator = Accelerator()
    
    device = args.device
    dataloader = get_data(args)
    
    model = UNet_conditional(num_classes = args.num_classes)
    
    if args.resumes:
        ckpts = torch.load(args.resumes)
        new_state_dict = OrderedDict()
        for k, v in ckpts.items():
            name = k[7:]  # 移除'module.'
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict)
        print("resume epoch success!\n")
        
#     model = UNet()
#     model = UNet().to(device)
    optimizer = optim.AdamW(model.parameters(),lr= args.lr)
    mse = nn.MSELoss()
    diffusion = Diffusion(img_size = args.img_size,device=device)
    logger = SummaryWriter(os.path.join("runs",args.run_name))
    l = len(dataloader)
    
    
    model,optimizer,dataloader = accelerator.prepare(model,optimizer,dataloader)
    args.accelerator = accelerator
    
    ema = EMA(beta=0.995)
    ema_model = copy.deepcopy(model).eval().requires_grad_(False)
    
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_drop, gamma=0.1)
    
    for epoch in range(args.epochs):
        logging.info(f"starting epoch {epoch}: ")
#         print(f"starting epoch {epoch}: ")
        pbar = tqdm(dataloader)
#         for i,(images,_) in enumerate(pbar):
        for i,(images,labels) in enumerate(pbar):
    
            images = images.to(device)
            
            labels = labels.to(device)
            
#             images = images
#             t = diffusion.sample_timesteps(images.shape[0])
            t = diffusion.sample_timesteps(images.shape[0]).to(device)

            x_t,noise = diffusion.noise_images(images,t)
#         predicted_noise = model(x_t,t),
            if np.random.random()<0.1:
                labels = None # cfg,10% unconditional
            predicted_noise = model(x_t,t,labels)
    
            loss = mse(noise,predicted_noise)
            
            optimizer.zero_grad()
#             loss.backward()
            args.accelerator.backward(loss)
            optimizer.step()
            ema.step_ema(ema_model,model)
            scheduler.step()
            
            pbar.set_postfix(MSE=loss.item())
            logger.add_scalar("MSE",loss.item(),global_step= epoch*l+i)
        if epoch>=200:    
            if (epoch+1) % 10 == 0:
                labels = torch.arange(args.num_classes).long().to(device)

                sampled_images = diffusion.sample(model, n=len(labels), labels=labels)
                sampled_images = args.accelerator.gather_for_metrics(sampled_images)
                save_images(sampled_images, os.path.join("results",args.run_name,f"{epoch}.jpg"))

                ema_sampled_images = diffusion.sample(ema_model, n=len(labels), labels=labels)
                ema_sampled_images = args.accelerator.gather_for_metrics(ema_sampled_images)
                save_images(ema_sampled_images, os.path.join("results", args.run_name, f"{epoch}_ema.jpg"))

                torch.save(model.state_dict(),os.path.join("models",args.run_name,f"{epoch}.ckpts"))
                torch.save(ema_model.state_dict(), os.path.join("models", args.run_name, f"ema_ckpt.pt"))
                torch.save(optimizer.state_dict(), os.path.join("models", args.run_name, f"optim_{epoch}.pt"))
            
    
def launch():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("-f","--file",default="file")
    
    
    
    
    args = parser.parse_args()
    args.resumes = None # if you want to resume your ckpts
    args.run_name = "DDPM_conditional"
    args.num_classes = 10   # change your datasets classes num
    args.epochs = 500
    args.batch_size = 8
    args.img_size = 64
    args.dataset_path = r"/kaggle/input/ddpm-nwpu-instance/content/train_nwpu" # change your datasets file path
    args.device = "cuda"
    args.lr = 2e-4
    args.lr_drop = 450
    train(args)
    
if __name__=="__main__":
    launch()

In [None]:
# import os
# os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'INFO'

In [None]:
!accelerate launch --multi_gpu ddpm.py  