#### Implementation of DEIT from scratch

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from sentry_sdk.utils import epoch


if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [2]:
#Patch Embedding

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels, embedding_dim, patch_size):
        super(PatchEmbedding, self).__init__()
        self.conv2d = nn.Conv2d(
            in_channels, embedding_dim, kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
         x = self.conv2d(x)
         x = x.flatten(2)
         return x.transpose(1, 2)

In [3]:
class DEIT(nn.Module):
    def __init__(self, img_size, patch_size, in_channels, num_classes, embedding_dim, depth, ff_dim, dropout , n_head):
        super(DEIT, self).__init__()

        self.patch_embed = PatchEmbedding(in_channels, embedding_dim, patch_size)
        num_patches = (img_size // patch_size) ** 2

        #CLS TOKEN
        self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim) * 0.02)
        self.dist_token = nn.Parameter(torch.randn(1, 1, embedding_dim) * 0.02)

        #Postional Embedding
        self.pos_embed  = nn.Parameter(torch.randn(1, num_patches + 2, embedding_dim) * 0.02)
        self.dropout = nn.Dropout(dropout)

        #Transformer  encoder

        encoder_layer = nn.TransformerEncoderLayer(embedding_dim, n_head, ff_dim, dropout, activation= 'gelu', batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)

        self.norm = nn.LayerNorm(embedding_dim)

        #Two Head DEiT
        self.head_cls = nn.Linear(embedding_dim, num_classes)
        self.head_dist = nn.Linear(embedding_dim, num_classes)

    def forward(self, x):
        # PATCH EMBEDDING
        x = self.patch_embed(x)

        B = x.size(0)
        cls_token = self.cls_token.expand(B, -1, -1)
        dist_token = self.dist_token.expand(B, -1, -1)
        # CONCAT TOKENS
        x = torch.cat([cls_token, dist_token, x], dim=1)
        # POSITIONAL EMBEDDING
        pos = self.pos_embed[:, :x.size(1), :]
        x = x + pos
        # DROPOUT
        x = self.dropout(x)

        # TRANSFORMER
        x = self.transformer(x)
        x = self.norm(x)
        cls_out = self.head_cls(x[:, 0])
        dist_out = self.head_dist(x[:, 1])
        return cls_out, dist_out


In [4]:
def deit_loss(cls_out, dist_out, teacher_logits, labels, T=4.0, alpha=0.95):

    ce_loss = F.cross_entropy(cls_out, labels)
    student_log_probs = F.log_softmax(dist_out / T, dim=-1)
    teacher_probs = F.softmax(teacher_logits.detach() / T, dim=-1)

    dist_loss = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean') * (T * T)

    return alpha * dist_loss + (1 - alpha) * ce_loss, ce_loss, dist_loss

In [5]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
transform = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor()
])

train_data = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
test_data  = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader  = DataLoader(test_data, batch_size=128)


In [6]:
# #Define Teacher Model
# import torchvision.models as models
#
# teacher = models.resnet50(weights=None)
# teacher.fc = nn.Linear(2048, 10)
# teacher = teacher.to(device)
#
# epochs = 10
# optimizer = torch.optim.Adam(teacher.parameters(), lr = 3e-4)
#
# for epoch in range(epochs):
#     for images, labels in train_loader:
#         images, labels = images.to(device), labels.to(device)
#         optimizer.zero_grad()
#         logits = teacher(images)
#         loss = F.cross_entropy(logits, labels)
#         loss.backward()
#         optimizer.step()


In [7]:
import torchvision.models as models
teacher = models.resnet50(weights=None)
teacher.fc = nn.Linear(2048, 10)

teacher.load_state_dict(torch.load("teacher_cifar10.pth"))
teacher.eval()

for p in teacher.parameters():
    p.requires_grad = False
teacher.to(device)


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [8]:
def eval(model, loader, t_model):
    model.eval()
    teacher.eval()

    total_loss = 0
    total_ce_loss = 0
    total_dist_loss = 0
    correct = 0
    total = 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        teacher_prob = torch.softmax(t_model(images), dim=-1)
        teacher_prob = teacher_prob.to(device)
        cls_out, dist_out  = model(images)
        loss, ce_loss, dist_loss = deit_loss(cls_out, dist_out, teacher_prob, labels)
        total_loss += loss.item()
        total_ce_loss += ce_loss.item()
        total_dist_loss += dist_loss.item()

        preds = cls_out.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    average_loss = total_loss / len(loader)
    average_ce_loss = total_ce_loss / len(loader)
    average_dist_loss = total_dist_loss / len(loader)
    accuracy = correct / total

    print(f'Eval Loss = {average_loss}, CE = {average_ce_loss}, Dist = {average_dist_loss}, Accuracy = {accuracy}')
    return average_loss, accuracy
def training_func(model, t_model, optimizer, scheduler, train_loader, valid_loader, epochs):
    history = {'train_loss': [], 'ce_loss': [], 'dist_loss': [], 'train_accuracy': [],'validation_accuracy': [] }
    t_model.eval()
    for e in range(epochs):
        model.train()
        total_loss = 0
        total_ce_loss = 0
        total_dist_loss = 0
        correct = 0
        total = 0

        for image, labels in train_loader:
            images, labels = image.to(device), labels.to(device)
            with torch.no_grad():
                teacher_prob = t_model(images)
                teacher_prob = teacher_prob.to(device)
            optimizer.zero_grad()
            cls_out, dist_out = model(images)
            loss, ce_loss, dist_loss = deit_loss(cls_out, dist_out, teacher_prob, labels, T = 4, alpha = 0.9)
            loss.backward()
            optimizer.step()

            total_loss += loss.detach().item()
            total_ce_loss += ce_loss.item()
            total_dist_loss += dist_loss.item()

            preds = cls_out.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        epoch_loss = total_loss / len(train_loader)
        epoch_ce_loss = total_ce_loss / len(train_loader)
        epoch_dist_loss = total_dist_loss / len(train_loader)
        train_accuracy = correct / total

        _, val_accuracy = eval(model, valid_loader, t_model)

        if scheduler is not None:
            scheduler.step()

        history['train_loss'].append(epoch_loss)
        history['ce_loss'].append(epoch_ce_loss)
        history['dist_loss'].append(epoch_dist_loss)
        history['train_accuracy'].append(train_accuracy)
        history['validation_accuracy'].append(val_accuracy)


        print(f'Epoch {e + 1}/ {epochs}, '
                  f'Loss: {epoch_loss:.4f}, '
                  f'Dist: {epoch_ce_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Validation Accuracy: {val_accuracy:.4f}')
    return history


In [9]:
student = DEIT(img_size= 32, patch_size= 4, num_classes= 10, embedding_dim= 192, depth= 4, n_head= 4, ff_dim= 384, in_channels=3 , dropout= 0.1).to(device)

images = torch.randn(2,3,32,32).to(device)
cls, dist = student(images)
loss = (cls.mean() + dist.mean())
loss.backward()

print(student.head_cls.weight.grad is not None)
print(student.head_dist.weight.grad is not None)


True
True


In [10]:
import math

class WarmupCosineScheduler(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup_epochs, max_epochs, warmup_start_lr=1e-6, eta_min=1e-5):
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs
        self.warmup_start_lr = warmup_start_lr
        self.eta_min = eta_min
        super().__init__(optimizer)

    def get_lr(self):
        if self.last_epoch < self.warmup_epochs:

            return [self.warmup_start_lr + (base_lr - self.warmup_start_lr) * (self.last_epoch / self.warmup_epochs)
                    for base_lr in self.base_lrs]
        else:
            progress = (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)
            return [self.eta_min + (base_lr - self.eta_min) * 0.5 * (1 + math.cos(math.pi * progress))
                    for base_lr in self.base_lrs]

warmup_epochs = 5
total_epochs = 50


In [11]:

optimizer = torch.optim.AdamW(
    student.parameters(),
    lr=5e-4,
    betas=(0.9, 0.999),
    weight_decay=0.05
)
scheduler = WarmupCosineScheduler(
    optimizer,
    warmup_epochs=warmup_epochs,
    max_epochs=total_epochs,
    warmup_start_lr=1e-6,
    eta_min=1e-5
)
training_func(student, teacher, optimizer, scheduler= scheduler, train_loader= train_loader, valid_loader= test_loader, epochs= total_epochs)
eval(student, test_loader, teacher)

Eval Loss = 0.23189290867576115, CE = 2.3428510804719562, Dist = 0.12078984732492061, Accuracy = 0.0626
Epoch 1/ 50, Loss: 4.9811, Dist: 2.3839, Train Accuracy: 0.0858, Validation Accuracy: 0.0626
Eval Loss = 2.882244399831265, CE = 1.5561893725696998, Dist = 2.9520367851740197, Accuracy = 0.4272
Epoch 2/ 50, Loss: 3.0129, Dist: 1.8148, Train Accuracy: 0.3150, Validation Accuracy: 0.4272
Eval Loss = 3.7102090407021437, CE = 1.2807800392561322, Dist = 3.838073769702187, Accuracy = 0.5304
Epoch 3/ 50, Loss: 1.9506, Dist: 1.4309, Train Accuracy: 0.4720, Validation Accuracy: 0.5304
Eval Loss = 3.6354277918610394, CE = 1.1834615484068665, Dist = 3.764478662345983, Accuracy = 0.5653
Epoch 4/ 50, Loss: 1.5480, Dist: 1.2610, Train Accuracy: 0.5407, Validation Accuracy: 0.5653
Eval Loss = 3.85116297987443, CE = 1.1155788770204857, Dist = 3.995141144040265, Accuracy = 0.5933
Epoch 5/ 50, Loss: 1.3497, Dist: 1.1652, Train Accuracy: 0.5785, Validation Accuracy: 0.5933
Eval Loss = 3.966380565981322

(4.754745791230021, 0.7676)

#### How to improve the current 76% accuracy

To match or beat the original DeIT-Tiny Google paper:

1. Zero-init cls_token, trunc_normal pos_embed
2. Add LayerScale (1e-6) to each block
3. Add stochastic depth (linear drop from 0.0 → 0.1–0.2)
4. Train 200–300 epochs with cosine scheduler + warm restarts
5. Use weight_decay=0.05 and AdamW
6. Switch to relative position bias or DeiT-III improvements
