In [1]:
import torch 
import torch.nn as nn

In [2]:

import torch 
import torch.nn as nn
class LinearScehduler:
    def __init__(self, num_timestamps, beta_start, beta_end):
        self.num_timestamps = num_timestamps
        self.beta_start = beta_start
        self.beta_end = beta_end

        self.betas = torch.linspace(beta_start, beta_end, num_timestamps) # To linearly increase BETA from start to end, we will have BETA from 0 to T
        self.alphas = 1. - self.betas 
        self.alpha_cumilative_product = torch.cumprod(self.alphas, dim = 0)
        self.alpha_sqroot_cumilative_prod = torch.sqrt(self.alpha_cumilative_product)
        self.one_minus_alpha_squareroot = torch.sqrt( 1. - self.alpha_cumilative_product)


    def add_noise(self, original_image, noise,t ):
        """
        add noise to the image in the forward process
        the images and noise will be of shape BxCxHxW and a 1D tensor for time stamp 't' of size 'B'
        """
        """
        Forward method for diffusion
        :param original: Image on which noise is to be applied
        :param noise: Random Noise Tensor (from normal dist)
        :param t: timestep of the forward process of shape -> (B,)
        :return:
        """
        shape = original_image.shape 
        batch_size = shape[0]

        alpha_sqrt_cum_prod = self.alpha_sqroot_cumilative_prod[t].reshape(batch_size)
        one_minus_alphs_sqrt = self.one_minus_alpha_squareroot[t].reshape(batch_size)

        for _ in range(len(shape)-1):
            """Reshape aplha sqrt and alpha-1 sqrt to Bx1x1x1"""
            alpha_sqrt_cum_prod = alpha_sqrt_cum_prod.unsqueeze(-1)
            one_minus_alphs_sqrt = one_minus_alphs_sqrt.unsqueeze(-1)
        return alpha_sqrt_cum_prod*original_image + one_minus_alphs_sqrt*noise

    def reverse_process(self, xt, noise_predicted, t):
        """
        Forward method for diffusion
        :param original: Image on which noise is to be applied
        :param noise: Random Noise Tensor (from normal dist)
        :param t: timestep of the forward process of shape -> (B,)
        :return: tuple of (mean, image), it returns the predicted mean of the distribution and the predicted denoised image
        """
        x0 = (xt - (self.one_minus_alpha_squareroot[t]*noise_predicted)) / self.alpha_sqroot_cumilative_prod[t]

        x0 = torch.clamp(x0, -1., 1.)

        mean = xt - ((self.betas[t]*noise_predicted) / self.alpha_sqroot_cumilative_prod[t])
        mean = mean / torch.sqrt(self.alphas[t])

        if t==0:
            return mean, x0
        else:
            variance = (1. - self.alphas[t]) * (1.- self.alpha_cumilative_product[t])
            variance = variance / (1. - self.alphas[t])
            sigma = variance ** 0.5 
            z = torch.randn(xt.shape).to(xt.device)
            #return the sample from the distribution using Reparameterization trick
            return mean + sigma*z, x0

    
        
        

In [3]:

import torch
import torch.nn as nn

class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

class TimeEmbed(nn.Module):

    """Takes time stam't' and the required embeddings dimnestion.
    Then 't' is passed to Embedding followed by Linear layer, acitvation function and a final Linear layer:
    This is done to project the timestamp values as vectors.
    return: BxD embedding representation of B time steps.
    """    
    
    def __init__(self, t_embed_dim):
        super().__init__()
        self.t_embed_dim = t_embed_dim
        self.fc = nn.Linear(t_embed_dim, t_embed_dim)
        self.swish = Swish()
    
    def forward(self, t):

        # Factor: 10000^(2i/d_model)
        factor = 10000 ** (torch.arange(
            start=0, end=self.t_embed_dim // 2, dtype=torch.float32, device=t.device
        ) / (self.t_embed_dim // 2))

        # Compute embeddings
        t_emb = t[:, None] / factor  # Shape: (B, t_embed_dim // 2)
        t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)  # Shape: (B, t_embed_dim)

        # Pass through fully connected layer and Swish activation
        t_emb = self.swish(self.fc(t_emb))  # Final projection with non-linearity
        return t_emb


In [4]:

import torch
import torch.nn as nn

class DownBlock(nn.Module):
    """
    A downsampling block with residual connections, attention, and time embeddings.
    """
    def __init__(self, n_groups, in_channels, out_channels, num_heads):
        super().__init__()
        self.Block1 = nn.Sequential(
            nn.GroupNorm(n_groups, in_channels),
            Swish(),
            nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1)),
        )
        self.Block2 = nn.Sequential(
            nn.GroupNorm(n_groups, out_channels),
            Swish(),
            nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1)),
        )
        self.attention_norm = nn.GroupNorm(n_groups, out_channels)
        self.attention = nn.MultiheadAttention(out_channels, num_heads, batch_first=True)

        # Linear projection to match input size for skip connection
        self.linear_layer_input = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))

        # Linear projection for time embedding to match out_channels
        self.time_proj = nn.Linear(out_channels, out_channels)

        # Downsampling layer
        self.down_sampling = nn.Conv2d(out_channels, out_channels, kernel_size=4, padding=1, stride=2)

    def forward(self, x, t_emb):
        # Residual block
        block1 = self.Block1(x)
        
        # Project time embedding and add
        t_proj = self.time_proj(t_emb)  # Shape: (B, out_channels)
        t_proj = t_proj[:, :, None, None]  # Add spatial dimensions to match the expected shape by Block1
        block_time_sum = block1 + t_proj

        block2 = self.Block2(block_time_sum)
        skip_connection = self.linear_layer_input(x)
        out_residual = block2 + skip_connection

        # Attention block
        batch, channel, h, w = out_residual.shape
        attn_input = out_residual.reshape(batch, channel, h * w)
        attn_input = self.attention_norm(attn_input)
        attn_input = attn_input.transpose(1, 2)
        out_attn, _ = self.attention(attn_input, attn_input, attn_input)
        out_attn = out_attn.transpose(1, 2).reshape(batch, channel, h, w)

        # Final output
        out_final = out_attn + out_residual
        out_final = self.down_sampling(out_final)

        return out_final


In [5]:
# Example Usage
batch_size = 8
height, width = 64, 64
in_channels = 64
out_channels = 256
time_embed_dim = 256  # Match this to `out_channels`
n_groups = 32
num_heads = 8

# Inputs
x = torch.randn(batch_size, in_channels, height, width)
t = torch.randn(batch_size)  # Time steps

# Time embedding generation
time_embedder = TimeEmbed(time_embed_dim)
t_emb = time_embedder(t)  # Shape: (B, time_embed_dim)

# DownBlock
down = DownBlock(n_groups, in_channels, out_channels, num_heads)
output = down(x, t_emb)
print(output.shape)  # Should print: (8, 256, 32, 32)


torch.Size([8, 256, 32, 32])


In [6]:
from torchinfo import summary

In [24]:
middle = MiddleBlock(out_channels, out_channels, n_groups, num_heads)


In [30]:
middle = middle(output, t_emb)

In [7]:
summary(down, input_data=(x, t_emb))

Layer (type:depth-idx)                   Output Shape              Param #
DownBlock                                [8, 256, 32, 32]          --
├─Sequential: 1-1                        [8, 256, 64, 64]          --
│    └─GroupNorm: 2-1                    [8, 64, 64, 64]           128
│    └─Swish: 2-2                        [8, 64, 64, 64]           --
│    └─Conv2d: 2-3                       [8, 256, 64, 64]          147,712
├─Linear: 1-2                            [8, 256]                  65,792
├─Sequential: 1-3                        [8, 256, 64, 64]          --
│    └─GroupNorm: 2-4                    [8, 256, 64, 64]          512
│    └─Swish: 2-5                        [8, 256, 64, 64]          --
│    └─Conv2d: 2-6                       [8, 256, 64, 64]          590,080
├─Conv2d: 1-4                            [8, 256, 64, 64]          16,640
├─GroupNorm: 1-5                         [8, 256, 4096]            512
├─MultiheadAttention: 1-6                [8, 4096, 256]         

In [27]:

import torch 
import torch.nn as nn
class MiddleBlock(nn.Module):
    def __init__(self, in_channels, out_channels, n_groups, num_heads):
        super().__init__()
        # First Residual Block
        self.Block1 = nn.Sequential(
            nn.GroupNorm(n_groups, in_channels),
            Swish(),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        )
        self.Block2 = nn.Sequential(
            nn.GroupNorm(n_groups, out_channels),
            Swish(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
        )
        
        # Attention Block
        self.attention_norm = nn.GroupNorm(n_groups, out_channels)
        self.attention = nn.MultiheadAttention(out_channels, num_heads, batch_first=True)

        # Input Projection and Time Embedding Projection
        self.input_projection = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.time_projection = nn.Linear(out_channels, out_channels)

    def forward(self, x, t_emb):
        # First Residual Block
        res1_block1 = self.Block1(x)
        t_proj = self.time_projection(t_emb).view(t_emb.size(0), -1, 1, 1)  # Project time embedding
        res1_block2 = self.Block2(res1_block1 + t_proj)
        skip_connection1 = self.input_projection(x)
        out_residual = res1_block2 + skip_connection1
    
        # Attention Block
        batch, channel, h, w = out_residual.shape
        attn_input = out_residual.view(batch, channel, h * w).permute(0, 2, 1)  # Shape: (B, H*W, C)
        normalized_attn_input = attn_input.permute(0, 2, 1).view(batch, channel, h, w)  # Shape: (B, C, H, W)
        normalized_attn_input = self.attention_norm(normalized_attn_input)  # Apply GroupNorm
        normalized_attn_input = normalized_attn_input.view(batch, channel, h * w).permute(0, 2, 1)  # Shape: (B, H*W, C)
        out_attn, _ = self.attention(normalized_attn_input, normalized_attn_input, normalized_attn_input)
        out_attn = out_attn.permute(0, 2, 1).view(batch, channel, h, w)  # Reshape back to [B, C, H, W]
    
        # Second Residual Block
        res2_block1 = self.Block1(out_attn)
        t_proj2 = self.time_projection(t_emb).view(t_emb.size(0), -1, 1, 1)
        res2_block2 = self.Block2(res2_block1 + t_proj2)
        skip_connection2 = self.input_projection(x)
        out_final = res2_block2 + skip_connection2
        return out_final



In [106]:
%%writefile 'Up_Sample.py'
import torch 
import torch.nn as nn
from Linear_scheduler import LinearScehduler
from Time_Embed import TimeEmbed, Swish

class UpBlock(nn.Module):
    def __init__(self, n_groups, in_channels, out_channels, num_heads):
        super().__init__()
        self.Block1 = nn.Sequential(
            nn.GroupNorm(n_groups, in_channels),
            Swish(),
            nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1)),
        )
        self.Block2 = nn.Sequential(
            nn.GroupNorm(n_groups, out_channels),
            Swish(),
            nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1)),
        )
        self.attention_norm = nn.LayerNorm(out_channels)
        self.attention = nn.MultiheadAttention(out_channels, num_heads, batch_first=True)

        # Linear projection to match input size for skip connection
        self.linear_layer_input = nn.Conv2d(in_channels + out_channels, out_channels, kernel_size=(1, 1))

        # Linear projection for time embedding to match out_channels
        self.time_proj = nn.Linear(out_channels, out_channels)

        # Upsampling layer
        self.up_sample = nn.ConvTranspose2d(in_channels//2, out_channels//2, kernel_size = 4, padding = 1, stride = 2)

    def forward(self, x, down_out, t):
        x = self.up_sample(x)
        print(x.shape)
        # Concatenate skip connection
        x = torch.concat([x, down_out], dim=1)

        # Residual block
        block1 = self.Block1(x)

        # Project time embedding and add
        t_proj = self.time_proj(t)
        t_proj = t_proj[:, :, None, None]  # Broadcast over spatial dimensions
        block_time_sum = block1 + t_proj

        block2 = self.Block2(block_time_sum)
        skip_connection = self.linear_layer_input(x)
        out_residual = block2 + skip_connection

        # Attention block
        batch, channel, h, w = out_residual.shape
        attn_input = out_residual.reshape(batch, h * w, channel)  # (B, H*W, C)
        out_attn, _ = self.attention(attn_input, attn_input, attn_input)
        out_attn = out_attn.reshape(batch, channel, h, w)

        # Final output
        out_final = out_attn + out_residual
        return out_final


Writing Up_Sample.py


In [35]:
middle.shape

torch.Size([8, 256, 32, 32])

In [36]:
output.shape

torch.Size([8, 256, 32, 32])

In [90]:
t_emb.shape

torch.Size([8, 256])