<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Transformers_for_Vision_(Vision_Transformers_ViT).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ViT(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000, embed_dim=768, num_heads=12, num_layers=12):
        super(ViT, self).__init__()
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embedding = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads)
        self.transformer = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, num_classes)
        )
        self.dropout = nn.Dropout(p=0.1)

        # Initialize embeddings
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.pos_embedding, std=0.02)

    def forward(self, x):
        B = x.shape[0]
        x = self.proj(x).flatten(2).transpose(1, 2)  # Shape: (batch_size, num_patches, embed_dim)
        cls_tokens = self.cls_token.expand(B, -1, -1)  # Shape: (batch_size, 1, embed_dim)
        x = torch.cat((cls_tokens, x), dim=1)  # Concatenate class token
        x = self.dropout(x + self.pos_embedding)  # Add position embeddings
        x = self.transformer(x)  # Pass through Transformer encoder
        x = self.mlp_head(x[:, 0])  # Classification head on [CLS] token
        return x

# Example usage
model = ViT()
dummy_input = torch.randn(1, 3, 224, 224)  # Simulated input
output = model(dummy_input)
print(output.shape)  # Output shape should be (1, 1000)