# Data-Efficient Image Transformer (DeiT) from Scratch

This notebook implements **DeiT (Data-efficient Image Transformers)**. While standard Vision Transformers (ViT) usually require massive datasets (like JFT-300M) to outperform CNNs, DeiT introduces a training strategy that allows Transformers to train effectively on smaller datasets (like ImageNet or even MNIST) by leaning on a strong "Teacher" network.

### Key Concepts:
1.  **Teacher-Student Distillation:** We use a pre-trained strong CNN (ResNet50) as a "Teacher".
2.  **The Distillation Token:** Unlike a standard ViT which has one class token `[CLS]`, DeiT adds a second token `[DIST]`. 
    * The `[CLS]` token learns from the true labels (Ground Truth).
    * The `[DIST]` token learns to mimic the Teacher's predictions.
3.  **Hard vs Soft Distillation:** We will implement **Soft Distillation** using KL-Divergence, minimizing the difference between the Student's and Teacher's output distributions.



## 1. Imports and Configuration

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader, Subset
from torchvision import transforms, datasets, models
import numpy as np
import matplotlib.pyplot as plt

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

In [None]:
# Hyperparameters
BATCH_SIZE = 64
IMG_SIZE = 28
PATCH_SIZE = 7
CHANNELS = 3          # We repeat MNIST channels to 3 to match ResNet input expectations
CLASSES = 10

# Model Params
EMBED_DIM = 32
ATTENTION_HEADS = 4
TRANSFORMER_LAYERS = 4

# Training Params
EPOCHS_STUDENT = 5
LR_STUDENT = 0.001
LR_TEACHER = 0.001

# Distillation Params
TEMPERATURE = 2       # Softens the probability distributions (higher = softer)
ALPHA = 0.5           # Weighting: 0.5 * Student_Loss + 0.5 * Distillation_Loss

## 2. Data Preparation
We use MNIST for this demonstration. 

**Note:** ResNet (our teacher) expects 3-channel images. Since MNIST is grayscale (1 channel), we use a lambda transform to repeat the channel 3 times.

In [None]:
tfm = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda t: t.repeat(3, 1, 1)), # Convert 1x28x28 -> 3x28x28
])

train_full = datasets.MNIST('./data', train=True, download=True, transform=tfm)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=tfm)

# Use a subset for faster demonstration if needed, currently using 100% of data
n = int(1.0 * len(train_full))
subset_idx = np.random.permutation(len(train_full))[:n]
train_dataset = Subset(train_full, subset_idx)

train_dl = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
val_dl = DataLoader(test_dataset, batch_size=BATCH_SIZE, drop_last=True)

## 3. The Teacher Model (ResNet50)

In Knowledge Distillation, the Teacher must be a strong model. We use a **ResNet50** pre-trained on ImageNet.

Since ResNet outputs 1000 classes (ImageNet), we replace the final layer to output 10 classes (MNIST). We then fine-tune it briefly so it actually knows how to classify digits. The Student will later try to "clone" this knowledge.

In [None]:
# Load Pre-trained ResNet50
teacher = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)

# Modify the final layer for MNIST (10 classes)
teacher.fc = nn.Linear(teacher.fc.in_features, CLASSES)
teacher.to(device)

# Freeze all layers except the last classification layer for efficiency
for param in teacher.parameters():
    param.requires_grad = False
for param in teacher.fc.parameters():
    param.requires_grad = True

print("Teacher model ready.")

In [None]:
# Fine-tune the Teacher
optimizer_teacher = torch.optim.Adam(teacher.fc.parameters(), lr=LR_TEACHER)
criterion_teacher = nn.CrossEntropyLoss()

EPOCHS_TEACHER = 3  # A few epochs are enough for the teacher to learn MNIST

print("Start Teacher Fine-tuning...")
for epoch in range(EPOCHS_TEACHER):
    teacher.train()
    running_loss = 0.0
    for inputs, labels in train_dl:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer_teacher.zero_grad()
        outputs = teacher(inputs)
        loss = criterion_teacher(outputs, labels)
        loss.backward()
        optimizer_teacher.step()
        running_loss += loss.item()

    print(f"  Teacher Epoch {epoch+1}, Avg Loss: {running_loss / len(train_dl):.4f}")

# Evaluate Teacher
teacher.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in val_dl:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = teacher(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Teacher Accuracy: {100 * correct / total:.2f}%")

## 4. The Student Model (DeiT)

This is the standard Vision Transformer architecture with one crucial modification:

**The Input Sequence:**
Instead of `[Class Token, Patch 1, Patch 2, ...]`, DeiT uses:
`[Class Token, Distillation Token, Patch 1, Patch 2, ...]`

Both tokens interact with the image patches via Self-Attention, but they feed into separate heads at the end.

In [None]:
class PatchEmbed(nn.Module):
    """Splits image into patches and embeds them."""
    def __init__(self, channels=CHANNELS, embed_dim=EMBED_DIM, patch_size=PATCH_SIZE, img_size=IMG_SIZE):
        super().__init__()
        # We use a Convolution to handle splitting and embedding simultaneously
        self.proj = nn.Conv2d(in_channels=channels, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size)
        self.n_patches = (img_size // patch_size) ** 2

    def forward(self, x):
        # x: (Batch, Channels, Height, Width)
        x = self.proj(x)      # (Batch, Embed_Dim, H', W')
        x = x.flatten(2)      # (Batch, Embed_Dim, N_Patches)
        x = x.transpose(1, 2) # (Batch, N_Patches, Embed_Dim)
        return x

In [None]:
class DeiT(nn.Module):
    def __init__(self, embed_dim=EMBED_DIM, attention_heads=ATTENTION_HEADS, layers=TRANSFORMER_LAYERS, classes=CLASSES):
        super().__init__()
        
        self.patch_embed = PatchEmbed()
        
        # 1. Class Token (Standard ViT)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        # 2. Distillation Token (DeiT Specific)
        self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        # Position Embedding: N_Patches + 2 tokens (CLS + DIST)
        self.n_patches = self.patch_embed.n_patches
        self.pos_embed = nn.Parameter(torch.zeros(1, self.n_patches + 2, embed_dim))
        
        # Transformer Encoder Blocks
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=attention_heads, batch_first=True)
        self.transformer_blocks = nn.TransformerEncoder(encoder_layer, num_layers=layers)
        
        self.layernorm = nn.LayerNorm(embed_dim)
        
        # Two separate heads
        self.head_cls = nn.Linear(embed_dim, classes)     # For Ground Truth
        self.head_dist = nn.Linear(embed_dim, classes)    # For Teacher Prediction

    def forward(self, x):
        B = x.shape[0]
        
        # Embed Patches
        x = self.patch_embed(x)
        
        # Expand Tokens to batch size
        cls_tokens = self.cls_token.expand(B, -1, -1)
        dist_tokens = self.dist_token.expand(B, -1, -1)
        
        # Concatenate: [CLS, DIST, Patches]
        x = torch.cat((cls_tokens, dist_tokens, x), dim=1)
        
        # Add Position Embeddings
        x = x + self.pos_embed
        
        # Transformer Pass
        x = self.transformer_blocks(x)
        x = self.layernorm(x)
        
        # Extract specific tokens
        # Index 0 is CLS, Index 1 is DIST
        out_cls = x[:, 0]
        out_dist = x[:, 1]
        
        # Pass through respective heads
        logits_cls = self.head_cls(out_cls)
        logits_dist = self.head_dist(out_dist)
        
        return logits_cls, logits_dist

## 5. Distillation Loss Function

The total loss is a combination of two losses:
1.  **Student Loss (CrossEntropy):** Does the `[CLS]` token predict the correct digit?
2.  **Distillation Loss (KL Divergence):** Does the `[DIST]` token output the same probability distribution as the Teacher?

$$ Loss = \alpha \cdot \text{KL}(Student_{dist}, Teacher) + (1-\alpha) \cdot \text{CE}(Student_{cls}, Label) $$



In [None]:
def distillation_loss(student_cls_logits, student_dist_logits, teacher_logits, labels, alpha=ALPHA, temperature=TEMPERATURE):
    # 1. Standard Classification Loss (Ground Truth)
    loss_ce = F.cross_entropy(student_cls_logits, labels)
    
    # 2. Distillation Loss (Teacher Knowledge)
    # We soften the logits by dividing by temperature to reveal "dark knowledge" (relationships between classes)
    distillation = F.kl_div(
        F.log_softmax(student_dist_logits / temperature, dim=1),
        F.softmax(teacher_logits / temperature, dim=1),
        reduction='batchmean'
    ) * (temperature ** 2)
    
    # Combine losses
    total_loss = (alpha * distillation) + ((1 - alpha) * loss_ce)
    return total_loss

## 6. Training the Student

In [None]:
student = DeiT().to(device)
optimizer_student = torch.optim.Adam(student.parameters(), lr=LR_STUDENT)

print("Training Student (DeiT)...")
for epoch in range(EPOCHS_STUDENT):
    student.train()
    running_loss = 0.0
    
    for inputs, labels in train_dl:
        inputs, labels = inputs.to(device), labels.to(device)
        
        # 1. Get Teacher Predictions (No Gradient needed for Teacher)
        with torch.no_grad():
            teacher_logits = teacher(inputs)
            
        # 2. Get Student Predictions
        cls_logits, dist_logits = student(inputs)
        
        # 3. Calculate Loss
        loss = distillation_loss(cls_logits, dist_logits, teacher_logits, labels)
        
        # 4. Backprop
        optimizer_student.zero_grad()
        loss.backward()
        optimizer_student.step()
        
        running_loss += loss.item()
        
    print(f"Epoch {epoch+1}/{EPOCHS_STUDENT}, Loss: {running_loss/len(train_dl):.4f}")

## 7. Evaluation

During inference, DeiT typically averages the predictions of the `[CLS]` head and the `[DIST]` head for maximum accuracy.

In [None]:
student.eval()
correct = 0
total = 0
samples = []

with torch.no_grad():
    for inputs, labels in val_dl:
        inputs, labels = inputs.to(device), labels.to(device)
        
        cls_logits, dist_logits = student(inputs)
        
        # Inference Strategy: Average both heads
        avg_logits = (cls_logits + dist_logits) / 2
        predictions = avg_logits.argmax(dim=1)
        
        correct += (predictions == labels).sum().item()
        total += labels.size(0)
        
        # Save a few for visualization
        if len(samples) < 5:
            samples.append((inputs.cpu(), predictions.cpu(), labels.cpu()))

acc = 100 * correct / total
print(f"\nStudent (DeiT) Test Accuracy: {acc:.2f}%")

In [None]:
# Visualization
fig, axs = plt.subplots(1, len(samples), figsize=(15, 3))
for i, (img, pred, true) in enumerate(samples):
    # Convert Tensor (3, 28, 28) back to Numpy (28, 28, 3)
    img_np = img[0].permute(1, 2, 0).numpy()
    
    axs[i].imshow(img_np, cmap='gray' if CHANNELS==1 else None)
    axs[i].set_title(f"Pred: {pred[0].item()} | True: {true[0].item()}")
    axs[i].axis('off')
plt.show()