**Initial Setup**

In [1]:
import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.amp import GradScaler, autocast
from transformers import AutoModelForCausalLM, AutoTokenizer

**Load Models**

In [2]:
# Load Teacher and Student Models
#teacher_model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
teacher_model_name = "Jesujuwon/distilgpt2-squad"
#student_model_name = "Locutusque/TinyMistral-248M"
student_model_name = "tniranjan/finetuned_tinystories_33M_pretrained_tinystories_ta"

teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model_name)
student_model = AutoModelForCausalLM.from_pretrained(student_model_name)

teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
student_tokenizer = AutoTokenizer.from_pretrained(student_model_name)

**Dataset Processing**

In [3]:
# Preprocessing function
def preprocess_batch(batch, tokenizer, max_length=256):
    """
    Preprocesses a batch of examples into tokenized format.
    """
    # Extract questions and contexts from the batch
    questions = [example["question"] for example in batch]
    contexts = [example["context"] for example in batch]
    
    # Tokenize context and question
    inputs = tokenizer(
        questions,
        contexts,
        truncation=True,
        padding="max_length",
        max_length=max_length,
        return_tensors="pt"
    )
    
    # Extract answer text (use the first answer for simplicity)
    answer_texts = [
        example["answers"]["text"][0] if len(example["answers"]["text"]) > 0 else "" 
        for example in batch
    ]
    
    # Tokenize answers
    labels = tokenizer(
        answer_texts,
        truncation=True,
        padding="max_length",
        max_length=max_length,
        return_tensors="pt"
    )["input_ids"]
    
    # Add labels to inputs
    inputs["labels"] = labels
    
    return inputs

**Load Dataset**

In [4]:
# Load SQuAD1.1 Dataset
splits = {'train': 'plain_text/train-00000-of-00001.parquet', 'validation': 'plain_text/validation-00000-of-00001.parquet'}
train_df = pd.read_parquet("hf://datasets/rajpurkar/squad/" + splits["train"])
validation_df = pd.read_parquet("hf://datasets/rajpurkar/squad/" + splits["validation"])

**Process Dataset**

In [5]:
# Convert DataFrame to a list of dictionaries for batch processing
train_data = train_df.to_dict(orient="records")
validation_data = validation_df.to_dict(orient="records")

# Create a PyTorch Dataset
class QADataset(torch.utils.data.Dataset):
    def __init__(self, input_ids, attention_mask, labels):
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.labels = labels

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
            "labels": self.labels[idx]
        }

# Preprocess data in batches
def process_dataset(data, tokenizer):
    processed_data = []
    batch_size = 32

    for i in range(0, len(data), batch_size):
        batch = data[i:i+batch_size]
        processed_batch = preprocess_batch(batch, tokenizer)
        processed_data.append(processed_batch)

    # Combine all processed batches into a single dataset
    input_ids = torch.cat([batch["input_ids"] for batch in processed_data])
    attention_mask = torch.cat([batch["attention_mask"] for batch in processed_data])
    labels = torch.cat([batch["labels"] for batch in processed_data])

    dataset = QADataset(input_ids, attention_mask, labels)

    return dataset

train_dataset = process_dataset(train_data, student_tokenizer)
validation_dataset = process_dataset(validation_data, student_tokenizer)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(validation_dataset, batch_size=4, shuffle=True)

**Teacher Vocab is much larger than Students so need to configure a projection layer so both tensors will align**

In [6]:
# Define a projection layer to align teacher logits with student logits
class TeacherLogitProjection(nn.Module):
    def __init__(self, teacher_vocab_size, student_vocab_size):
        super(TeacherLogitProjection, self).__init__()
        self.projection = nn.Linear(teacher_vocab_size, student_vocab_size)

    def forward(self, teacher_logits):
        return self.projection(teacher_logits)


In [7]:
# Define Distillation Loss
class DistillationLoss(nn.Module):
    def __init__(self, alpha=0.5, temperature=2.0):
        super(DistillationLoss, self).__init__()
        self.alpha = alpha  # Weight for hard vs soft targets
        self.temperature = temperature  # Temperature for softening logits
        self.ce_loss = nn.CrossEntropyLoss()

    def forward(self, student_logits, teacher_logits, labels):
        # Softened teacher logits
        teacher_probs = nn.functional.softmax(teacher_logits / self.temperature, dim=-1)
        student_probs = nn.functional.log_softmax(student_logits / self.temperature, dim=-1)

        # KL Divergence loss (soft targets)
        distillation_loss = nn.functional.kl_div(student_probs, teacher_probs, reduction="batchmean") * (self.temperature ** 2)

        # Cross-entropy loss (hard targets)
        ce_loss = self.ce_loss(student_logits.view(-1, student_logits.size(-1)), labels.view(-1))

        return self.alpha * ce_loss + (1 - self.alpha) * distillation_loss

**Knowledge Distillation Function**

In [8]:

def train_student_with_distillation(teacher_model, student_model, train_loader, val_loader, epochs):
    teacher_device = "cpu"  # Move teacher model to CPU
    #student_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    student_device = "cuda"  # Move student model to CPU
    teacher_model.to(teacher_device)
    student_model.to(student_device)

    optimizer = torch.optim.AdamW(student_model.parameters(), lr=5e-5)
    distillation_criterion = DistillationLoss(alpha=0.5, temperature=2.0)
    scaler = GradScaler('cuda')  # Mixed precision scaler to reduce VRAM use

    # Add projection layer for teacher logits
    teacher_vocab_size = teacher_model.config.vocab_size
    student_vocab_size = student_model.config.vocab_size
    projection_layer = TeacherLogitProjection(teacher_vocab_size, student_vocab_size).to(teacher_device)
    
    for epoch in range(epochs):
        student_model.train()
        teacher_model.eval()

        epoch_loss = 0

        for batch in train_loader:
            input_ids = batch["input_ids"].to(student_device)
            attention_mask = batch["attention_mask"].to(student_device)
            labels = batch["labels"].to(student_device)

            optimizer.zero_grad()

            # Mixed precision context
            with autocast('cuda'):
                with torch.no_grad():
                    teacher_outputs = teacher_model(input_ids=input_ids.to(teacher_device), attention_mask=attention_mask.to(teacher_device))
                    teacher_logits = projection_layer(teacher_outputs.logits)
                    projection_layer.to(student_device)

                student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
                student_logits = student_outputs.logits

                loss = distillation_criterion(student_logits, teacher_logits, labels)

            scaler.scale(loss).backward()  # Scale gradients
            scaler.step(optimizer)  # Step optimizer
            scaler.update()  # Update scaler

            epoch_loss += loss.item()

        print(f"Epoch {epoch + 1}, Loss: {epoch_loss / len(train_loader)}")


In [9]:
# Run Training
train_student_with_distillation(
    teacher_model,
    student_model,
    train_loader,
    val_loader,
    epochs=3
)

OutOfMemoryError: CUDA out of memory. Tried to allocate 9.36 GiB. GPU 0 has a total capacity of 11.60 GiB of which 8.31 GiB is free. Including non-PyTorch memory, this process has 482.00 MiB memory in use. Of the allocated memory 278.11 MiB is allocated by PyTorch, and 11.89 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)