# Original CvT-Model

<img src="./../CvT-Original.drawio.png?raw=1" alt="CvT-Modell mit Convolutional Embedding" title="CvT-Modell mit Convolutional Embedding" height="400" />

Dimensions sind ohne Batch-Size.

## Input-Dimensions

**Dimensions:** $H_0 = 64px, \quad W_0 = 64px, \quad C_0 = 3$ \
**Output-Shape:** `(3, 64, 64)`

## Conv2d

Berechnung Output-Dimensions:

$ \text{kernel size}\ k = 7, \quad \text{stride}\ s = 4, \quad \text{padding}\ p = 3 $ \
$ H_i = \frac{H_{i-1} + 2p - k}{s}\ + 1, \quad W_i = \frac{W_{i-1} + 2p - k}{s}\ + 1 $

**Output-Dimensions:** $H_1 = 16px, \quad W_1 = 16px, \quad C_1 = 64$ \
**Output-Shape:** `(64, 16, 16)`

## Flatten

**Output-Dimensions:** $H_1 W_1 \times C_1 = 16*16 \times 64$ \
**Output-Shape:** `(256, 64)`

## Conv Projection



## Multi-Head Attention

Berechnung der Query-, Key- und Value-Matrizen:

$X \in \mathbb{R}^{H_1 W_1 \times C_1}$ \
$d_k$ ist die Dimension der Value-, Query- und Key-Vektoren \
$W^Q, W^K, W^V \in \mathbb{R}^{C_1 \times d_k}$ \
$Q = XW^Q, \quad K = XW^K, \quad V = XW^V$

$d_k = 64$ \
$Q, K, V \in \mathbb{R}^{256 \times 64}$

**Output-Dimensions:** $256 \times 64$ \
**Output-Shape:** `(256, 64)`

## MLP

Expansion factor: $e = 4$

1. **Step:** Linear ➔ GELU ➔ Dropout
   
   **Output-Dimensions:** $256 \times 64 \times 4 = 256 \times 256$ \
   **Output-Shape:** `(256, 256)`

2. **Step:** Linear ➔ Dropout

    **Output-Dimensions:** $256 \times 256 \times 64 = 256 \times 64$ \
    **Output-Shape:** `(256, 64)`


# Imports

In [None]:
%pip install pytorch-lightning
%pip install torch torchvision
%pip install lightning
%pip install einops
%pip install dotenv

In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
from dotenv import load_dotenv
import torch
from einops import rearrange
import torch.nn as nn

IS_PAPERSPACE = os.getcwd().startswith('/notebooks')
dir_env = os.path.join(os.getcwd(), '.env') if IS_PAPERSPACE else os.path.join(os.getcwd(), '..', '.env')
_ = load_dotenv(dotenv_path=dir_env)

# Modell

In [None]:
class ConvEmbed(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super().__init__()
        self.proj = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2)
        self.norm = nn.LayerNorm(out_channels)

    def forward(self, x):
        x = self.proj(x)  # B x C x H x W
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)  # B x (H*W) x C // todo: check CvT normalisiert und flattet
        x = self.norm(x)
        # todo: transpose evtl hier?? 
        return x, H, W

class DWConv(nn.Module):
    def __init__(self, dim, kernel_size=3):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim)

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.transpose(1, 2).reshape(B, C, H, W)
        x = self.dwconv(x)
        x = x.flatten(2).transpose(1, 2)
        # print(x.shape)
        return x

class ConvAttention(nn.Module):
    def __init__(self, dim, heads=4, kernel_size=3):
        super().__init__()
        self.num_heads = heads
        self.head_dim = dim // heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)
        self.dwconv = DWConv(dim, kernel_size)

        self.attn_drop = nn.Dropout(0.1)
        self.proj_drop = nn.Dropout(0.1)

    def forward(self, x, H, W):
        B, N, C = x.shape
        # print(x.shape)
        qkv = self.qkv(self.dwconv(x, H, W)).reshape(B, N, 3, self.num_heads, self.head_dim)
        q, k, v = qkv.unbind(2)  # each: B, N, heads, head_dim

        q = q.transpose(1, 2)  # B, heads, N, head_dim
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        # B, heads, head_dim, N
        attn = attn.softmax(dim=-1) # todo: check dim  == B, heads, head_dim, N 
        attn = self.attn_drop(attn)

        out = (attn @ v).transpose(1, 2).reshape(B, N, C)
        # todo: check attn @ v
        # B, head_dim ,heads, head_dim
        out = self.proj(out)
        return self.proj_drop(out)

class TransformerBlock(nn.Module):
    def __init__(self, dim, heads, mlp_ratio=4.0, drop_path=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = ConvAttention(dim, heads)

        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(0.2),  # Adding dropout here
            nn.Linear(int(dim * mlp_ratio), dim)
        )

    def forward(self, x, H, W):
        x = x + self.attn(self.norm1(x), H, W)
        x = x + self.mlp(self.norm2(x))
        return x

class CvTStage(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, stride, depth, heads):
        super().__init__()
        self.embed = ConvEmbed(in_ch, out_ch, kernel_size, stride)
        self.blocks = nn.ModuleList([
            TransformerBlock(out_ch, heads) for _ in range(depth)
        ])

    def forward(self, x):
        x, H, W = self.embed(x)
        for blk in self.blocks:
            x = blk(x, H, W)
        return x, H, W

class CvTOriginal(nn.Module):
    def __init__(self, num_classes=200):
        super().__init__()
        self.num_classes = num_classes
        # todo: check diff kernel_sizes + stride
        self.stage1 = CvTStage(3, 64, kernel_size=5, stride=2, depth=1, heads=1)
        self.stage2 = CvTStage(64, 192, kernel_size=3, stride=2, depth=2, heads=3)
        self.stage3 = CvTStage(192, 384, kernel_size=3, stride=1, depth=12, heads=6)

        self.head = nn.Sequential(
            nn.LayerNorm(384),
            nn.Linear(384, num_classes)
        )

    def forward(self, x):
        x1, H1, W1 = self.stage1(x)
        x1_spatial = rearrange(x1, 'b (h w) c -> b c h w', h=H1, w=W1)

        x2, H2, W2 = self.stage2(x1_spatial)
        x2_spatial = rearrange(x2, 'b (h w) c -> b c h w', h=H2, w=W2)

        x3, _, _ = self.stage3(x2_spatial)
        x = x3.mean(dim=1)
        return self.head(x)

## Testing

In [None]:
model = CvTOriginal()

dummy_input = torch.randn(8, 3, 64, 64)
output = model(dummy_input)

assert output.shape == (8, 200), f"Expected output shape (8, 200), but got {output.shape}"
print("Model output shape is as expected:", output.shape)

dummy_input = torch.randn(1, 3, 64, 64)
output = model(dummy_input)

assert output.shape == (1, 200), f"Expected output shape (1, 200), but got {output.shape}"
print("Model output shape is as expected:", output.shape)


# Dataset

In [None]:
from models.processData import prepare_data_and_get_loaders

train_loader, val_loader, test_loader = prepare_data_and_get_loaders("/datasets/tiny-imagenet-200/tiny-imagenet-200.zip", "data/tiny-imagenet-200")

### Testing

In [None]:
def imshow(img):
    img = img / 2 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.axis('off')
    plt.show()

image, label = train_loader.dataset[0]
imshow(image)

# Training

In [None]:
from models.trainModel import train_test_model

train_test_model(model, train_loader, val_loader, test_loader)