In [7]:
import torch
from torch import nn
from torch.nn import functional as f
import math

Building Encoder and Decoder

In [14]:
# dir(nn)

In [8]:
## Job of Encoder is to reduce the dimension of Data

Building self Attention

In [21]:
class selfAttention(nn.Module):
    def __init__(self,n_heads:int,d_embed:int, in_proj_bias=True ,out_proj_bias=True):
        super().__init__()
        self.in_proj=nn.Linear(d_embed,3*d_embed,bias=in_proj_bias)
        self.out_proj=nn.Linear(d_embed,d_embed,bias=out_proj_bias)
        self.heads=n_heads
        self.d_head=d_embed//n_heads

    def forward(self, x, casual_mask=False):
        # x: (Batch_size,seq_length (len_of_embed),Dim)
        input_shape=x.shape
        batch_size,seq_len,d_embed=input_shape

        # required for converting the input into desired input for K,q,v
        intermin_shape=(batch_size,seq_len,self.heads,self.d_head)

        # (Batch_size,seq_length (len_of_embed),Dim)--> (Batch_size,seq_length (len_of_embed),Dim*3 )--> 3 tensors of shape  (Batch_size,seq_length (len_of_embed),Dim)
        k,q,v=self.in_proj(x).chunk(3,dim=-1) # to the last dim which is itself in this case is "DIM"

        # splitting the number of k,q,v in n_heads


# transpose(): Swaps two dimensions of a tensor, effectively changing the arrangement of the data.
# view(): Reshapes the tensor without changing its data. The number of elements must remain the same.

#Creating Q, K, V:
# For each word, we create Q, K, and V (in practice, these are learned linear transformations of the embeddings).

        q=q.view(intermin_shape).transpose(1,2)  #(Batch_size,seq_length,Dim )--> (Batch_size,seq_length, Heads, DIM/Heads )
        k=k.view(intermin_shape).transpose(1,2)  #(Batch_size,seq_length,Dim )--> (Batch_size,seq_length, Heads, DIM/Heads )
        v=v.view(intermin_shape).transpose(1,2) #(Batch_size,seq_length,Dim )--> (Batch_size,seq_length, Heads, DIM/Heads )


        # doing the weights 
        weights= q@k.transpose(-1,-2)

        if casual_mask: # so that we can disaacoatiate simliars tokens nex to each other, makes a -infity passed to softmax--> gives 0
            # mask where the upper traingles are 1 above the principal Diagonal
            mask=torch.ones_like(weights,dtype=torch.bool).triu(1)
            weights.masked_fill(mask,-torch.inf)

        weights/=math.sqrt(self.d_head)
        weights=f.softmax(weights,dim=-1)

        output=weights@v # (Batch_size,Heads,Seq_len,Seq_len ) @ (Batch_size,Heads,Seq_len,DIM/Heads ) --> (Batch_size,Heads,Seq_len,DIM/Heads )
        
        output.transpose(1,2) #m (Batch_size , Seq_len, Heads, DIM/Heads )

        output=output.reshape(input_shape)

        output=self.out_proj(output)

        # (Batch_size,seq_length,Dim )
        return output


Building the VAE_residual Block

self VAE Attention Block

In [22]:
class VAE_AttentionBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.groupNorm= nn.GroupNorm(32, channels)
        self.attention=selfAttention(1,channels)

    def forward(self, x :torch.Tensor)-> torch.Tensor:
        residue=x
        n,c,h,w=x.shape()
        
        #(BAtch size, Feautures(respected embeds), height , widht) --> (BAtch size, Feautures(respected embeds), height * widht)
        x=x.view(n,c,w*h)

        # (BAtch size, Feautures(respected embeds), height , widht)--> (BAtch size, , height * widht, Feautures(respected embeds))
        x=x.transpose(-1,-2)

        x=self.attention(x)

        x=x.transpose(-1,-2)
        x=x.view((n,c,h,w))
        x+=residue

        residue



Residual Block


In [23]:
class VAE_ResidualBlock(nn.Module):
    def __init__(self,in_channels,out_channels):
        super().__init__()
        self.groupNorm_1=nn.GroupNorm(32,in_channels) # this doesnt change the --> input , outputsize
        self.conv_1= nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1)  # this also doesnt change the output as there padding introducted
        self.groupNorm_2=nn.GroupNorm(32,out_channels)
        self.conv_2=nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1)

        if in_channels== out_channels:
            self.residual_layer=nn.Identity()
        else:
            self.residual_layer=nn.Conv2d(in_channels,out_channels,kernel_size=1,padding=0) # to convert the layer just that input channel is equal to ouput channel

    def forward(self,x:torch.Tensor)->torch.tensor:
        # x: (Batch_size, In_channels, Height, Width)
        residue = x
        x=self.groupNorm_1(x)
        x==f.silu(x)
        x=self.conv_1(x)
        x=self.groupNorm_2(x)
        x=f.silu(x)
        x=self.conv_2(x)
        return x + self.residual_layer(residue)
            
        



In [16]:
# from decoder import VAE_attentionBlock, VAE_residualBlock

# Building the encoder Block

┌──────────────────────────────────────────────────────────────────┐
│  Input:  x ∈ ℝ^(B, 3, H, W)                                     │
└──────────────────────────────────────────────────────────────────┘
                   │
                   ▼
          Conv2D(3 → 128, 3×3, stride=1, pad=1)
                   │  (same H×W, 128 channels)
                   ▼
          ┌────────────────────────────────────────────────────────┐
          │  ResidualBlock(128 → 128)                             │
          └────────────────────────────────────────────────────────┘
                   │  (same H×W, 128 channels)
                   ▼
          ┌────────────────────────────────────────────────────────┐
          │  ResidualBlock(128 → 128)                             │
          └────────────────────────────────────────────────────────┘
                   │  (same H×W, 128 channels)
                   ▼
          Conv2D(128 → 128, 3×3, stride=2, pad=0)
                   │  (H/2 × W/2, 128 channels; note: padded)
                   ▼
          ┌────────────────────────────────────────────────────────┐
          │  ResidualBlock(128 → 256)                             │
          └────────────────────────────────────────────────────────┘
                   │  (H/2 × W/2, 256 channels)
                   ▼
          ┌────────────────────────────────────────────────────────┐
          │  ResidualBlock(256 → 256)                             │
          └────────────────────────────────────────────────────────┘
                   │  (H/2 × W/2, 256 channels)
                   ▼
          Conv2D(256 → 256, 3×3, stride=2, pad=0)
                   │  (H/4 × W/4, 256 channels; note: padded)
                   ▼
          ┌────────────────────────────────────────────────────────┐
          │  ResidualBlock(256 → 512)                             │
          └────────────────────────────────────────────────────────┘
                   │  (H/4 × W/4, 512 channels)
                   ▼
          ┌────────────────────────────────────────────────────────┐
          │  ResidualBlock(512 → 512)                             │
          └────────────────────────────────────────────────────────┘
                   │  (H/4 × W/4, 512 channels)
                   ▼
          Conv2D(512 → 512, 3×3, stride=2, pad=0)
                   │  (H/8 × W/8, 512 channels; note: padded)
                   ▼
          ┌────────────────────────────────────────────────────────┐
          │  ResidualBlock(512 → 512)                             │
          └────────────────────────────────────────────────────────┘
                   │  (H/8 × W/8, 512 channels)
                   ▼
          ┌────────────────────────────────────────────────────────┐
          │  ResidualBlock(512 → 512)                             │
          └────────────────────────────────────────────────────────┘
                   │  (H/8 × W/8, 512 channels)
                   ▼
          ┌────────────────────────────────────────────────────────┐
          │  ResidualBlock(512 → 512)                             │
          └────────────────────────────────────────────────────────┘
                   │  (H/8 × W/8, 512 channels)
                   ▼
               AttentionBlock(512 → 512)
                   │  (H/8 × W/8, 512 channels)
                   ▼
          ┌────────────────────────────────────────────────────────┐
          │  ResidualBlock(512 → 512)                             │
          └────────────────────────────────────────────────────────┘
                   │  (H/8 × W/8, 512 channels)
                   ▼
               GroupNorm(32, 512)
                   │
                   ▼
                     SILU
                   │
                   ▼
          Conv2D(512 → 8, 3×3, pad=1)
                   │  (H/8 × W/8, 8 channels)
                   ▼
          Conv2D(8 → 8, 1×1, pad=0)
                   │  (H/8 × W/8, 8 channels)
                   ▼
   ┌────────────────────────────────────────────────────────────────┐
   │   Split along channel → mean (4 ch) and log_variance (4 ch)   │
   └────────────────────────────────────────────────────────────────┘
                   │
                   ▼
   Clamp log_variance to [-30, 20], then variance = exp(log_variance)
                   │
                   ▼
          stdev = sqrt(variance)
                   │
                   ▼
          z = mean + stdev * noise
                   │
                   ▼
         Scale z by 0.18215 → (output)


In [24]:
class VAE_Encoder(nn.Sequential):
    def __init__(self):
        super().__init__()
        # batch size, channel, height , width
        nn.Conv2d(3,128,3,1,1),
        VAE_ResidualBlock(128,128), # does not change the size of the image : (Batch_size ,128,Height , width)--> remain same as --> (Batch_size ,128,Height , width)
        VAE_ResidualBlock(128,128),
        nn.Conv2d(128,128,kernel_size=3,stride=2,padding=0),
        VAE_ResidualBlock(256,256),#--> only increases the number of Features with Height /2, Width/2 
        VAE_ResidualBlock(256,256),#--> only increases the number of Features with Height /2, Width/2 
        nn.Conv2d(256,256,kernel_size=3,stride=2,padding=0),            
        VAE_ResidualBlock(256,512),#--> only increases the number of Features with Height /4, Width/4 
        VAE_ResidualBlock(512,512),#--> only increases the number of Features with Height /4, Width/4
        nn.Conv2d(512,512,kernel_size=3,stride=2, padding=0),
        VAE_ResidualBlock(512,512),#--> only increases the number of Features with Height /8, Width/8 
        VAE_ResidualBlock(512,512),#--> only increases the number of Features with Height /8, Width/8  
        VAE_ResidualBlock(512,512),#--> only increases the number of Features with Height /8, Width/8  

        ## ANNOTATION BLOCK ##--> seq to seq model
        VAE_AttentionBlock(512,512), #Batch size, 512, Height /8, Width/8

        VAE_ResidualBlock(512,512),#--> only increases the number of Features with Height /8, Width/8  
        nn.GroupNorm(32,512),
        nn.SILU(),
        nn.Conv2d(512,8,kernel_size=3,padding=1), #Batch size, 8, Height /8, Width/8
        nn.Conv2d(8,8,kernel_size=1,padding=0) #Batch size, 8, Height /8, Width/8

    
        
    def forward(self,x: torch.Tensor,noise: torch.Tensor) -> torch.Tensor:
        """params
        x: (BatchSize, Channel, Height , Width)
        noise: (BatchSize, out_channel,Height/8, width/8)
        """
        for module in self:
            if getattr(module, 'stride',None)==(2,2):
                # padding_left, padding_right, padding_top, padding_bottom
                x=f.pad(x,(0,1,0,1))
            x=module(x)
    
        # (BatchSize, 8, Height/8 , Width/8) --> return two tensors of shape ((BatchSize, 4, Height/8 , Width/8)) i.e on dim=1
        mean,log_variance=torch.chunk(x,2,dim=1)
        log_variance=torch.clamp(log_variance,-30,20)# --> setting a range to log_variance if too small 

        variance=log_variance.exp()
        stdev=variance.sqrt()

        #now converting one distritbution to another with mean and variance
        # if z=N(0,1)--> N(mean, variance)=x?
        # ques: How we do it: Answer--> simply by x=mean+stdev*z
        x=mean+stdev*noise
        
        # scaling the output by a constant
        x*=0.18215

        return x

# Building the Decoder Block

In [25]:

class VAE_Decoder(nn.Sequential):
    def __init__(self):
        super().__init__(
            # (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 4, Height / 8, Width / 8)
            nn.Conv2d(4, 4, kernel_size=1, padding=0),

            # (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
            nn.Conv2d(4, 512, kernel_size=3, padding=1),
            
            # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
            VAE_ResidualBlock(512, 512), 
            
            # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
            VAE_AttentionBlock(512), 
            
            # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
            VAE_ResidualBlock(512, 512), 
            
            # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
            VAE_ResidualBlock(512, 512), 
            
            # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
            VAE_ResidualBlock(512, 512), 
            
            # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
            VAE_ResidualBlock(512, 512), 
            
            # Repeats the rows and columns of the data by scale_factor (like when you resize an image by doubling its size).
            # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 4, Width / 4)
            nn.Upsample(scale_factor=2),
            
            # (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
            nn.Conv2d(512, 512, kernel_size=3, padding=1), 
            
            # (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
            VAE_ResidualBlock(512, 512), 
            
            # (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
            VAE_ResidualBlock(512, 512), 
            
            # (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
            VAE_ResidualBlock(512, 512), 
            
            # (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 2, Width / 2)
            nn.Upsample(scale_factor=2), 
            
            # (Batch_Size, 512, Height / 2, Width / 2) -> (Batch_Size, 512, Height / 2, Width / 2)
            nn.Conv2d(512, 512, kernel_size=3, padding=1), 
            
            # (Batch_Size, 512, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 2, Width / 2)
            VAE_ResidualBlock(512, 256), 
            
            # (Batch_Size, 256, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 2, Width / 2)
            VAE_ResidualBlock(256, 256), 
            
            # (Batch_Size, 256, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 2, Width / 2)
            VAE_ResidualBlock(256, 256), 
            
            # (Batch_Size, 256, Height / 2, Width / 2) -> (Batch_Size, 256, Height, Width)
            nn.Upsample(scale_factor=2), 
            
            # (Batch_Size, 256, Height, Width) -> (Batch_Size, 256, Height, Width)
            nn.Conv2d(256, 256, kernel_size=3, padding=1), 
            
            # (Batch_Size, 256, Height, Width) -> (Batch_Size, 128, Height, Width)
            VAE_ResidualBlock(256, 128), 
            
            # (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
            VAE_ResidualBlock(128, 128), 
            
            # (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
            VAE_ResidualBlock(128, 128), 
            
            # (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
            nn.GroupNorm(32, 128), 
            
            # (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
            nn.SiLU(), 
            
            # (Batch_Size, 128, Height, Width) -> (Batch_Size, 3, Height, Width)
            nn.Conv2d(128, 3, kernel_size=3, padding=1), 
        )

    def forward(self, x):
        # x: (Batch_Size, 4, Height / 8, Width / 8)
        
        # Remove the scaling added by the Encoder.
        x /= 0.18215

        for module in self:
            x = module(x)

        # (Batch_Size, 3, Height, Width)
        return x