<a href="https://colab.research.google.com/github/JAZ201107/PyTorch-DL/blob/main/DeiT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

![](https://noblecatt-1304922865.cos.ap-singapore.myqcloud.com/202412101618419.png)


In [None]:
class SoftDistillationLoss(nn.Module):
    def __init__(self, teacher, alpha=0.5, tau=0.2):
        super().__init__()

        self.ce = nn.CrossEntropyLoss()
        self.kl = nn.KLDivLoss(reduction="batchmean")

        self.teacher = teacher
        self.alpha = alpha
        self.tau = tau

    def forward(self, inputs, outputs, labels):

        base_loss = self.ce(outputs, labels)

        with torch.no_grad():
            teacher_outputs = self.teacher(inputs)

        soft_loss = self.kl(
            F.log_softmax(outputs / self.tau, dim=1),
            F.softmax(teacher_outputs / self.tau, dim=1),
        ) * (self.tau**2)

        return (1 - self.alpha) * base_loss + self.alpha * soft_loss

![](https://noblecatt-1304922865.cos.ap-singapore.myqcloud.com/202412101618723.png)


In [None]:
class HardDistillationLoss(nn.Module):
    def __init__(self, teacher):
        super().__init__()

        self.teacher = teacher
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, inputs, outputs, labels):
        base_loss = self.criterion(outputs, labels)

        with torch.no_grad():
            self.teacher.eval()
            teacher_outputs = self.teacher(inputs)

        teacher_labels = torch.argmax(teacher_outputs, dim=-1)
        teacher_loss = self.criterion(outputs, teacher_outputs)

        return (base_loss + teacher_loss) / 2

![](https://noblecatt-1304922865.cos.ap-singapore.myqcloud.com/202412101531507.png)


In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=16, emb_size=768, img_size=224):
        super().__init__()

        self.patch_size = patch_size

        self.projection = nn.Sequential(
            nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
            nn.Flatten(start_dim=2),
        )  # (B, E, patch_size * patch_size)

        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
        self.dis_token = nn.Parameter(torch.randn(1, 1, emb_size))

        self.position = nn.Parameter(
            torch.randn((img_size // patch_size) ** 2 + 2, emb_size)
        )

    def forward(self, x):
        B, C, H, W = x.shape

        x = self.projection(x)
        x = x.transpose(1, 2)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        dis_tokens = self.dis_token.expand(B, -1, -1)

        x = torch.cat((cls_tokens, dis_tokens, x), dim=1)
        x += self.position

        return x

In [None]:
class ClassificationHead(nn.Module):
    def __init__(self, emb_size=768, n_class=1000):
        super().__init__()

        self.head = nn.Linear(emb_size, n_class)
        self.dist_head = nn.Linear(emb_size, n_class)

    def forward(self, x):
        x, x_dist = x[:, 0, :], x[:, -1, :]
        # Classification head
        x_head = self.head(x)
        # Distillation head
        x_dist_head = self.dist_head(x_dist)

        if self.training:
            return x_head, x_dist_head
        else:
            return (x_head + x_dist_head) / 2

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, num_heads=8, dropout=0.1):
        super().__init__()

        self.embed_size = embed_size
        self.num_heads = num_heads
        assert (
            embed_size % num_heads == 0
        ), "Embedding dimension must be  a multiple of number of heads"

        self.head_dim = embed_size // num_heads

        self.qkv_lin = nn.Linear(embed_size, embed_size * 3)

        self.attn_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(embed_size, embed_size)

    def forward(self, x):
        B, N, E = x.shape

        qkv = self.qkv_lin(x)
        q, k, v = qkv.chunk(3, dim=-1)

        q = q.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)

        attn_weights = (q @ k.transpose(-2, -1)) * (self.head_dim**-0.5)
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_weights = self.attn_drop(attn_weights)

        x = (attn_weights @ v).transpose(1, 2).reshape(B, N, E)
        x = self.projection(x)

        return x

In [None]:
class FeedForwardBlock(nn.Module):
    def __init__(self, embed_size, expansion_factor=4, dropout=0.1):
        super().__init__()

        self.ln1 = nn.Linear(embed_size, embed_size * expansion_factor)
        self.act = nn.GELU()
        self.dropout = nn.Dropout(dropout)
        self.ln2 = nn.Linear(embed_size * expansion_factor, embed_size)

    def forward(self, x):
        x = self.ln1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.ln2(x)

        return x

In [None]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_size, num_heads, expansion_factor, dropout):
        super().__init__()

        self.ln1 = nn.LayerNorm(embed_size)
        self.attn = MultiHeadAttention(embed_size, num_heads, dropout)
        self.ln2 = nn.LayerNorm(embed_size)
        self.ff = FeedForwardBlock(embed_size, expansion_factor, dropout)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        residual = x

        x = self.dropout(self.attn(self.ln1(x)))
        x += residual

        residual = x
        x = self.dropout(self.ff(self.ln2(x)))
        x += residual

        return x

In [None]:
class ViTTransformer(nn.Module):
    def __init__(
        self,
        in_channels,
        patch_size,
        emb_size,
        img_size,
        num_layers,
        num_heads,
        expansion_factor,
        dropout,
        n_class,
    ):
        super().__init__()

        self.patch_embed = PatchEmbedding(in_channels, patch_size, emb_size, img_size)

        self.transformer = nn.Sequential(
            *[
                TransformerEncoderBlock(emb_size, num_heads, expansion_factor, dropout)
                for _ in range(num_layers)
            ]
        )

        self.classification_head = ClassificationHead(emb_size, n_class)

    def forward(self, x):
        x = self.patch_embed(x)
        x = self.transformer(x)
        x = self.classification_head(x)

        return x

In [None]:
# Now, Starting training
import timm

# Use Convolutions instead
teacher_model = timm.create_model("res", pretrained=True)
student_model = ViTTransformer(
    in_channels=3,
    patch_size=16,
    emb_size=768,
    img_size=224,
    num_layers=12,
    num_heads=12,
    expansion_factor=4,
    dropout=0.1,
    n_class=1000,
)

criterion = HardDistillationLoss(teacher_model)