In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# tiny fake vocab
vocab_size = 8

# tiny fake Transformer-like model
class TinyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(16, 32)
        self.linear2 = nn.Linear(32, vocab_size)

    def forward(self, x):
        h = F.relu(self.linear1(x))
        return self.linear2(h)  # logits [batch, seq, vocab]


In [2]:
teacher = TinyModel()
student = TinyModel()


In [3]:
batch = 1
seq_len = 5
embed_dim = 16

x = torch.randn(batch, seq_len, embed_dim)


In [4]:
teacher_logits = teacher(x)


In [5]:
student_logits = student(x)


In [9]:
print(student_logits.shape)
print(teacher_logits.shape)


torch.Size([1, 5, 8])
torch.Size([1, 5, 8])


In [11]:
# Temperature for knowledge distillation
T = 3.0  # Common values are 3-5

In [12]:
p_teacher = F.softmax(teacher_logits / T, dim=-1)
p_student_log = F.log_softmax(student_logits / T, dim=-1)


In [13]:
kl = F.kl_div(p_student_log, p_teacher, reduction="batchmean") * (T*T)


In [14]:
print(f"KL Divergence Loss: {kl.item():.4f}")

KL Divergence Loss: 0.1758
