# Installs

In [2]:
!pip install einops ptflops



# Imports

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from ptflops import get_model_complexity_info
import gc
import math
from einops import rearrange

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cpu'

# Utils

In [4]:
def clear_model_from_memory(model: nn.Module): # clearing memory
    del model
    gc.collect()
    torch.cuda.empty_cache()

In [5]:
def count_parameters_with_commas(model: nn.Module) -> str: # no. of params
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    formatted_params = f"{total_params:,}"
    return formatted_params

In [6]:
def model_size_mb(model: nn.Module): # params size in mb
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    size_in_bytes = total_params * 4
    size_in_mb = size_in_bytes / (1024 ** 2)
    return size_in_mb

In [5]:
def count_model_flops(model: nn.Module, input_size=(3, 224, 224), print_results=True):
    try:
        macs, params = get_model_complexity_info(
            model, input_size, as_strings=False, print_per_layer_stat=False, verbose=False
        )

        gflops = macs / 1e9  # Convert MACs to GFLOPs

        if print_results:
            print(f'Computational complexity: {gflops:.3f} GFLOPs')
            print(f'Number of parameters: {params / 1e6:.3f} M')

        return gflops, params
    except Exception as e:
        print(f"An error occurred: {str(e)}")
        return None, None

# Model 1 (4 Layers)

In [8]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=16, emb_dim=1024):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_channels, emb_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # (B, emb_dim, H/patch_size, W/patch_size)
        x = rearrange(x, 'b c h w -> b (h w) c')  # (B, num_patches, emb_dim)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, emb_dim, num_heads, mlp_ratio=4., dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.norm1 = nn.LayerNorm(emb_dim)
        self.attn = nn.MultiheadAttention(emb_dim, num_heads, dropout=dropout)
        self.norm2 = nn.LayerNorm(emb_dim)
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim, int(mlp_ratio * emb_dim)),
            nn.GELU(),
            nn.Linear(int(mlp_ratio * emb_dim), emb_dim),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = x + self.dropout(self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0])
        x = x + self.dropout(self.mlp(self.norm2(x)))
        return x

class ProgressiveUpsamplingTransformer(nn.Module):
    def __init__(self, in_channels=3, emb_dim=768, num_heads=12, num_layers=4, num_classes=30):
        super(ProgressiveUpsamplingTransformer, self).__init__()
        self.patch_size = 16
        self.emb_dim = emb_dim
        self.patch_embed = PatchEmbedding(in_channels, self.patch_size, emb_dim)

        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(emb_dim, num_heads) for _ in range(num_layers)
        ])
        self.final_proj = nn.Conv2d(emb_dim, num_classes, kernel_size=1)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.patch_embed(x)  # (B, num_patches, emb_dim)
        H_p, W_p = H // self.patch_size, W // self.patch_size
        x = rearrange(x, 'b (h w) c -> b c h w', h=H_p, w=W_p)  # (B, emb_dim, H_p, W_p)

        for block in self.transformer_blocks:
            x = rearrange(x, 'b c h w -> b (h w) c')  # (B, num_patches, emb_dim)
            x = block(x)  # (B, num_patches, emb_dim)
            x = rearrange(x, 'b (h w) c -> b c h w', h=H_p, w=W_p)  # (B, emb_dim, H_p, W_p)
            H_p, W_p = H_p * 2, W_p * 2
            x = F.interpolate(x, size=(H_p, W_p), mode='bilinear', align_corners=False)  # Upsample

        x = self.final_proj(x)  # (B, num_classes, H, W)
        x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=False)  # Ensure the output size is same as input
        return x

In [9]:
# Example usage
model1 = ProgressiveUpsamplingTransformer().to(device)
input_tensor = torch.randn(1, 3, 224, 224).to(device)  # Example input tensor
output = model1(input_tensor)
print(output.shape)

torch.Size([1, 30, 224, 224])


# Model 2 (12 Layers Total)

In [6]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=16, emb_dim=768):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_channels, emb_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # (B, emb_dim, H/patch_size, W/patch_size)
        x = rearrange(x, 'b c h w -> b (h w) c')  # (B, num_patches, emb_dim)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, emb_dim, num_heads, mlp_ratio=4., dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.norm1 = nn.LayerNorm(emb_dim)
        self.attn = nn.MultiheadAttention(emb_dim, num_heads, dropout=dropout)
        self.norm2 = nn.LayerNorm(emb_dim)
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim, int(mlp_ratio * emb_dim)),
            nn.GELU(),
            nn.Linear(int(mlp_ratio * emb_dim), emb_dim),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.norm1(x)
        x = x + self.dropout(self.attn(x, x, x)[0])
        x = self.norm2(x)
        x = x + self.dropout(self.mlp(x))
        return x

class UpsamplingTransformer(nn.Module):
    def __init__(self, in_channels=3, emb_dim=768, num_heads=12, num_classes=21):
        super(UpsamplingTransformer, self).__init__()
        self.patch_size = 16
        self.emb_dim = emb_dim
        self.patch_embed = PatchEmbedding(in_channels, self.patch_size, emb_dim)

        self.encoder_blocks = nn.ModuleList([
            TransformerBlock(emb_dim, num_heads) for _ in range(4)
        ])
        self.upsampling_blocks = nn.ModuleList([
            TransformerBlock(emb_dim, num_heads) for _ in range(4)
        ])
        self.additional_blocks = nn.ModuleList([
            TransformerBlock(emb_dim, num_heads) for _ in range(4)
        ])
        self.final_proj = nn.Conv2d(emb_dim, num_classes, kernel_size=1)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.patch_embed(x)  # (B, num_patches, emb_dim)
        H_p, W_p = H // self.patch_size, W // self.patch_size
        x = rearrange(x, 'b (h w) c -> b c h w', h=H_p, w=W_p)  # (B, emb_dim, H_p, W_p)

        for block in self.encoder_blocks:
            x = rearrange(x, 'b c h w -> b (h w) c')  # (B, num_patches, emb_dim)
            x = block(x)  # (B, num_patches, emb_dim)
            x = rearrange(x, 'b (h w) c -> b c h w', h=H_p, w=W_p)  # (B, emb_dim, H_p, W_p)

        for block in self.upsampling_blocks:
            H_p, W_p = H_p * 2, W_p * 2
            x = F.interpolate(x, size=(H_p, W_p), mode='bilinear', align_corners=False)  # Upsample
            x = rearrange(x, 'b c h w -> b (h w) c')  # (B, num_patches, emb_dim)
            x = block(x)  # (B, num_patches, emb_dim)
            x = rearrange(x, 'b (h w) c -> b c h w', h=H_p, w=W_p)  # (B, emb_dim, H_p, W_p)

        for block in self.additional_blocks:
            x = rearrange(x, 'b c h w -> b (h w) c')  # (B, num_patches, emb_dim)
            x = block(x)  # (B, num_patches, emb_dim)
            x = rearrange(x, 'b (h w) c -> b c h w', h=H_p, w=W_p)  # (B, emb_dim, H_p, W_p)

        x = self.final_proj(x)  # (B, num_classes, H_p, W_p)
        x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=False)  # Ensure the output size is same as input
        return x

In [None]:
# Example usage
model2 = UpsamplingTransformer().to(device)
input_tensor = torch.randn(1, 3, 224, 224).to(device)  # Example input tensor
output = model2(input_tensor)
print(output.shape)

# Model Size

In [10]:
count_parameters_with_commas(model1)

'28,965,150'

In [11]:
model_size_mb(model1)

110.49327850341797

In [None]:
count_parameters_with_commas(model2)

In [None]:
model_size_mb(model2)

# FLOPS

In [18]:
count_model_flops(model1)

Computational complexity: 119.524 GFLOPs
Number of parameters: 28.965 M


(119.524113968, 28965150)

In [19]:
clear_model_from_memory(model1)

In [None]:
count_model_flops(model2)