#### Implementation of DEIT from scratch

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [None]:
#Patch Embedding

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels, embedding_dim, patch_size):
        super(PatchEmbedding, self).__init__()
        self.conv2d = nn.Conv2d(
            in_channels, embedding_dim, kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
         x = self.conv2d(x)
         x = x.flatten(start_dim=2)
         return x.transpose(1, 2)

In [None]:
class DEIT(nn.Module):
    def __init__(self, img_size, patch_size = 16, in_channels = 3, num_classes = 1000, embedding_dim = 768, depth = 12, ff_dim = 3071, dropout = 0.1, n_head = 8):
        super(DEIT, self).__init__()

        self.patch_embed = PatchEmbedding(embedding_dim,in_channels, patch_size)
        num_patches = (img_size // patch_size) ** 2

        #CLS TOKEN
        self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim) * 0.02)
        self.dist_token = nn.Parameter(torch.randn(1, 1, embedding_dim) * 0.02)

        #Postional Embedding
        self.pos_embed  = nn.Parameter(torch.randn(1, num_patches + 2, embedding_dim) * 0.02)
        self.dropout = nn.Dropout(dropout)

        #Transformer  encoder

        encoder_layer = nn.TransformerEncoderLayer(embedding_dim, n_head, ff_dim, dropout, activation= 'gelu', batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)

        self.norm = nn.LayerNorm(embedding_dim)

        #Two Head DEiT
        self.head_cls = nn.Linear(embedding_dim, num_classes)
        self.head_dist = nn.Linear(embedding_dim, num_classes)

    def forward(self, x):
        x = self.patch_embed(x)

        B = x.size(0)
        cls_token = self.cls_token.expand(B, -1, -1)
        dist_token = self.dist_token.expand(B, -1, -1)

        #Concatenate CLS + DIST, PATCHES
        x = torch.cat([cls_token, dist_token, x], dim = 1)

        #Add positional Encoding
        z = x + self.pos_embed

        z = self.dropout(z)
        z = self.transformer(z)

        z = self.norm(z)

        cls_out = self.head_cls(z[:0])
        dist_out = self.head_dist(z[:, 1])
        return cls_out, dist_token

In [3]:
#Function to calculate total loss
def deit_loss(cls_out, dist_out, teacher_probs, labels, T = 1.0):
    ce_loss = F.cross_entropy(cls_out, labels)

    #Distillation
    k1_loss = F.K1_div(
        F.log_softmax(dist_out/ T, dim=-1),
        teacher_probs,
        reduction ='batchmean'
    ) * (T * T)
    return ce_loss + k1_loss, ce_loss, k1_loss

In [None]:
#Define Teacher Model
import torchvision.models as models

teacher = models.resnet50(weights = 'IMAGENET1K_V1')
teacher.eval()

for p in teacher.parameters():
    p.requires_grad = False
teacher.to(device)

In [None]:
def eval(model, loader, metrics, t_model):
    model.eval()
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            teacher_prob = torch.softmax(t_model(images), dim=-1)
            cls_out, dist_out  = model(images)
            loss, ce_loss, dist_loss = deit_loss(cls_out, dist_out, teacher_prob, labels)

            