In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm  # For progress bars

# Use GPU if available, else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using: {device}")

Using: cpu


In [5]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2  # Number of patches

        # Use Conv2d to split image into patches
        self.proj = nn.Conv2d(
            in_channels=in_channels,       # Input channels (3 for RGB)
            out_channels=embed_dim,        # Output dimension (D)
            kernel_size=patch_size,        # Patch size (e.g., 16x16)
            stride=patch_size,             # Non-overlapping patches
        )

    def forward(self, x):
        # Input: (B, C, H, W) → (Batch, Channels, Height, Width)
        x = self.proj(x)  # (B, D, H/P, W/P) → (B, 768, 14, 14) if img_size=224, patch_size=16
        x = x.flatten(2)  # Flatten patches into sequence → (B, D, N_patches)
        x = x.transpose(1, 2)  # (B, N_patches, D) → Now it's like a sequence!
        return x

In [2]:
# Define transformations
transform = transforms.Compose([
    transforms.Resize(224),                 # Resize to 224x224
    transforms.ToTensor(),                  # Convert to tensor
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to [-1, 1]
])

# Load CIFAR-10 datasets
train_data = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_data = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Create dataloaders (batches of 64 images)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

100%|██████████| 170M/170M [00:10<00:00, 15.9MB/s]


In [3]:
class PositionalEncoding(nn.Module):
    def __init__(self, n_patches, embed_dim):
        super().__init__()
        # Learnable parameter: (1, N+1, D) (+1 for [CLS] token)
        self.pos_embed = nn.Parameter(torch.randn(1, n_patches + 1, embed_dim))

    def forward(self, x):
        return x + self.pos_embed  # Add position info to patches

In [4]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim=768, n_heads=12):
        super().__init__()
        self.embed_dim = embed_dim
        self.n_heads = n_heads
        self.head_dim = embed_dim // n_heads  # 768 / 12 = 64

        # Linear layers to compute Q, K, V in one go
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)  # Final projection

    def forward(self, x):
        B, N, D = x.shape
        # Compute Q, K, V (each B, N, D) → split into heads
        qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # Each (B, n_heads, N, head_dim)

        # Attention scores (B, n_heads, N, N)
        attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn = attn.softmax(dim=-1)

        # Weighted sum of values
        x = (attn @ v).transpose(1, 2).reshape(B, N, D)
        x = self.proj(x)  # Final linear layer
        return x

In [6]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim=768, n_heads=12, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, n_heads)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),  # Expand
            nn.GELU(),                                         # Activation
            nn.Linear(int(embed_dim * mlp_ratio), embed_dim),  # Compress
        )

    def forward(self, x):
        # 1. LayerNorm → MSA → Residual
        x = x + self.attn(self.norm1(x))
        # 2. LayerNorm → MLP → Residual
        x = x + self.mlp(self.norm2(x))
        return x

In [7]:
class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, n_classes=1000,
                 embed_dim=768, depth=12, n_heads=12):
        super().__init__()
        # 1. Split image into patches
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)

        # 2. [CLS] token (for classification)
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))

        # 3. Positional embeddings
        self.pos_embed = PositionalEncoding(self.patch_embed.n_patches, embed_dim)

        # 4. Stack of Transformer blocks
        self.blocks = nn.Sequential(*[
            TransformerBlock(embed_dim, n_heads)
            for _ in range(depth)
        ])

        # 5. Final normalization & classification head
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, n_classes)

    def forward(self, x):
        B = x.shape[0]  # Batch size

        # 1. Patch embedding (B, N, D)
        x = self.patch_embed(x)

        # 2. Add [CLS] token (B, 1, D)
        cls_token = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_token, x], dim=1)  # (B, N+1, D)

        # 3. Add positional embeddings
        x = self.pos_embed(x)

        # 4. Transformer blocks
        x = self.blocks(x)

        # 5. Take [CLS] token and classify
        x = self.norm(x[:, 0])  # Only [CLS] token (B, D)
        x = self.head(x)         # (B, n_classes)
        return x

In [9]:
if __name__ == "__main__":
    vit = VisionTransformer(img_size=224, patch_size=16, n_classes=10)
    img = torch.randn(1, 3, 224, 224)  # Fake image
    out = vit(img)
    print(out.shape)  # (1, 10) → 10-class prediction

torch.Size([1, 10])


In [8]:
model = VisionTransformer(
    img_size=224,       # Resized CIFAR-10 images
    patch_size=16,       # Smaller patches for smaller images
    in_channels=3,       # RGB
    n_classes=10,        # CIFAR-10 has 10 classes
    embed_dim=768,       # From original ViT
    depth=6,             # Fewer layers for faster training (original: 12)
    n_heads=12,          # As in original ViT
).to(device)

In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Lower LR for stability

In [11]:
def train(model, dataloader, criterion, optimizer, epochs=5):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")

        for images, labels in progress_bar:
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update progress
            running_loss += loss.item()
            progress_bar.set_postfix(loss=running_loss / len(dataloader))

        print(f"Epoch {epoch+1}, Loss: {running_loss / len(dataloader):.4f}")

In [None]:
# Train for 5 epochs
train(model, train_loader, criterion, optimizer, epochs=1)

Epoch 1/1:   5%|▍         | 37/782 [40:48<13:54:58, 67.25s/it, loss=0.128]

In [None]:
def test(model, dataloader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Testing"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")

In [None]:
# Evaluate
test(model, test_loader)