In [None]:
from typing import Optional, List

import torch
import torch.nn as nn
import torch.nn.functional as F

from stable_diffusion_model import SpatialTransformer, TimeStepEmbSequantial,Upsample, Downsample

import numpy as numpy
import math

In [None]:
class ResBlock(nn.Module):
    def __init__(self, 
                 channels: int, 
                 d_t_embedding: int,
                 out_channels=None,
                 ):
        """
        channels: input dimensions,
        d_t_embedding: time_embedding,
        out_channels: out_channels
        """
        super().__init__()
        self.channels = channels
        self.d_t_embedding = d_t_embedding
        self.out_channels = out_channels

        if self.out_channels is None:
            self.out_channels = channels
        
        self.in_layers = nn.Sequential(
            nn.GroupNorm(num_groups=32, num_channels=channels),
            nn.SiLU(),
            nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
        )  

        self.emb_layers = nn.Sequential(
            nn.SiLU(),
            nn.Linear(d_t_embedding, out_channels)
        )

        self.out_layers = nn.Sequential(
            nn.GroupNorm(num_groups=32, nun_channels=channels),
            nn.SiLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        )

        if out_channels==channels:
            self.shortcut = nn.Identity()
        else:
            self.shortcut = nn.Conv2d(channels, out_channels, kernel_size=1, stride=1)

        
        def forward(self, x:torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
            x = self.in_layers(x)
            t_emb = self.emb_layers(t_emb)
            x = x + t_emb[:, :, None, None]
            x = self.out_layers(x)
            shortcut = self.shortcut(x)
            output = shortcut + x
            return output

In [None]:
class UnetModel(nn.Module):
    def __init__(self,
                 in_channels: torch.Tensor,
                 out_channels: torch.Tensor,
                 channels: torch.Tensor,
                 num_blocks: int,
                 atten_level: List[int],
                 channel_multiplier: List[int],
                 num_heads: int,
                 tf_layers: int,
                 emb_dim: int,
                 clip_dim: int,
                 ):
        super().__init__()
        self.channels = channels
        levels = len(self.channel_multiplier)
        
        d_time_embedding = 4 * self.channels

        self.time_emb = nn.Sequential(
            nn.Linear(channels, d_time_embedding),
            nn.SiLU(),
            nn.Linear(d_time_embedding, channels)
        )

        self.downsample = nn.ModuleList([])
        self.downsample.append(TimeStepEmbSequantial(
            nn.Conv2d(in_channels, channels, kernel_size=3, stride=1, padding=1)
        ))

        input_block_channels = [channels]
        channels_list = [channels * m for m in channel_multiplier]
        
        # This would be the downsample part of the UNet, the resnet does not have a different output channel
        for i in range(levels):
            for _ in range(num_blocks):
                layers =  [ResBlock(channels, d_t_embedding=d_time_embedding, out_channels=channels_list[i])]
                channels = channels_list[i]
            
            if i in atten_level:
                layers.append(SpatialTransformer(channels, num_heads, tf_layers, emb_dim))
            
            self.downsample.append(TimeStepEmbSequantial(*layers))
            input_block_channels.append(channels)
        
            # Downsample at all levels except the last one
            if i != levels-1:
               self.downsample.append(TimeStepEmbSequantial(Downsample(channels)))
               input_block_channels.append(channels)
        
        # Build the bottleneck block
        self.middle_block = TimeStepEmbSequantial(
            ResBlock(channels, d_time_embedding),
            SpatialTransformer(channels, num_heads, tf_layers, emb_dim, clip_dim),
            ResBlock(channels, d_time_embedding)
        )

        self.upsample = nn.ModuleList([])
        
        # num_blocks has one more in the upsample part
        for i in reversed(range(levels)):
            for j in range(num_blocks + 1):
                layers = [ResBlock(channels + input_block_channels.pop(), d_t_embedding=d_time_embedding, out_channels=channels_list[i])]
                channels = channels_list[i]

                if i in atten_level:
                    layers.append(SpatialTransformer(channels, num_heads, tf_layers, emb_dim))
                
                if i!=0 and j==num_blocks:
                    layers.append(Upsample(channels))

                self.upsample.append(TimeStepEmbSequantial(*layers))

        self.out = nn.Sequential(
            torch.GroupNorm(num_groups=32, num_channels=channels),
            nn.SiLU(),
            nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
        )

    def time_step_embedding(self, time_step:torch.Tensor, max_period:int=10000):
        half = self.channels//2 #half of the channels would be sin and other half would be cosine
        frequencies = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32))/half
        args = time_step[:, None].float() * frequencies[None]
        output = torch.cat(torch.cos(args), torch.sin(args), dim=-1)
        return output
    
    def forward(self,x: torch.Tensor, time_steps: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
        """
        calculate the time step embedding and add it for our transformer model
        """
        x_input_block = []
        
        embedding = self.time_step_embedding(x)
        embedding = self.time_emb(x)
        
        # Encoder part of the UNet
        for module in self.downsample:
            x = module(x, embedding, cond)
            x_input_block.append(x)
        
        # BottleNeck part of the UNet
        x = self.middle_block(x, embedding, cond)

        # Decoder part of the UNet
        for module in self.upsample:
            x = module(x, embedding, cond)
            x_input_block.append(x)
        
        return self.out(x)

In [None]:
class TimeStepEmbSequantial(nn.Sequential):
    def __init__(self,
                 x:torch.Tensor,
                 t_emb: torch.Tensor,
                 cond = None,
                 ):
        super().__init__()
    
    def forward(self,x: torch.Tensor, t_emb: torch.Tensor, cond: None) -> torch.Tensor:
        for layer in self:
            if isinstance(layer, ResBlock):
                x = layer(x, t_emb)
            elif isinstance(layer, SpatialTransformer):
                x = layer(x, cond)
            else:
                x = layer(x)
        return x

In [None]:
class Upsample(nn.Module):
    def __init__(self, 
                 channels: int,
                 ):
        """
        This is the part inside every upsample block, so it's the upsample block itself
        """
        super().__init__()
        self.channels = channels
        
        self.conv = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x:torch.Tensor) -> torch.Tensor:
        """
        This upsample class follows the original UNet structure, the model will interpolate first, then downsample once
        """
        x = F.interpolate(x, scale_factor=2, mode='nearest')
        output = self.conv(x)

        return output

In [None]:
class Downsample(nn.Module):
    def __init__(self, 
                 channels:int,
                 ):
        super().__init__()
        self.channels = channels

        self.conv = nn.Conv2d(channels, channels, stride=2, padding=1)

    def forward(self, x:torch.Tensor) -> torch.Tensor:
        """
        This downsample class follows the original UNet structure, the model will downsample without interpolation
        """
        output = self.conv(x)
        return output