In [None]:
# ✅ Improved Knowledge Distillation on IMDb
# Goal: Boost student accuracy closer to teacher's

!pip install transformers datasets -q

import torch
import torch.nn as nn
from datasets import load_dataset
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, f1_score
import torch.nn.functional as F

# ✅ Load and tokenize IMDb dataset (larger subset for better student learning)
dataset = load_dataset("imdb")
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")

def tokenize(example):
    return tokenizer(example["text"], truncation=True, padding="max_length", max_length=256)

tokenized = dataset.map(tokenize, batched=True)
tokenized = tokenized.remove_columns(["text"])
tokenized.set_format("torch")

train_data = tokenized["train"].shuffle(seed=42).select(range(5000))
test_data = tokenized["test"].shuffle(seed=42).select(range(500))

train_loader = DataLoader(train_data, batch_size=16)
test_loader = DataLoader(test_data, batch_size=16)

# ✅ Load pretrained DistilBERT as teacher
teacher = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")
teacher.train()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher.to(device)

optimizer = torch.optim.Adam(teacher.parameters(), lr=2e-5)
print(f"Train loader: {len(train_loader)}")

count = 0
# Teacher fine-tuning
for epoch in range(1):
    for batch in train_loader:
        print(f"batch_count: {count}")
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)
        outputs = teacher(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        count += 1

teacher.eval()

# ✅ Enhanced student model (deeper & higher capacity MLP)
class StudentMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(30522, 128)
        self.fc1 = nn.Linear(128, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 2)

    def forward(self, input_ids):
        x = self.embedding(input_ids).mean(dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

student = StudentMLP().to(device)
optimizer_s = torch.optim.Adam(student.parameters(), lr=5e-4)

# ✅ Train student with knowledge distillation (multiple epochs)
temperature = 2.0
epochs = 3

for epoch in range(epochs):
    student.train()
    for batch in train_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)

        with torch.no_grad():
            t_logits = teacher(input_ids=input_ids, attention_mask=attention_mask).logits

        s_logits = student(input_ids)

        hard_loss = F.cross_entropy(s_logits, labels)
        soft_loss = F.kl_div(
            F.log_softmax(s_logits / temperature, dim=1),
            F.softmax(t_logits / temperature, dim=1),
            reduction='batchmean'
        ) * (temperature ** 2)

        loss = 0.5 * hard_loss + 0.5 * soft_loss

        optimizer_s.zero_grad()
        loss.backward()
        optimizer_s.step()

# ✅ Evaluate both models
student.eval()
teacher.eval()

all_student_preds, all_teacher_preds, all_labels = [], [], []

with torch.no_grad():
    for batch in test_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)

        teacher_logits = teacher(input_ids=input_ids, attention_mask=attention_mask).logits
        student_logits = student(input_ids)

        teacher_preds = torch.argmax(teacher_logits, dim=1)
        student_preds = torch.argmax(student_logits, dim=1)

        all_teacher_preds.extend(teacher_preds.tolist())
        all_student_preds.extend(student_preds.tolist())
        all_labels.extend(labels.tolist())

teacher_acc = accuracy_score(all_labels, all_teacher_preds)
student_acc = accuracy_score(all_labels, all_student_preds)

teacher_f1 = f1_score(all_labels, all_teacher_preds)
student_f1 = f1_score(all_labels, all_student_preds)
retention = 100 * (student_acc / teacher_acc)

print(f"Teacher Accuracy: {teacher_acc:.4f}, F1 Score: {teacher_f1:.4f}")
print(f"Student Accuracy: {student_acc:.4f}, F1 Score: {student_f1:.4f}")
print(f"Student retains {retention:.2f}% of teacher's accuracy")


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/491.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m [32m481.3/491.4 kB[0m [31m25.9 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m [32m481.3/491.4 kB[0m [31m25.9 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m [32m481.3/491.4 kB[0m [31m25.9 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.4/491.4 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/116.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.6/193.6 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/7.81k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/21.0M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

unsupervised-00000-of-00001.parquet:   0%|          | 0.00/42.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Train loader: 313
batch_count: 0
batch_count: 1
batch_count: 2
batch_count: 3
batch_count: 4
batch_count: 5
batch_count: 6
batch_count: 7
batch_count: 8
batch_count: 9
batch_count: 10
batch_count: 11
batch_count: 12
batch_count: 13
batch_count: 14
batch_count: 15
batch_count: 16
batch_count: 17
batch_count: 18
batch_count: 19
batch_count: 20
batch_count: 21
batch_count: 22
batch_count: 23
batch_count: 24
batch_count: 25
batch_count: 26
batch_count: 27
batch_count: 28
batch_count: 29
batch_count: 30
batch_count: 31
batch_count: 32
batch_count: 33
batch_count: 34
batch_count: 35
batch_count: 36
batch_count: 37
batch_count: 38
batch_count: 39
batch_count: 40
batch_count: 41
batch_count: 42
batch_count: 43
batch_count: 44
batch_count: 45
batch_count: 46
batch_count: 47
batch_count: 48
batch_count: 49
batch_count: 50
batch_count: 51
batch_count: 52
batch_count: 53
batch_count: 54
batch_count: 55
batch_count: 56
batch_count: 57
batch_count: 58
batch_count: 59
batch_count: 60
batch_count: 61
