In [1]:
from torch import nn
import torch
import math
import numpy as np
class SinsoidalPositionalEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        
    def forward(self, timestep): 
        #* get the device of the timestep
        device = timestep.device 
        
        #* get the half of the dimension
        half_dim = self.dim // 2 
        
        #* calculate the frequency #* 2^i / 10000^(2i/d)
        embeddings = math.log(10000) / (half_dim - 1)
        
        #* calculate the embeddings
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) 
        embeddings = timestep[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        
        return embeddings

        

In [2]:
from torch import nn
import torch
import numpy as np
from torchsummary import summary
import math


class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up = False, scale_img = False):
        super().__init__()
        self.time_mlp = nn.Linear(time_emb_dim, out_ch) #* time embedding
        
        #* for upsampling blocks
        if up:
            self.conv1 = nn.Conv2d(2*in_ch, out_ch, kernel_size=3, padding=1)
            if scale_img: 
                self.transform = nn.Sequential(
                    nn.ConvTranspose2d(out_ch, out_ch, kernel_size=4, stride=2, padding=1),
                )
            else:
                self.transform = nn.Sequential(
                    nn.ConvTranspose2d(out_ch, out_ch, kernel_size=4, stride=2, padding=1),
                    nn.MaxPool2d(2, stride=2)
                )
                
        #* for downsampling blocks
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
        
        #* defining rest of the layers 
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
        #self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU()
        
        self.down_scale = nn.MaxPool2d(2, stride=2)
        
    def forward(self, x, t):
        #* first conv 
        h = self.relu(self.conv1(x))
        h = self.bnorm1(h)
        
        # #* time embedding
        time_emb = self.time_mlp(t)
        time_emb = self.relu(time_emb)
        

        # #* You may have to debug this part and permute the time embedding according to the shape of your input
        time_emb = torch.permute(time_emb[(..., ) + (None, ) * 1], (0,2,1,3)) #* (4,1,64) -> (4,1,64,1) -> (4,64,1,1) 
        
        # #* add time embedding to the output of the first conv
        h = h + time_emb
        
        # #* second conv
        
        h = self.bnorm2(self.relu(self.conv2(h)))
        h = self.transform(h)
        return h


In [3]:
## * UNet

class Diffusion_Unet(nn.Module):
    def __init__(self):
        super().__init__()

        image_channels = 1  # * grayscale image
        down_block_channels = (32, 64, 128, 256)
        up_channels = (256, 128, 64, 32)
        self.time_emb_dim = 32
        out_dim = 1

        # * Time Embedding
        self.time_mlp = nn.Sequential(
            SinsoidalPositionalEmbeddings(self.time_emb_dim),  # * positional embeddings
            nn.Linear(self.time_emb_dim, self.time_emb_dim),  # * linear layer
            nn.ReLU(),  # * activation
        )

        # * Initial Convolution
        self.conv0 = nn.Conv2d(
            image_channels, down_block_channels[0], kernel_size=3, padding=1
        )

        # * Downsampling Blocks
        self.downs = nn.ModuleList(
            [
                Block(
                    in_ch=down_block_channels[i],
                    out_ch=down_block_channels[i + 1],
                    time_emb_dim=self.time_emb_dim,
                )
                for i in range(len(down_block_channels) - 1)
            ]
        )

        # * Upsampling Blocks
        self.ups = nn.ModuleList(
            [
                Block(
                    in_ch=up_channels[i],
                    out_ch=up_channels[i + 1],
                    time_emb_dim=self.time_emb_dim,
                    up=True,
                    scale_img= False
                )
                for i in range(len(up_channels) - 1)
            ]
        )

        # * Output Convolution
        self.outout = nn.Conv2d(up_channels[-1], out_dim, kernel_size=3, padding=1)

    # * forward pass
    def forward(self, x, t):
        # * Ebmedding time
        t = self.time_mlp(t)

        # * Initial Convolution
        x = self.conv0(x)

        # * Unet
        residual_input = []
        counter = 0
        # * append the input of each down block to the list
        for down in self.downs:
            x = down(x, t)
            residual_input.append(x)
            counter += 1

        # * pop the last element from the list
        for up in self.ups:
            x = torch.cat((x, residual_input.pop()), dim=1)
            x = up(x, t)

        # * output
        x = self.outout(x)
        return x

In [4]:
image = torch.randn(4, 1, 28, 28)
time_stamp = torch.randn(4, 1)
model = Diffusion_Unet()
model(image, time_stamp).shape

torch.Size([4, 1, 28, 28])

In [5]:
import torch
from torchsummary import summary

model = Diffusion_Unet()
model = model.to('cuda')
# print("Num params: ", sum(p.numel() for p in model.parameters()))
print("Num params: ", sum(p.numel() for p in model.parameters()))


try :
    summary(
    model=model,
    input_size= [(1, 28, 28), (1,)],
    device='cuda',
    batch_size=1
    )
except Exception as e:
    print("Ignore the error, summary works fine.")
    

Num params:  3275649
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
SinsoidalPositionalEmbeddings-1                 [1, 1, 32]               0
            Linear-2                 [1, 1, 32]           1,056
              ReLU-3                 [1, 1, 32]               0
            Conv2d-4            [1, 32, 28, 28]             320
            Conv2d-5            [1, 64, 28, 28]          18,496
              ReLU-6            [1, 64, 28, 28]               0
       BatchNorm2d-7            [1, 64, 28, 28]             128
            Linear-8                 [1, 1, 64]           2,112
              ReLU-9                 [1, 1, 64]               0
           Conv2d-10            [1, 64, 28, 28]          36,928
             ReLU-11            [1, 64, 28, 28]               0
      BatchNorm2d-12            [1, 64, 28, 28]             128
           Conv2d-13            [1, 64, 28, 28]          36,928
       