<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Knowledge_Distillation_Example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn.functional as F
from torch.optim import Adam
from transformers import AutoModelForSequenceClassification, AutoTokenizer

class DistillationTrainer:
    def __init__(self, teacher_model, student_model, temperature=3.0, alpha=0.5, learning_rate=1e-4):
        self.teacher = teacher_model
        self.student = student_model
        self.temperature = temperature
        self.alpha = alpha
        self.optimizer = Adam(self.student.parameters(), lr=learning_rate)

    def distillation_loss(self, student_logits, teacher_logits, labels):
        teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1)
        student_probs = F.log_softmax(student_logits / self.temperature, dim=-1)
        distillation_loss = F.kl_div(student_probs, teacher_probs, reduction="batchmean") * (self.temperature ** 2)
        classification_loss = F.cross_entropy(student_logits, labels)
        return self.alpha * distillation_loss + (1 - self.alpha) * classification_loss

    def train_step(self, input_data, labels):
        self.student.train()
        self.teacher.eval()

        with torch.no_grad():
            teacher_logits = self.teacher(**input_data).logits

        student_logits = self.student(**input_data).logits
        loss = self.distillation_loss(student_logits, teacher_logits, labels)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

    def train(self, dataloader, epochs):
        for epoch in range(epochs):
            total_loss = 0
            for batch in dataloader:
                input_data = {
                    "input_ids": batch["input_ids"],
                    "attention_mask": batch["attention_mask"]
                }
                labels = batch["labels"]
                loss = self.train_step(input_data, labels)
                total_loss += loss

            avg_loss = total_loss / len(dataloader)
            print(f"Epoch [{epoch + 1}/{epochs}], Loss: {avg_loss:.4f}")

# Helper function to pad sequences
def pad_sequences(tokenizer, texts, max_length):
    return tokenizer(
        texts, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt"
    )

# Example usage
if __name__ == "__main__":
    model_name = "bert-base-uncased"
    teacher_model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
    student_model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Create a list of texts and labels
    texts = ["Sample text for training", "Another sample text"]
    labels = torch.tensor([1, 0])

    # Pad the sequences
    max_length = 10  # Define the maximum length for padding
    padded_inputs = pad_sequences(tokenizer, texts, max_length)

    # Create a DataLoader
    data = [{"input_ids": input_ids, "attention_mask": attention_mask, "labels": label}
            for input_ids, attention_mask, label in zip(padded_inputs["input_ids"], padded_inputs["attention_mask"], labels)]

    dataloader = torch.utils.data.DataLoader(data, batch_size=2)

    epochs = 10

    trainer = DistillationTrainer(teacher_model, student_model, temperature=3.0, alpha=0.5, learning_rate=1e-4)
    trainer.train(dataloader, epochs)