<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Vision_Transformer_(ViT).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

class VisionTransformer(nn.Module):
    def __init__(self, image_size, patch_size, num_classes, dim, depth=6, heads=8, mlp_dim=2048):
        super(VisionTransformer, self).__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        num_patches = (image_size // patch_size) ** 2
        patch_dim = patch_size * patch_size * 3

        self.linear_proj = nn.Linear(patch_dim, dim)

        encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=mlp_dim)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)

        self.mlp_head = nn.Linear(dim, num_classes)

    def forward(self, x):
        batch_size = x.size(0)
        x = x.view(batch_size, -1, self.patch_size * self.patch_size * 3)
        x = self.linear_proj(x)
        x = x.permute(1, 0, 2)  # Required for transformer: (seq_len, batch_size, dim)
        x = self.transformer(x)
        x = x.mean(dim=0)  # Global average pooling
        x = self.mlp_head(x)
        return x

# Example usage with Vision Transformer
vit = VisionTransformer(image_size=224, patch_size=16, num_classes=1000, dim=768)
input_data = torch.randn(32, 3, 224, 224)  # Example input
output = vit(input_data)

# Print the shape of the output
print(output.shape)  # Expected shape: [batch_size, num_classes]