In [13]:
# Install necessary libraries from Hugging Face and PyTorch
!pip install transformers datasets accelerate bitsandbytes torch



In [14]:
# Log in to your Hugging Face account
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [15]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, BitsAndBytesConfig
from datasets import load_dataset

In [16]:
# --- 1. Load Tokenizer and Dataset ---
student_model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(student_model_name)

# Load the IMDB dataset
imdb = load_dataset("imdb")

# Define the preprocessing function
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)

# CORRECT ORDER: First, tokenize the entire dataset
tokenized_imdb = imdb.map(preprocess_function, batched=True)

# THEN, create your smaller subsets from the tokenized dataset
train_dataset = tokenized_imdb["train"].shuffle(seed=42).select(range(10000))
eval_dataset = tokenized_imdb["test"].shuffle(seed=42).select(range(1000))

print("Dataset loaded and tokenized.")

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Dataset loaded and tokenized.


In [17]:
# --- 2. Load Student Model ---
student_model = AutoModelForSequenceClassification.from_pretrained(
    student_model_name,
    num_labels=2 # Positive or Negative sentiment
)
print("Student model loaded.")

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Student model loaded.


In [18]:
# --- 3. Load Teacher Model ---
teacher_model_name = "google/gemma-3-4b-it"

# Pro-Tip: Load the large teacher model in 4-bit to make it fit in Colab's memory.
# This is a practical application of quantization!
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16
)

teacher_model = AutoModelForSequenceClassification.from_pretrained(
    teacher_model_name,
    num_labels=2,
    quantization_config=quantization_config,
    device_map="auto" # Automatically map model layers to available devices (GPU)
)
print("Teacher model loaded in 4-bit precision.")

config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/90.6k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of Gemma3ForSequenceClassification were not initialized from the model checkpoint at google/gemma-3-4b-it and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Teacher model loaded in 4-bit precision.


In [19]:
from transformers import Trainer, TrainingArguments
import torch.nn.functional as F
import torch

# This is the core of your project's logic, where you customize the training process.
class DistillationTrainer(Trainer):
    def __init__(self, *args, teacher_model=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher_model
        # Move teacher to the same device as the student model
        if self.teacher is not None:
            self.teacher.to(self.model.device)

    # CORRECTED LINE: Added **kwargs to accept any extra arguments
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        # --- Standard Student Loss ---
        # Get the student's own predictions and calculate the loss against the true labels.
        outputs_student = model(**inputs)
        student_loss = outputs_student.loss

        # --- Distillation Loss ---
        # Get the teacher's predictions (logits). No gradient needed for the teacher.
        with torch.no_grad():
            outputs_teacher = self.teacher(**inputs)

        # Define distillation parameters (as mentioned in your survey paper)
        alpha = 0.5      # Balances the two loss components
        temperature = 2.0  # Softens the probability distributions for better knowledge transfer

        # Calculate the distillation loss between teacher and student "soft targets"
        distillation_loss = F.kl_div(
            input=F.log_softmax(outputs_student.logits / temperature, dim=-1),
            target=F.softmax(outputs_teacher.logits / temperature, dim=-1),
            reduction="batchmean"
        ) * (temperature ** 2)

        # Calculate the final combined loss as a weighted sum
        loss = alpha * student_loss + (1.0 - alpha) * distillation_loss
        return (loss, outputs_student) if return_outputs else loss

# --- Define Training Arguments for a quick proof-of-concept run ---
training_args = TrainingArguments(
    output_dir="distilled_model_checkpoint",
    num_train_epochs=1,  # Train for only 1 epoch for this test
    per_device_train_batch_size=4, # Use a smaller batch size to avoid memory issues
    per_device_eval_batch_size=4,
    fp16=True, # Use mixed precision for speed
    logging_steps=50,
    save_strategy="epoch",
    # We use max_steps to limit the training run for a fast result
    max_steps=200
)

# --- Instantiate and Run the Trainer ---
trainer = DistillationTrainer(
    model=student_model,
    args=training_args,
    # CORRECTED LINES: Use the variables you created earlier
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    teacher_model=teacher_model
)

print("Starting distillation training...")
trainer.train()
print("Proof-of-concept training complete.")

  super().__init__(*args, **kwargs)


Starting distillation training...


  | |_| | '_ \/ _` / _` |  _/ -_)


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mgjaswink[0m ([33mgjaswin[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss
50,0.0
100,0.0
150,0.0
200,0.0


Proof-of-concept training complete.


In [20]:
!pip install evaluate

Collecting evaluate
  Downloading evaluate-0.4.6-py3-none-any.whl.metadata (9.5 kB)
Downloading evaluate-0.4.6-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.6


In [21]:
import numpy as np
import evaluate  # The new Hugging Face library for metrics

# 1. Define the metric we want to use (accuracy)
metric = evaluate.load("accuracy")

# 2. Create a function that the Trainer will call to compute the metric
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

# --- Instantiate and Run the Trainer (with the new argument) ---
trainer = DistillationTrainer(
    model=student_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    teacher_model=teacher_model,
    compute_metrics=compute_metrics  # <-- ADD THIS LINE
)

# --- You can now re-run the evaluation ---
print("Evaluating the distilled student model...")
evaluation_results = trainer.evaluate()

# This will now work correctly
print("\n--- Evaluation Results ---")
print(f"Accuracy: {evaluation_results['eval_accuracy']:.4f}")
print(f"Loss: {evaluation_results['eval_loss']:.4f}")

Downloading builder script: 0.00B [00:00, ?B/s]

  super().__init__(*args, **kwargs)


Evaluating the distilled student model...



--- Evaluation Results ---
Accuracy: 0.5120
Loss: nan


In [22]:
# Force an upgrade of all necessary libraries to their latest versions
!pip install -U transformers datasets accelerate bitsandbytes torch

Collecting datasets
  Downloading datasets-4.3.0-py3-none-any.whl.metadata (18 kB)
Collecting torch
  Downloading torch-2.9.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (30 kB)
Collecting pyarrow>=21.0.0 (from datasets)
  Downloading pyarrow-22.0.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (3.2 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.8.93 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl.metadata (1.7 kB)
Collecting nvidia-cuda-runtime-cu12==12.8.90 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.7 kB)
Collecting nvidia-cuda-cupti-cu12==12.8.90 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.7 kB)
Collecting nvidia-cublas-cu12==12.8.4.1 (from torch)
  Downloading nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl.metadata (1.7 kB)
Collecting nvi

In [23]:
from transformers import TrainingArguments

# --- Define Training Arguments for a more STABLE and effective run ---
training_args = TrainingArguments(
    output_dir="distilled_model_final_checkpoint",
    num_train_epochs=3,          # Train for 3 full epochs to allow for better learning
    learning_rate=2e-5,          # A smaller, more stable learning rate is standard for fine-tuning
    per_device_train_batch_size=8,   # Slightly larger batch size if memory allows
    per_device_eval_batch_size=8,
    weight_decay=0.01,           # A standard regularization technique to prevent overfitting
    fp16=True,
    logging_strategy="epoch",    # Log metrics at the end of each epoch
    eval_strategy="epoch", # Evaluate at the end of each epoch
    save_strategy="epoch",
    load_best_model_at_end=True, # Automatically load the best performing model at the end
    report_to="wandb"            # Log results to Weights & Biases
)

# --- Re-instantiate your Trainer with the new, better arguments ---
trainer = DistillationTrainer(
    model=student_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    teacher_model=teacher_model,
    compute_metrics=compute_metrics
)

# --- Start the full training run ---
print("Starting full distillation training...")
trainer.train()
print("Full training complete.")

  super().__init__(*args, **kwargs)


Starting full distillation training...


Epoch,Training Loss,Validation Loss,Accuracy
1,0.0,,0.512
2,0.0,,0.512
3,0.0,,0.512


Full training complete.


In [29]:
!zip -r distilled_model_final_checkpoint.zip distilled_model_final_checkpoint/checkpoint-3750


  adding: distilled_model_final_checkpoint/checkpoint-3750/ (stored 0%)
  adding: distilled_model_final_checkpoint/checkpoint-3750/model.safetensors (deflated 8%)
  adding: distilled_model_final_checkpoint/checkpoint-3750/scaler.pt (deflated 64%)
  adding: distilled_model_final_checkpoint/checkpoint-3750/tokenizer_config.json (deflated 75%)
  adding: distilled_model_final_checkpoint/checkpoint-3750/training_args.bin (deflated 53%)
  adding: distilled_model_final_checkpoint/checkpoint-3750/optimizer.pt (deflated 100%)
  adding: distilled_model_final_checkpoint/checkpoint-3750/trainer_state.json (deflated 70%)
  adding: distilled_model_final_checkpoint/checkpoint-3750/config.json (deflated 45%)
  adding: distilled_model_final_checkpoint/checkpoint-3750/scheduler.pt (deflated 61%)
  adding: distilled_model_final_checkpoint/checkpoint-3750/special_tokens_map.json (deflated 42%)
  adding: distilled_model_final_checkpoint/checkpoint-3750/rng_state.pth (deflated 26%)
  adding: distilled_model

In [30]:
from google.colab import files
files.download("distilled_model_final_checkpoint.zip")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>