<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Transformers_for_Vision_(Vision_Transformers_ViT).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn

class ViT(nn.Module):
    def __init__(self, img_size, patch_size, num_classes, d_model, num_heads, num_layers):
        super(ViT, self).__init__()
        assert img_size % patch_size == 0, "Image size must be divisible by patch size"
        self.patch_size = patch_size
        self.d_model = d_model
        self.num_patches = (img_size // patch_size) ** 2

        # Layers
        self.embedding = nn.Linear(patch_size**2 * 3, d_model)  # Patch to d_model
        self.position_embeddings = nn.Parameter(torch.randn(1, self.num_patches, d_model))
        self.transformer = nn.Transformer(d_model, num_heads, num_layers, batch_first=True)
        self.classifier = nn.Linear(d_model, num_classes)

    def forward(self, x):
        # Extract patches
        x = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        x = x.permute(0, 2, 3, 1, 4, 5).contiguous()
        x = x.view(x.size(0), self.num_patches, -1)  # Batch x Patches x Flattened Patch

        # Add positional embeddings
        x = self.embedding(x) + self.position_embeddings  # [batch_size, num_patches, d_model]

        # Pass through Transformer
        x = self.transformer(x, x)  # [batch_size, num_patches, d_model]

        # Classification token (mean pooling)
        x = x.mean(dim=1)  # Global representation
        return self.classifier(x)

# Define model and example input
model = ViT(img_size=32, patch_size=8, num_classes=10, d_model=64, num_heads=4, num_layers=6)
img = torch.randn(1, 3, 32, 32)  # Example image batch
logits = model(img)
print("Logits:", logits)