- Full Name: **[Full Name]**
- Student ID: **[Stundet ID]**

# 🧠 Homework 4: Vision Transformer (ViT)

## 📌 Objective
In this assignment, you will **implement a Vision Transformer (ViT) from scratch** and gain a deep understanding of its core components. You will:

- 🛠️ Complete all core blocks of the Vision Transformer model.
- 🧪 Train the ViT on the CIFAR-10 image classification dataset.
- 👁️ Visualize attention maps to interpret model behavior.
- 📊 Analyze the effect of hyperparameters (e.g., patch size, depth, heads) on model performance.

---

## 📚 Learning Goals
By the end of this assignment, you should be able to:
- Explain the role of patch embeddings, self-attention, and MLP layers in ViT.
- Implement Transformer blocks without relying on high-level libraries like Hugging Face Transformers.
- Visualize self-attention maps to understand how ViT focuses on different parts of an image.
- Experiment with architectural design choices and evaluate their effects.

---

## 🧪 Dataset
We will use the **CIFAR-10** dataset, which consists of 60,000 32×32 color images in 10 classes, with 6,000 images per class. The dataset is split into 50,000 training and 10,000 test images.

---

## 🎯 Evaluation Criteria

Your final submission will be evaluated based on:
- ✅ Correct implementation of all model components.
- ✅ Accuracy of the model on CIFAR-10 test set.
- ✅ Insightfulness of attention visualizations.
- ✅ Clarity of code and documentation.
- ✅ Quality of hyperparameter analysis and discussion.





In [None]:
# ===============================
# 1.1 Setup & Imports
# ===============================
!pip install einops torch torchvision matplotlib seaborn --quiet

In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T
import matplotlib.pyplot as plt
import seaborn as sns
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using', device)


In [None]:

# ===============================
# 1.2 Data Loading & Visualization
# ===============================
transform = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
])

## To Do: Load CIFAR10:
train_ds = ...
test_ds = ...
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=64, shuffle=False)


In [None]:

## TO do: Visualize 8 images with their labels
classes = train_ds.classes
images, labels = next(iter(train_loader))
fig, axes = plt.subplots(1, 8, figsize=(15, 2))
for i in range(8):
    ...
plt.show()


In [None]:
# ===============================
# 1.3 Helper: pair
# ===============================
def pair(t):
    return t if isinstance(t, tuple) else (t, t)


In [None]:
# ===============================
# 1.4 FeedForward
# ===============================
## To Do: Complete the FeedForward module -> layer norm -> Linear -> GELU -> Dropout -> Linear -> Dropout
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            ...
        )
    def forward(self, x):
        return self.net(x)


In [None]:

# ===============================
# 1.5 Attention
# ===============================
## To Do: Complete the Attention module
class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = heads * dim_head
        self.heads = heads
        self.scale = dim_head ** -0.5
        ## layer norm
        self.norm = ...
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        ## softmax
        self.attend = ...
        self.dropout = nn.Dropout(dropout)
        self.to_out = ...

    def forward(self, x):
        x = self.norm(x)
        q, k, v = self.to_qkv(x).chunk(3, dim=-1)
        q = rearrange(q, 'b n (h d) -> b h n d', h=self.heads)
        k = rearrange(k, 'b n (h d) -> b h n d', h=self.heads)
        v = rearrange(v, 'b n (h d) -> b h n d', h=self.heads)

        ## q, k multiplication 
        dots = ...
        ## softmax
        attn = ...
        ### dropout
        attn = ...
        ## v multiplication 
        out = ...
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)


In [None]:

# ===============================
# 1.6 Transformer Block
# ===============================
## To Do: Complete the Transformer module
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([
            ## Attention + FeedForward
            nn.ModuleList([
                ..., 
                ...
            ]) for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        for attn, ff in self.layers:
            ## To do: Normal Transformer forward, nothing weird
        return self.norm(x)


In [None]:

# ===============================
# 1.7 ViT Model
# ===============================
class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes,
                 dim, depth, heads, mlp_dim,
                 pool='cls', channels=3,
                 dim_head=64, dropout=0., emb_dropout=0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)
        assert image_height % patch_height == 0 and image_width % patch_width == 0

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
                      p1=patch_height, p2=patch_width),
            nn.Linear(patch_dim, dim)
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    ## To do: complete the forward function
    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += ...
        x = self.dropout(x)

        x = ...

        x = ... if self.pool == 'cls' else ...
        return ...


In [None]:

# ===============================
# 1.8 Training Loop & Visualization
# ===============================
## To do: Complete the training loop
def train(model, epochs=10, lr=3e-4):
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()

    model.to(device)
    losses = []
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for imgs, labels in train_loader:
            ## Forward, loss and backward here
            
            
        losses.append(avg_loss)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

    plt.plot(losses)
    plt.title("Training Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.show()

    # Evaluate
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            preds = model(imgs).argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    print(f"Test Accuracy: {100 * correct / total:.2f}%")



In [None]:

# ===============================
# Run and enjoy!
# ===============================
model = ViT(
    image_size=32, patch_size=4, num_classes=10,
    dim=128, depth=6, heads=8, mlp_dim=256,
    pool='cls', channels=3, dim_head=64,
    dropout=0.1, emb_dropout=0.1
)

train(model, epochs=20)


---

## 🧠 Bonus (Optional)
- 🧩 Implement a variant of ViT using sinusoidal positional embeddings.
- 🎨 Visualize per-head attention maps across different layers.
- 🔁 Try training with different patch sizes and analyze effects.

---