In [None]:
!pip install transformers datasets torch

In [2]:
import torch.nn as nn

In [None]:
# From a new Colab cell:
!pip install --upgrade datasets transformers huggingface_hub gcsfs fsspec

In [6]:
import os
import torch
import torch.nn as nn
from transformers import (
    AutoTokenizer,
    AutoModelForQuestionAnswering,
    Trainer,
    TrainingArguments,
    pipeline
)
from datasets import load_dataset

# 1. Load the SQuAD dataset
dataset = load_dataset("squad")

# 2. Define teacher and student models
teacher_name = "bert-base-uncased"
student_name = "distilbert-base-uncased"

# Initialize tokenizer and models
tokenizer = AutoTokenizer.from_pretrained(teacher_name)
teacher_model = AutoModelForQuestionAnswering.from_pretrained(teacher_name)
student_model = AutoModelForQuestionAnswering.from_pretrained(student_name)

# 3. Preprocessing function for question answering
max_length = 384
doc_stride = 128

def prepare_features(examples):
    tokenized = tokenizer(
        examples['question'], examples['context'],
        truncation='only_second', max_length=max_length,
        stride=doc_stride, return_overflowing_tokens=True,
        return_offsets_mapping=True, padding='max_length'
    )
    sample_mapping = tokenized.pop('overflow_to_sample_mapping')
    offset_mapping = tokenized.pop('offset_mapping')

    start_positions = []
    end_positions = []

    for i, offsets in enumerate(offset_mapping):
        input_ids = tokenized['input_ids'][i]
        sequence_ids = tokenized.sequence_ids(i)
        sample_index = sample_mapping[i]
        answers = examples['answers'][sample_index]

        if len(answers['answer_start']) == 0:
            start_positions.append(0)
            end_positions.append(0)
        else:
            start_char = answers['answer_start'][0]
            end_char   = start_char + len(answers['text'][0])
            # find token start/end
            token_start_index = 0
            while sequence_ids[token_start_index] != 1:
                token_start_index += 1
            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != 1:
                token_end_index -= 1

            # if answer not fully in span
            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                start_positions.append(0)
                end_positions.append(0)
            else:
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                start_positions.append(token_start_index - 1)
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                end_positions.append(token_end_index + 1)

    tokenized['start_positions'] = start_positions
    tokenized['end_positions']   = end_positions
    return tokenized

# 4. Tokenize the dataset
tokenized_datasets = dataset.map(
    prepare_features,
    batched=True,
    remove_columns=dataset['train'].column_names
)

# ───────────────────────────────────────────────────────────────────────────────
# 5) Tokenize the entire dataset
# ───────────────────────────────────────────────────────────────────────────────
tokenized_datasets = dataset.map(
    prepare_features,
    batched=True,
    remove_columns=dataset["train"].column_names
)


Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
# 4. Custom Trainer for distillation
class DistillationTrainer(Trainer):
    def __init__(self, teacher_model, alpha=0.5, temperature=2.0, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher_model = teacher_model
        self.alpha = alpha
        self.temperature = temperature
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = {
            'start_positions': inputs.pop('start_positions'),
            'end_positions': inputs.pop('end_positions')
        }
        # Student forward pass
        outputs_student = model(**inputs, **labels)
        loss_ce = outputs_student.loss

        # Teacher forward (no gradients)
        with torch.no_grad():
            outputs_teacher = self.teacher_model(**inputs)

        t = self.temperature
        # Distillation losses
        loss_start = self.kl_loss(
            nn.functional.log_softmax(outputs_student.start_logits / t, dim=-1),
            nn.functional.softmax(outputs_teacher.start_logits / t, dim=-1)
        ) * (t * t)
        loss_end = self.kl_loss(
            nn.functional.log_softmax(outputs_student.end_logits / t, dim=-1),
            nn.functional.softmax(outputs_teacher.end_logits / t, dim=-1)
        ) * (t * t)
        loss_distill = (loss_start + loss_end) / 2

        # Combine CE and distillation losses
        loss = self.alpha * loss_distill + (1 - self.alpha) * loss_ce
        return (loss, outputs_student) if return_outputs else loss

In [8]:
# 5. Set up training arguments (without evaluation_strategy for compatibility)
training_args = TrainingArguments(
    output_dir='./models_distilled',
    learning_rate=3e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=2,
    save_total_limit=2,
    logging_steps=100,
    report_to=[]
)

In [9]:
# 6. Initialize the distillation trainer
trainer = DistillationTrainer(
    teacher_model=teacher_model,
    model=student_model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
    tokenizer=tokenizer
)

  super().__init__(*args, **kwargs)


In [10]:
# 4. Custom Trainer for distillation
class DistillationTrainer(Trainer):
    def __init__(self, teacher_model, alpha=0.5, temperature=2.0, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher_model = teacher_model
        self.alpha = alpha
        self.temperature = temperature
        # Ensure teacher is on the same device as student
        self.teacher_model.to(self.args.device)
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = {
            'start_positions': inputs.pop('start_positions'),
            'end_positions': inputs.pop('end_positions')
        }
        # Student forward pass
        outputs_student = model(**inputs, **labels)
        loss_ce = outputs_student.loss

        # Teacher forward (no gradients)
        with torch.no_grad():
            outputs_teacher = self.teacher_model(**inputs)

        t = self.temperature
        # Distillation losses
        loss_start = self.kl_loss(
            nn.functional.log_softmax(outputs_student.start_logits / t, dim=-1),
            nn.functional.softmax(outputs_teacher.start_logits / t, dim=-1)
        ) * (t * t)
        loss_end = self.kl_loss(
            nn.functional.log_softmax(outputs_student.end_logits / t, dim=-1),
            nn.functional.softmax(outputs_teacher.end_logits / t, dim=-1)
        ) * (t * t)
        loss_distill = (loss_start + loss_end) / 2

        # Combine CE and distillation losses
        loss = self.alpha * loss_distill + (1 - self.alpha) * loss_ce
        return (loss, outputs_student) if return_outputs else loss

# 5. Set up training arguments (without evaluation_strategy for compatibility)
training_args = TrainingArguments(
    output_dir='./models_distilled',
    learning_rate=3e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=1,
    save_total_limit=2,
    logging_steps=100,
    report_to=[]
)

# 6. Initialize the distillation trainer
trainer = DistillationTrainer(
    teacher_model=teacher_model,
    model=student_model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
    tokenizer=tokenizer
)

# 7. Train and save the distilled student model
trainer.train()
trainer.save_model('./distilled_student_model')

# 8. Print compression statistics
def print_compression_stats(teacher, student):
    teacher_params = sum(p.numel() for p in teacher.parameters())
    student_params = sum(p.numel() for p in student.parameters())
    print(f"Teacher parameters: {teacher_params}")
    print(f"Student parameters: {student_params}")
    print(f"Compression ratio (teacher/student): {teacher_params / student_params:.2f}x")

print_compression_stats(teacher_model, student_model)

  super().__init__(*args, **kwargs)
W0530 20:12:38.562000 907 torch/_inductor/utils.py:1137] [0/0] Not enough SMs to use max_autotune_gemm mode


Step,Training Loss
100,2.401
200,1.7383
300,1.5626
400,1.3962
500,1.3563
600,1.2861
700,1.248
800,1.2466
900,1.2046
1000,1.1538


Teacher parameters: 108893186
Student parameters: 66364418
Compression ratio (teacher/student): 1.64x


In [12]:
def answer_question(question, context):
    inputs = tokenizer(question, context, return_tensors='pt')
    # remove token_type_ids if present
    inputs = {k: v.to(trainer.args.device) for k, v in inputs.items() if k != 'token_type_ids'}
    outputs = student_model(**inputs)
    start_idx = torch.argmax(outputs.start_logits, dim=-1).item()
    end_idx = torch.argmax(outputs.end_logits, dim=-1).item() + 1
    tokens = inputs['input_ids'][0][start_idx:end_idx]
    answer = tokenizer.decode(tokens, skip_special_tokens=True)
    # average of start/end probabilities as score
    prob_start = torch.softmax(outputs.start_logits, dim=-1)[0][start_idx]
    prob_end = torch.softmax(outputs.end_logits, dim=-1)[0][end_idx-1]
    score = ((prob_start + prob_end) / 2).item()
    return answer, score

# 9. Example QA
context = "The Transformer architecture was introduced in the paper Attention is All You Need in 2017."
question = "When was the Transformer architecture introduced?"
ans, scr = answer_question(question, context)
print("Answer:", ans)
print("Score:", scr)

Answer: 2017
Score: 0.8541181087493896
