In [1]:
import torch
import torch.nn as nn



In [2]:
import math
class Swish(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self,x):
        return x * torch.sigmoid(x)
class TimestampEmbedding(nn.Module):
    def __init__(self,T,hid_dim,out_dim):
        '''hid_dim must be even'''
        assert hid_dim % 2 == 0
        super().__init__()
        self.T = T
        self.hid_dim = hid_dim
        self.out_dim = out_dim
        emb = torch.zeros((T,hid_dim))
        t = torch.arange(0,T).float()
        pos = torch.exp(- torch.arange(0,hid_dim,2).float() * math.log(10000) / hid_dim)
        pos = pos[None,:] * t[:,None]
        emb[:,0::2] = torch.sin(pos)
        emb[:,1::2] = torch.cos(pos)
        assert list(emb.shape) == [T,hid_dim], emb.shape
        self.emb_layer = nn.Embedding.from_pretrained(emb)
        self.linear_projection = nn.Sequential(
            nn.Linear(hid_dim,hid_dim),
            Swish(),
            nn.Linear(hid_dim,out_dim),
        )
    def forward(self,t):
        ts = self.emb_layer(t)
        return self.linear_projection(ts)


In [3]:

# test timeembedding
T = 1000
t_hid = 256
t_out = 64
temb = TimestampEmbedding(T,t_hid,t_out)
temb(torch.tensor([100]))

tensor([[-1.0124e-01, -1.8524e-01, -7.5046e-02,  1.2864e-01, -1.2727e-01,
          1.2264e-01,  9.7239e-02, -5.3639e-02,  3.0916e-01,  3.1185e-02,
          5.8693e-02, -9.3891e-02,  9.7069e-02, -2.3483e-02, -4.5807e-02,
          1.5918e-04, -8.8223e-02,  2.3478e-01,  4.4465e-02, -4.0098e-02,
         -2.4875e-02,  2.7891e-02,  1.0597e-01, -1.6713e-01,  9.6536e-02,
          1.0362e-02,  4.8702e-02,  2.2652e-02, -1.5555e-02,  4.8726e-02,
         -7.6452e-02, -8.5779e-02, -1.3418e-01,  1.0486e-01, -2.1896e-02,
         -1.8625e-01,  1.1470e-01, -1.4839e-01, -1.7249e-01,  1.1505e-01,
         -2.0230e-02, -1.9516e-01, -2.3811e-02, -2.8151e-01,  6.4190e-02,
          1.0873e-01,  1.7719e-01, -8.4682e-02, -7.2578e-02, -1.8599e-01,
         -1.5161e-01, -2.7676e-01,  3.4338e-01,  2.8707e-01,  1.5990e-01,
         -9.9030e-02, -6.5463e-02, -1.1371e-01, -3.7894e-02,  3.2099e-01,
          1.4604e-01, -3.5411e-03,  1.0791e-01,  6.8421e-02]],
       grad_fn=<AddmmBackward0>)

In [4]:
class AttentionBlock(nn.Module):
    '''
    Self Attention block
    x.shape = [B,C,H,W] -> group_norm. q,k,v = linear projection of x,
    out = linear projection of softmax(q * k^T / d^0.05) * v
    '''
    def __init__(self,in_channel,num_heads,num_groups=4,hid_dim=None):
        super().__init__()
        self.groupnorm = nn.GroupNorm(num_groups,in_channel)
        if hid_dim is None:
            self.hid_dim = in_channel
        else :
            self.hid_dim = hid_dim
        self.qkv_projection = nn.Conv2d(in_channel,3*num_heads*self.hid_dim,1,1,0)
        self.out_projection = nn.Conv2d(num_heads*self.hid_dim,in_channel,1,1,0)
        self.num_heads = num_heads
    def forward(self,x):
        B,_,H,W = x.shape
        qkv = self.qkv_projection(x)
        q,k,v = torch.chunk(qkv,3,dim=1)
        q = q.reshape((B*self.num_heads,self.hid_dim,H*W)).permute(0,2,1)
        k_t = k.reshape((B*self.num_heads,self.hid_dim,H*W))
        v = v.reshape((B*self.num_heads,self.hid_dim,H*W)).permute(0,2,1)
        attention_weight = nn.functional.softmax(torch.bmm(q,k_t) * self.hid_dim ** -0.5 ,dim=-1)
        out = torch.bmm(attention_weight,v)
        assert list(out.shape) == [B*self.num_heads,H*W,self.hid_dim]
        out = out.reshape((B,self.num_heads,H*W,self.hid_dim)).permute(0,1,3,2)
        out = out.reshape((B,self.num_heads*self.hid_dim,H,W))
        return self.out_projection(out) + x

In [5]:
class ResBlock(nn.Module):
    def __init__(self,in_channel,out_channel,t_out,dropout=0.5,num_group=4):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.GroupNorm(num_group,in_channel),
            Swish(),
            nn.Conv2d(in_channel,out_channel,3,1,1),
        )
        self.tsEmb_projection = nn.Sequential(
            Swish(),
            nn.Linear(t_out,out_channel),
        )
        self.block2 = nn.Sequential(
            nn.GroupNorm(num_group,out_channel),
            Swish(),
            nn.Dropout(dropout),
            nn.Conv2d(out_channel,out_channel,3,1,1),
        )
        if in_channel != out_channel :
            self.in2out = nn.Conv2d(in_channel,out_channel,3,1,1)
        else :
            self.in2out = nn.Identity()
    def forward(self,x,tsEmb):
        h = self.block1(x)
        tsEmb = self.tsEmb_projection(tsEmb)[:,:,None,None]
        h += tsEmb
        h = self.block2(h)
        h += self.in2out(x)
        return h


In [6]:
class DownBlock(nn.Module):
    def __init__(self,in_channel,out_channel,t_out,num_heads=4,dropout=0.5,num_group=4,has_attention=True):
        super().__init__()
        self.resblock = ResBlock(in_channel,out_channel,t_out,dropout,num_group)
        if has_attention:
            self.attention = AttentionBlock(out_channel,num_heads,num_group)
        else:
            self.attention = nn.Identity()
    def forward(self,x,tEmb):
        x = self.resblock(x,tEmb)
        x = self.attention(x)
        return x
class UpBlock(nn.Module):
    def __init__(self,in_channel,out_channel,t_out,num_heads=4,dropout=0.5,num_group=4,has_attention=True):
        super().__init__()
        self.resblock = ResBlock(in_channel,out_channel,t_out,dropout,num_group)
        if has_attention:
            self.attention = AttentionBlock(out_channel,num_heads,num_group)
        else:
            self.attention = nn.Identity()
    def forward(self,x,tEmb):
        x = self.resblock(x,tEmb)
        x = self.attention(x)
        return x
class UpSample(nn.Module):
    def __init__(self,in_channel):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2,mode='bilinear',align_corners=False)
        self.conv = nn.Conv2d(in_channel,in_channel//2,3,1,1)
    def forward(self,x,tEmb):
        _ = tEmb
        x = self.conv(self.upsample(x))
        return x
class DownSample(nn.Module):
    def __init__(self,in_channel):
        super().__init__()
        self.conv = nn.Conv2d(in_channel,in_channel,3,2,1)
    def forward(self,x,tEmb):
        _ = tEmb
        return self.conv(x)
class MiddleBlock(nn.Module):
    def __init__(self,in_channel,out_channel,t_out,num_heads=4,drop_out=0.5,num_group=4):
        super().__init__()
        self.resblock1 = ResBlock(in_channel,in_channel,t_out,drop_out,num_group)
        self.resblock2 = ResBlock(in_channel,out_channel,t_out,drop_out,num_group)
        self.attention = AttentionBlock(in_channel,num_heads,num_group)
    def forward(self,x,tEmb):
        x = self.resblock1(x,tEmb)
        x = self.attention(x)
        x = self.resblock2(x,tEmb)
        return x

In [7]:
#write a Unet using the above blocks

class Conv(nn.Module):
    def __init__(self,in_channel,out_channel,kernel,stride,padding) -> None:
        super().__init__()
        self.conv = nn.Conv2d(in_channel,out_channel,kernel,stride,padding)
    def forward(self,x,tEmb):
        _ = tEmb
        return self.conv(x)
class Unet(nn.Module):
    def __init__(self,config:dict
                 ):
        super().__init__()
        steps = config['steps']
        t_hid = config['t_hid']
        t_out = config['t_out']
        in_channel = config['in_channel']
        self.timestamp_embedding = TimestampEmbedding(steps,t_hid,t_out)
        input_process = []
        for ch in config['input_process']:
            if in_channel % config['num_group'] != 0:
                input_process.append(Conv(in_channel,ch,3,1,1))
            else :
                input_process.append(DownBlock(in_channel,ch,t_out,has_attention=False))
            in_channel = ch
        self.input_process = nn.ModuleList(input_process)
        downs = []
        for down_ch in config['down_blocks']:
            downs.append(DownSample(in_channel))
            for ch in down_ch:
                downs.append(DownBlock(in_channel,ch,t_out,num_heads=config["num_heads"]))
                in_channel = ch
        self.downs = nn.ModuleList(downs)
        middle_blocks = [DownSample(in_channel)]
        for mid_ch in config['middle_blocks']:
            middle_blocks.append(MiddleBlock(in_channel,mid_ch,t_out))
            in_channel = mid_ch
        middle_blocks.append(UpSample(in_channel))
        self.middle_blocks = nn.ModuleList(middle_blocks)
        up_blocks = []
        for up_ch in config['up_blocks']:
            for ch in up_ch:
                up_blocks.append(UpBlock(in_channel,ch,t_out,num_heads=config["num_heads"]))
                in_channel = ch
            up_blocks.append(UpSample(in_channel))
        self.up_blocks = nn.ModuleList(up_blocks)
        output_process = []
        for ch in config['output_process']:
            output_process.append(DownBlock(in_channel,ch,t_out))
            in_channel = ch
        output_process.append(Conv(in_channel,config['out_channel'],3,1,1))
        self.output_process = nn.ModuleList(output_process)
    def forward(self,x,t):
        tsEmb = self.timestamp_embedding(t)
        for input_process in self.input_process:
            x = input_process(x,tsEmb)
        downs = []
        for down in self.downs:
            if isinstance(down,DownSample):
                downs.append(x)
            x = down(x,tsEmb)
        h = x
        for middle in self.middle_blocks:
            h = middle(h,tsEmb)
        x = torch.cat([x,h],dim=1)
        for up in self.up_blocks:
            x = up(x,tsEmb)
            if isinstance(up,UpSample):
                down = downs.pop()
                x = torch.cat([x,down],dim=1)
        for output in self.output_process:
            x = output(x,tsEmb)
        return x

In [8]:
def extract(origin,index,length):
    res = torch.gather(origin,index=index,dim=0)
    B = index.shape[0]
    return res.reshape([B] + [1] * (length-1))
class DDPM(nn.Module):
    def __init__(self,Config):
        super().__init__()
        self.T = Config['steps']
        beta_1 = Config['beta_1']
        beta_T = Config['beta_t']
        self.unet = Unet(Config)
        betas = torch.linspace(beta_1,beta_T,self.T)
        alphas = 1 - betas
        alphas_bar = torch.cumprod(alphas,dim=0)
        self.register_buffer("betas",betas)
        self.register_buffer("alphas",alphas)
        self.register_buffer("alphas_bar",alphas_bar)
        self.register_buffer("sqrt_alphas_bar",torch.sqrt(alphas_bar))
        self.register_buffer("sqrt_oneminus_alphas_bar",torch.sqrt(1-alphas_bar))
        self.register_buffer("sqrt_recip_alphas", torch.sqrt(1/alphas))
        self.register_buffer("noise_coeff",self.betas/self.sqrt_oneminus_alphas_bar)
    def forward(self,x):
        device = next(self.parameters()).device
        length = len(x.shape)
        noise = torch.randn_like(x).to(device)
        t = torch.randint(self.T,(x.shape[0],)).to(device)
        x_t =  extract(self.sqrt_alphas_bar,t,length) * x + noise * extract(self.sqrt_oneminus_alphas_bar,t,length)
        predicted_noise = self.unet(x_t,t)
        loss = nn.functional.mse_loss(predicted_noise,noise)
        return loss
    def gen_noise_img(self,x_0,t):
        assert t < self.T
        device = next(self.parameters()).device
        length = len(x_0.shape)
        noise = torch.randn_like(x_0).to(device)
        t = torch.full((x_0.shape[0],),t).to(device)
        return extract(self.sqrt_alphas_bar,t,length) * x_0 + noise * extract(self.sqrt_oneminus_alphas_bar,t,length),noise
    @torch.no_grad()
    def sample(self,x,return_progress=False,step=None):
        assert not return_progress or step != None,"step must be given when return_progress is True"
        B = x.shape[0]
        device = next(self.parameters()).device
        length = len(x.shape)
        x_t = x
        res = x
        for t in reversed(range(self.T)):
            timesteps = torch.full((B,),t)
            if t > 0 :
                noise = torch.randn_like(x).to(device)
            else :
                noise = 0
            timesteps = timesteps.to(device)
            predicted_noise = self.unet(x_t,timesteps)
            mean =  extract(self.sqrt_recip_alphas,timesteps,length) * (x_t - extract(self.noise_coeff,timesteps,length) * predicted_noise)
            var = noise * torch.sqrt(extract(self.betas,timesteps,length))
            x_t = mean + var
            if return_progress and t % step == 0:
                res = torch.cat([res,x_t],dim=0)
        if return_progress:
            return res.clamp(-1,1)
        return  x_t.clamp(-1,1) 
# test DDPM

In [9]:
Config = {
    "steps" : 1000,
    "beta_1" : 1e-4,
    "beta_t" : 0.02,
    "t_hid" : 256,
    "t_out" : 64,
    "in_channel" : 3,
    'out_channel' : 3,
    "num_heads" : 2,
    "dropout" : 0.5,
    "num_group" : 32,
    "input_process" : [
        64,64
    ],
    "output_process" : [
        64,64
    ],
    "down_blocks" : [
        [128,128],
        [256,256],
        [512,512],
    ] ,
    "middle_blocks" :[
        1024,1024
    ],
    "up_blocks" : [
        [512,512],
        [256,256],
        [128,128],
    ]
}
import torchvision
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torch.utils.data import dataloader
from torchvision.utils import make_grid
# Define the transformation to apply to the dataset
W,H = 32,32
transform = transforms.Compose([
    transforms.Resize((W,H)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5,0.5,0.5))
])

# Download the CIFAR10 dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
# trainset = torchvision.datasets.CelebA(root='./data', split='train', download=True, transform=transform)
# trainset = torchvision.datasets.ImageFolder("dataset",transform=transform)
batch_size = 10 
trainloader = dataloader.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
device = torch.device('cuda')


Files already downloaded and verified


In [10]:

# model = DDPM(Config).to(device)
# # print(f"running on {device}")
# x,_ = next(iter(trainloader))
# img = make_grid(x)
# plt.imshow(img.permute(1,2,0).detach().to("cpu").numpy() / 2 + 0.5)
# x = x.to(device)
# nosie = model.gen_noise_img(x,500)
# noise2 = model.gen_noise_img(x,999)
# nosie_gen = torch.randn_like(x)
# print(nosie.mean())
# print(nosie.var())
# print(noise2.mean())
# print(noise2.var())
# print(nosie_gen.mean())
# print(nosie_gen.var())


In [11]:
import lightning as L
class DDPM_lightning(L.LightningModule):
    def __init__(self,Config):
        super().__init__()
        self.model = DDPM(Config)
    def forward(self,x):
        return self.model(x)
    def training_step(self,batch,batch_idx):
        x,_ = batch
        loss = self.model(x)
        self.log('train_loss',loss,prog_bar=True)
        return loss
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(),lr=2e-5)
    def train_dataloader(self):
        return trainloader
    def sample(self,return_progress=False,step=None):
        device =  torch.device("cuda")
        x = torch.randn((4,3,32,32)).to(device)
        return self.model.sample(x,return_progress,step)
# CKPT_PATH = r"lightning_logs\version_2\checkpoints\epoch=3-step=20000.ckpt"
# model = DDPM_lightning.load_from_checkpoint(CKPT_PATH,Config=Config)
# trainer = L.Trainer(max_epochs=4)
# model = DDPM_lightning(Config)
# trainer.fit(model,trainloader)
# trainer.predict(model)

In [34]:
from torchvision.utils import make_grid
CKPT_PATH = r"lightning_logs\colab\v2epoch=7-step=4000.ckpt"
device = torch.device("cuda")
model = DDPM_lightning.load_from_checkpoint(CKPT_PATH,Config=Config)
model = model.to(device)


In [41]:
x_0,_= next(iter(trainloader))
x_0 = x_0.to(device)
t = 1
noisy_img,noise = model.model.gen_noise_img(x_0,t)

noise2 = torch.randn_like(x_0)
predicted_noise = model.model.unet(noisy_img,torch.full((10,),t).to(device))

loss = nn.functional.mse_loss(predicted_noise,noise)
loss2 = nn.functional.mse_loss(noise2,noise)
print(loss)
# img = make_grid(model.sample(return_progress=True,step=100)).detach().to("cpu").permute((1,2,0)).numpy() / 2 + 0.5
# plt.imshow(img)


tensor(0.6017, device='cuda:0', grad_fn=<MseLossBackward0>)


In [13]:
print(loss2)

tensor(1.9908, device='cuda:0')


In [14]:
print(model.model.noise_coeff)

tensor([0.0100, 0.0081, 0.0074, 0.0070, 0.0068, 0.0067, 0.0066, 0.0065, 0.0065,
        0.0064, 0.0064, 0.0064, 0.0064, 0.0063, 0.0063, 0.0063, 0.0063, 0.0063,
        0.0063, 0.0063, 0.0063, 0.0063, 0.0063, 0.0063, 0.0063, 0.0063, 0.0063,
        0.0063, 0.0063, 0.0063, 0.0063, 0.0063, 0.0063, 0.0063, 0.0063, 0.0063,
        0.0063, 0.0063, 0.0063, 0.0063, 0.0063, 0.0063, 0.0063, 0.0063, 0.0063,
        0.0063, 0.0063, 0.0063, 0.0063, 0.0063, 0.0063, 0.0063, 0.0063, 0.0063,
        0.0063, 0.0063, 0.0063, 0.0063, 0.0063, 0.0063, 0.0063, 0.0063, 0.0063,
        0.0064, 0.0064, 0.0064, 0.0064, 0.0064, 0.0064, 0.0064, 0.0064, 0.0064,
        0.0064, 0.0064, 0.0064, 0.0064, 0.0064, 0.0064, 0.0064, 0.0064, 0.0064,
        0.0064, 0.0064, 0.0064, 0.0064, 0.0064, 0.0064, 0.0064, 0.0064, 0.0064,
        0.0064, 0.0064, 0.0064, 0.0064, 0.0064, 0.0064, 0.0064, 0.0065, 0.0065,
        0.0065, 0.0065, 0.0065, 0.0065, 0.0065, 0.0065, 0.0065, 0.0065, 0.0065,
        0.0065, 0.0065, 0.0065, 0.0065, 

In [15]:
%reload_ext tensorboard
%tensorboard --logdir=lightning_logs/

Launching TensorBoard...