#### Implementation of DEIT from scratch

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from sentry_sdk.utils import epoch


if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [2]:
#Patch Embedding

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels, embedding_dim, patch_size):
        super(PatchEmbedding, self).__init__()
        self.conv2d = nn.Conv2d(
            in_channels, embedding_dim, kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
         x = self.conv2d(x)
         x = x.flatten(2)
         return x.transpose(1, 2)

In [3]:
class DEIT(nn.Module):
    def __init__(self, img_size, patch_size, in_channels, num_classes, embedding_dim, depth, ff_dim, dropout , n_head):
        super(DEIT, self).__init__()

        self.patch_embed = PatchEmbedding(in_channels, embedding_dim, patch_size)
        num_patches = (img_size // patch_size) ** 2

        #CLS TOKEN
        self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim) * 0.02)
        self.dist_token = nn.Parameter(torch.randn(1, 1, embedding_dim) * 0.02)

        #Postional Embedding
        self.pos_embed  = nn.Parameter(torch.randn(1, num_patches + 2, embedding_dim) * 0.02)
        self.dropout = nn.Dropout(dropout)

        #Transformer  encoder

        encoder_layer = nn.TransformerEncoderLayer(embedding_dim, n_head, ff_dim, dropout, activation= 'gelu', batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)

        self.norm = nn.LayerNorm(embedding_dim)

        #Two Head DEiT
        self.head_cls = nn.Linear(embedding_dim, num_classes)
        self.head_dist = nn.Linear(embedding_dim, num_classes)

    def forward(self, x):
        print("\n---- FORWARD DEBUG ----")
        print("input:", x.shape)  # <-- print batch first

        # PATCH EMBEDDING
        x = self.patch_embed(x)
        print("after patch_embed:", x.shape)

        B = x.size(0)
        cls_token = self.cls_token.expand(B, -1, -1)
        dist_token = self.dist_token.expand(B, -1, -1)
        print("cls_token:", cls_token.shape)
        print("dist_token:", dist_token.shape)

        # CONCAT TOKENS
        x = torch.cat([cls_token, dist_token, x], dim=1)
        print("after concat:", x.shape)

        # POSITIONAL EMBEDDING
        pos = self.pos_embed[:, :x.size(1), :]
        print("pos_embed:", pos.shape)

        x = x + pos
        print("after +pos:", x.shape)

        # DROPOUT
        x = self.dropout(x)

        # TRANSFORMER
        x = self.transformer(x)
        print("after transformer:", x.shape)

        # LAYER NORM
        x = self.norm(x)
        print("after norm:", x.shape)

        cls_out = self.head_cls(x[:, 0])
        dist_out = self.head_dist(x[:, 1])

        print("cls_out:", cls_out.shape)
        print("dist_out:", dist_out.shape)
        print("-----------------------\n")

        return cls_out, dist_out


In [4]:
#Function to calculate total loss
def deit_loss(cls_out, dist_out, teacher_probs, labels, T = 1.0):
    ce_loss = F.cross_entropy(cls_out, labels)

    #Distillation
    k1_loss = torch.nn.functional.kl_div(
        F.log_softmax(dist_out/ T, dim=-1),
        teacher_probs,
        reduction ='batchmean'
    ) * (T * T)
    return ce_loss + k1_loss, ce_loss, k1_loss

In [5]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
transform = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor()
])

train_data = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
test_data  = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader  = DataLoader(test_data, batch_size=128)


In [6]:
# #Define Teacher Model
# import torchvision.models as models
#
# teacher = models.resnet50(weights=None)
# teacher.fc = nn.Linear(2048, 10)
# teacher = teacher.to(device)
#
# epochs = 10
# optimizer = torch.optim.Adam(teacher.parameters(), lr = 3e-4)
#
# for epoch in range(epochs):
#     for images, labels in train_loader:
#         images, labels = images.to(device), labels.to(device)
#         optimizer.zero_grad()
#         logits = teacher(images)
#         loss = F.cross_entropy(logits, labels)
#         loss.backward()
#         optimizer.step()


In [7]:
import torchvision.models as model
teacher = model.resnet50(weights=None)
teacher.fc = nn.Linear(2048, 10)

teacher.load_state_dict(torch.load("teacher_cifar10.pth"))
teacher.eval()

for p in teacher.parameters():
    p.requires_grad = False
teacher.to(device)


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [8]:
def eval(model, loader, t_model):
    model.eval()
    teacher.eval()

    total_loss = 0
    total_ce_loss = 0
    total_dist_loss = 0
    correct = 0
    total = 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        teacher_prob = torch.softmax(t_model(images), dim=-1)
        cls_out, dist_out  = model(images)
        loss, ce_loss, dist_loss = deit_loss(cls_out, dist_out, teacher_prob, labels)
        total_loss += loss.item()
        total_ce_loss += ce_loss.item()
        total_dist_loss += dist_loss.item()

        preds = cls_out.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    average_loss = total_loss / len(loader)
    average_ce_loss = total_ce_loss / len(loader)
    average_dist_loss = total_dist_loss / len(loader)
    accuracy = correct / total

    print(f'Eval Loss = {average_loss}, CE = {average_ce_loss}, Dist = {average_dist_loss}, Accuracy = {accuracy}')
    return average_loss, accuracy
def training_func(model, t_model, optimizer, scheduler, train_loader, valid_loader, epochs):
    history = {'train_loss': [], 'ce_loss': [], 'dist_loss': [], 'train_accuracy': [], 'validation_accuracy': [] }
    t_model.eval()
    for e in range(epochs):
        model.train()
        total_loss = 0
        total_ce_loss = 0
        total_dist_loss = 0
        correct = 0
        total = 0

        for image, labels in train_loader:
            images, labels = image.to(device), labels.to(device)
            with torch.no_grad():
                teacher_prob = torch.softmax(t_model(images), dim=-1)
                optimizer.zero_grad()
                cls_out, dist_out = model(images)
                loss, ce_loss, dist_loss = deit_loss(cls_out, dist_out, teacher_prob, labels)
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                total_ce_loss += ce_loss.item()
                total_dist_loss += dist_loss.item()

                preds = cls_out.argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

            epoch_loss = total_loss / len(train_loader)
            epoch_ce_loss = total_ce_loss / len(train_loader)
            epoch_dist_loss = total_dist_loss / len(train_loader)
            train_accuracy = correct / total

            _, val_accuracy = eval(model, valid_loader, t_model)

            if scheduler is not None:
                scheduler.step()

            history['train_loss'].append(epoch_loss)
            history['train_ce_loss'].append(epoch_ce_loss)
            history['train_dist_loss'].append(epoch_dist_loss)
            history['train_accuracy'].append(train_accuracy)
            history['validation_loss'].append(epoch_loss)


            print(f'Epoch {e + 1}/ {epochs}, '
                  f'Loss: {epoch_loss:.4f}, '
                  f'Dist: {epoch_ce_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Validation Accuracy: {val_accuracy:.4f}')
    return history


In [9]:
student = DEIT(img_size= 32, patch_size= 4, num_classes= 10, embedding_dim= 192, depth= 4, n_head= 3, ff_dim= 384, in_channels=3 , dropout= 0.1).to(device)

images = torch.randn(2,3,32,32).to(device)
cls, dist = student(images)
loss = (cls.mean() + dist.mean())
loss.backward()

print(student.head_cls.weight.grad is not None)
print(student.head_dist.weight.grad is not None)



---- FORWARD DEBUG ----
input: torch.Size([2, 3, 32, 32])
after patch_embed: torch.Size([2, 64, 192])
cls_token: torch.Size([2, 1, 192])
dist_token: torch.Size([2, 1, 192])
after concat: torch.Size([2, 66, 192])
pos_embed: torch.Size([1, 66, 192])
after +pos: torch.Size([2, 66, 192])
after transformer: torch.Size([2, 66, 192])
after norm: torch.Size([2, 66, 192])
cls_out: torch.Size([2, 10])
dist_out: torch.Size([2, 10])
-----------------------

True
True




In [12]:
images, labels = next(iter(train_loader))
images = images.to(device)
labels = labels.to(device)

cls_out, dist_out = student(images)



---- FORWARD DEBUG ----
input: torch.Size([64, 3, 32, 32])
after patch_embed: torch.Size([64, 64, 192])
cls_token: torch.Size([64, 1, 192])
dist_token: torch.Size([64, 1, 192])
after concat: torch.Size([64, 66, 192])
pos_embed: torch.Size([1, 66, 192])
after +pos: torch.Size([64, 66, 192])
after transformer: torch.Size([64, 66, 192])
after norm: torch.Size([64, 66, 192])
cls_out: torch.Size([64, 10])
dist_out: torch.Size([64, 10])
-----------------------



In [13]:
cls_out, dist_out = student(images)

print("cls_out:", cls_out.requires_grad)
print("dist_out:", dist_out.requires_grad)

teacher_prob = torch.softmax(teacher(images), dim=-1)
loss, ce_loss, dist_loss = deit_loss(cls_out, dist_out, teacher_prob, labels)

print("loss:", loss.requires_grad)
print("ce_loss:", ce_loss.requires_grad)
print("dist_loss:", dist_loss.requires_grad)



---- FORWARD DEBUG ----
input: torch.Size([64, 3, 32, 32])
after patch_embed: torch.Size([64, 64, 192])
cls_token: torch.Size([64, 1, 192])
dist_token: torch.Size([64, 1, 192])
after concat: torch.Size([64, 66, 192])
pos_embed: torch.Size([1, 66, 192])
after +pos: torch.Size([64, 66, 192])
after transformer: torch.Size([64, 66, 192])
after norm: torch.Size([64, 66, 192])
cls_out: torch.Size([64, 10])
dist_out: torch.Size([64, 10])
-----------------------

cls_out: True
dist_out: True
loss: True
ce_loss: True
dist_loss: True


In [14]:

optimizer = torch.optim.AdamW(student.parameters(), lr= 3e-4)
training_func(student, teacher, optimizer, scheduler= None, train_loader= train_loader, valid_loader= test_loader, epochs= 10)
eval(student, test_loader, teacher)


---- FORWARD DEBUG ----
input: torch.Size([64, 3, 32, 32])
after patch_embed: torch.Size([64, 64, 192])
cls_token: torch.Size([64, 1, 192])
dist_token: torch.Size([64, 1, 192])
after concat: torch.Size([64, 66, 192])
pos_embed: torch.Size([1, 66, 192])
after +pos: torch.Size([64, 66, 192])
after transformer: torch.Size([64, 66, 192])
after norm: torch.Size([64, 66, 192])
cls_out: torch.Size([64, 10])
dist_out: torch.Size([64, 10])
-----------------------



RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn