<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Knowledge_Distillation_for_Efficient_Deployment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import GPT2Tokenizer, GPT2LMHeadModel

# Example models (teacher and student)
teacher_model = GPT2LMHeadModel.from_pretrained("gpt2")
student_model = GPT2LMHeadModel.from_pretrained("gpt2")  # You can use a smaller model or a lighter version

# Add pad token to tokenizer and resize model embeddings
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
teacher_model.resize_token_embeddings(len(tokenizer))
student_model.resize_token_embeddings(len(tokenizer))

# Example training data
texts = ["Hello, how are you?", "This is an example text.", "Reinforcement learning with transformers."]
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=50)

# Distillation loss (KL Divergence Loss)
distillation_criterion = nn.KLDivLoss(reduction='batchmean')

# Optimizer for the student model
optimizer = optim.Adam(student_model.parameters(), lr=1e-4)

# Temperature for distillation
temperature = 3.0

# Knowledge distillation loop
for epoch in range(3):  # Simulating 3 epochs
    for data in inputs['input_ids']:  # Assuming inputs is your training dataset
        data = data.unsqueeze(0)  # Add batch dimension
        # Ensure data is on the correct device
        data = data.to(student_model.device)

        # Forward pass of the teacher model
        with torch.no_grad():
            teacher_output = teacher_model(data).logits / temperature

        # Forward pass of the student model
        student_output = student_model(data).logits / temperature

        # Calculate the distillation loss
        distillation_loss = distillation_criterion(
            nn.functional.log_softmax(student_output, dim=-1),
            nn.functional.softmax(teacher_output, dim=-1)
        )

        # Backward pass and optimization step
        distillation_loss.backward()
        optimizer.step()
        optimizer.zero_grad()

print("Knowledge distillation completed.")