<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Transformer_Variants.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from einops.layers.torch import Rearrange

# Define the SimpleViT class
class SimpleViT(nn.Module):
    def __init__(self, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim):
        super(SimpleViT, self).__init__()
        assert image_size % patch_size == 0, "Image size must be divisible by patch size."
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.dim = dim

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),  # Rearrange patches
            nn.Linear(patch_size * patch_size * 3, dim)  # Linear embedding
        )

        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))  # Class token
        self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, dim))  # Positional embedding

        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(dim, heads, mlp_dim), num_layers=depth
        )  # Transformer encoder

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)  # Classification head
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)  # Convert image to patch embeddings
        b, n, _ = x.shape

        cls_tokens = self.cls_token.expand(b, -1, -1)  # Expand class token for each batch
        x = torch.cat((cls_tokens, x), dim=1)  # Concatenate class token with patch embeddings
        x += self.pos_embedding[:, :(n + 1)]  # Add positional embedding

        x = self.transformer(x)  # Pass through transformer encoder
        x = x[:, 0]  # Select the class token
        return self.mlp_head(x)  # Pass through classification head

# Example usage
model = SimpleViT(image_size=32, patch_size=8, num_classes=10, dim=128, depth=6, heads=8, mlp_dim=256)  # Instantiate the model
img = torch.randn(64, 3, 32, 32)  # Example input image batch (batch_size=64, channels=3, height=32, width=32)
output = model(img)

# Print the shape of the output
print(output.shape)  # Expected shape: [64, 10]