# CvT-Model with Convolutional Embedding

<img src="./../CvT-SimplifiedEmbedding.drawio.png?raw=1" height="400" />


# Imports

In [None]:
%pip install pytorch-lightning
%pip install torch torchvision
%pip install lightning
%pip install einops
%pip install timm
%pip install dotenv

In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
from dotenv import load_dotenv
import torch
from einops import rearrange
import torch.nn as nn
from timm.models.layers import DropPath, trunc_normal_

IS_PAPERSPACE = os.getcwd().startswith('/notebooks')
dir_env = os.path.join(os.getcwd(), '.env') if IS_PAPERSPACE else os.path.join(os.getcwd(), '..', '.env')
_ = load_dotenv(dotenv_path=dir_env)

# Modell

In [None]:
class ConvEmbedding(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super().__init__()
        self.proj = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2)
        self.norm = nn.LayerNorm(out_channels)

    def forward(self, x):
        # print('ConvEmbed.forward.0', x.shape)
        x = self.proj(x)
        # print('ConvEmbed.forward.1', x.shape)
        x = rearrange(x, 'b c h w -> b (h w) c')
        # print('ConvEmbed.forward.2', x.shape)
        x = self.norm(x)
        # print('ConvEmbed.forward.3', x.shape)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads=1, mlp_ratio=4.0):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.mlp_ratio = mlp_ratio

        self.norm1 = nn.LayerNorm(dim)

        self.proj_q = nn.Linear(dim, dim, bias=False)
        self.proj_k = nn.Linear(dim, dim, bias=False)
        self.proj_v = nn.Linear(dim, dim, bias=False)

        self.attn_drop = nn.Dropout(0.0)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(0.0)

        self.drop_path = DropPath(0.1)

        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(0.0),
            nn.Linear(int(dim * mlp_ratio), dim),
            nn.Dropout(0.0)
        )

    def forward(self, x):
        residual = x
        x_norm = self.norm1(x)

        q = rearrange(self.proj_q(x_norm), 'b t (h d) -> b h t d', h=self.num_heads)
        k = rearrange(self.proj_k(x_norm), 'b t (h d) -> b h t d', h=self.num_heads)
        v = rearrange(self.proj_v(x_norm), 'b t (h d) -> b h t d', h=self.num_heads)

        attn_score = torch.einsum('bhlk,bhtk->bhlt', [q, k]) * self.scale
        attn = nn.functional.softmax(attn_score, dim=-1)
        attn = self.attn_drop(attn)

        x = torch.einsum('bhlt,bhtv->bhlv', [attn, v])
        x = rearrange(x, 'b h t d -> b t (h d)')

        x = self.proj(x)
        x = self.proj_drop(x)
        x = residual + self.drop_path(x)

        residual2 = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = residual2 + self.drop_path(x)
        return x

class CvTStage(nn.Module):
    def __init__(self, out_ch, depth, num_heads):
        super().__init__()
        self.dropout = nn.Dropout(0.0)
        self.blocks = nn.ModuleList([
            TransformerBlock(out_ch, num_heads) for _ in range(depth)
        ])

    def forward(self, x):
        x = self.dropout(x)
        for blk in self.blocks:
            x = blk(x)
        return x


class CvTConvolutionalEmbedding(nn.Module):
    def __init__(self, num_classes=200):
        super().__init__()
        self.num_classes = num_classes
        self.conv_embed = ConvEmbedding(3, 192, kernel_size=5, stride=2)
        self.stage1 = CvTStage(192, depth=1, num_heads=3)
        self.stage2 = CvTStage(192, depth=2, num_heads=3)
        self.stage3 = CvTStage(192, depth=10, num_heads=3)

        self.norm = nn.LayerNorm(192)
        self.head = nn.Linear(192, num_classes) 


    def forward(self, x):
        x = self.conv_embed(x)

        x1 = self.stage1(x)

        x2 = self.stage2(x1)

        x3 = self.stage3(x2)

        x = self.norm(x3)
        x = x.mean(dim=1)
        return self.head(x)

## Testing

In [None]:
model = CvTConvolutionalEmbedding()

dummy_input = torch.randn(8, 3, 64, 64)
output = model(dummy_input)

assert output.shape == (8, 200), f"Expected output shape (8, 200), but got {output.shape}"
print("Model output shape is as expected:", output.shape)

dummy_input = torch.randn(1, 3, 64, 64)
output = model(dummy_input)

assert output.shape == (1, 200), f"Expected output shape (1, 200), but got {output.shape}"
print("Model output shape is as expected:", output.shape)


# Dataset

In [None]:
from models.processData import prepare_data_and_get_loaders

train_loader, val_loader, test_loader = prepare_data_and_get_loaders("/datasets/tiny-imagenet-200/tiny-imagenet-200.zip", "data/tiny-imagenet-200")

# Training

In [None]:
from models.trainModel import train_test_model

train_test_model(CvTConvolutionalEmbedding, train_loader, val_loader, test_loader)