<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Knowledge_Distillation_for_Model_Compression.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.utils.data import DataLoader

# Define the DistillationLoss class
class DistillationLoss(nn.Module):
    def __init__(self, temperature=2.0):
        super(DistillationLoss, self).__init__()
        self.temperature = temperature

    def forward(self, teacher_logits, student_logits):
        teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1)
        student_probs = F.softmax(student_logits / self.temperature, dim=-1)
        return F.kl_div(student_probs.log(), teacher_probs, reduction="batchmean")

# Initialize teacher and student models and the tokenizer
teacher_model = AutoModelForCausalLM.from_pretrained("gpt2")
student_model = AutoModelForCausalLM.from_pretrained("distilgpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Example training data (dataloader)
data = [
    {"input_ids": tokenizer("Example sentence one", return_tensors="pt")["input_ids"].squeeze()},
    {"input_ids": tokenizer("Example sentence two", return_tensors="pt")["input_ids"].squeeze()},
    # Add more training data here
]

train_dataloader = DataLoader(data, batch_size=2, shuffle=True)
optimizer = Adam(student_model.parameters(), lr=5e-5)

# Instantiate the distillation loss function
distillation_loss_fn = DistillationLoss()

teacher_model.eval()  # Freeze the teacher model

# Training loop
for epoch in range(3):  # Adjust number of epochs as needed
    student_model.train()
    for batch in train_dataloader:
        input_ids = batch["input_ids"].to(student_model.device)

        # Forward pass for student model
        student_logits = student_model(input_ids=input_ids).logits

        # Forward pass for teacher model (no gradients)
        with torch.no_grad():
            teacher_logits = teacher_model(input_ids=input_ids).logits

        # Compute distillation loss
        loss = distillation_loss_fn(teacher_logits, student_logits)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Epoch: {epoch + 1}, Loss: {loss.item()}")

print("Training complete!")