In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import BertForSequenceClassification, BertTokenizer,DataCollatorWithPadding
from datasets import load_dataset

# Define RNN Student Model
class RNNStudent(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim):
        super(RNNStudent, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.rnn = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, text):
        embedded = self.embedding(text)
        output, (hidden, _) = self.rnn(embedded)
        return self.fc(hidden[-1])

# Distillation Loss
def distillation_loss(student_logits, teacher_logits, labels, T=2.0, alpha=0.5):
    soft_loss = nn.KLDivLoss(reduction='batchmean')(
        nn.functional.log_softmax(student_logits / T, dim=1),
        nn.functional.softmax(teacher_logits / T, dim=1)
    ) * (T * T)
    hard_loss = nn.CrossEntropyLoss()(student_logits, labels)
    return alpha * soft_loss + (1 - alpha) * hard_loss

# Training Function
def train_kd(teacher, student, train_loader, epochs=5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    teacher.eval().to(device)
    student.train().to(device)
    optimizer = optim.Adam(student.parameters(), lr=1e-3)

    for epoch in range(epochs):
        running_loss = 0.0
        for batch in train_loader:
            input_ids      = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels         = batch["labels"].to(device)

            optimizer.zero_grad()
            with torch.no_grad():
                teacher_logits = teacher(input_ids, attention_mask=attention_mask).logits
            student_logits = student(input_ids)
            loss = distillation_loss(student_logits, teacher_logits, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        avg_loss = running_loss / len(train_loader)
        print(f"Epoch {epoch+1:>2}, Loss: {avg_loss:.4f}")


# Load Data
dataset = load_dataset("imdb")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", force_download=True)

print(dataset)
# def tokenize_function(examples):
#     return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)

tokenized = dataset.map(tokenize_function, batched=True)
tokenized = tokenized.rename_column("label", "labels")
tokenized = tokenized.remove_columns(["text"])
tokenized.set_format("torch")

# tokenized_datasets = dataset.map(tokenize_function, batched=True)
# train_loader = DataLoader(tokenized_datasets["train"], batch_size=32, shuffle=True)

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
train_loader = DataLoader(tokenized["train"], batch_size=32, shuffle=True, collate_fn=data_collator)
eval_loader  = DataLoader(tokenized["test"],   batch_size=32, shuffle=False, collate_fn=collator)

# Initialize Models
# teacher = BertForSequenceClassification.from_pretrained("bert-base-uncased" ,cache_dir='./hf_cache')
teacher = BertForSequenceClassification.from_pretrained("bert-base-uncased", force_download=True)
student = RNNStudent(vocab_size=tokenizer.vocab_size, embed_dim=100, hidden_dim=128, output_dim=2)
train_kd(teacher, student, train_loader)

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import (
    BertForSequenceClassification,
    BertTokenizer,
    DataCollatorWithPadding,
)
from datasets import load_dataset

# 1) Define RNN student
class RNNStudent(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.rnn = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, text):
        embedded = self.embedding(text)
        _, (hidden, _) = self.rnn(embedded)
        return self.fc(hidden[-1])

# 2) Distillation loss
def distillation_loss(student_logits, teacher_logits, labels, T=2.0, alpha=0.5):
    soft_loss = nn.KLDivLoss(reduction="batchmean")(
        nn.functional.log_softmax(student_logits / T, dim=1),
        nn.functional.softmax(teacher_logits / T, dim=1),
    ) * (T * T)
    hard_loss = nn.CrossEntropyLoss()(student_logits, labels)
    return alpha * soft_loss + (1 - alpha) * hard_loss

# 3) Teacher training / evaluation
def train_teacher(model, loader, optimizer, device, epochs=3):
    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        for batch in loader:
            optimizer.zero_grad()
            inputs = {k: batch[k].to(device) for k in ["input_ids", "attention_mask", "labels"]}
            loss = model(**inputs).loss
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"[Teacher] Epoch {epoch+1}, Loss {total_loss/len(loader):.4f}")

def eval_teacher(model, loader, device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for batch in loader:
            inputs = {k: batch[k].to(device) for k in ["input_ids", "attention_mask"]}
            labels = batch["labels"].to(device)
            logits = model(**inputs).logits
            preds = logits.argmax(dim=-1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    print(f"[Teacher] Accuracy: {100*correct/total:.2f}%")

# 4) Distillation training
def train_kd(teacher, student, loader, device, epochs=5):
    teacher.eval().to(device)
    student.train().to(device)
    optimizer = optim.Adam(student.parameters(), lr=1e-3)

    for epoch in range(epochs):
        total_loss = 0.0
        for batch in loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            optimizer.zero_grad()
            with torch.no_grad():
                teacher_logits = teacher(input_ids, attention_mask=attention_mask).logits
            student_logits = student(input_ids)
            loss = distillation_loss(student_logits, teacher_logits, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"[KD] Epoch {epoch+1}, Loss {total_loss/len(loader):.4f}")

# ──────────────────────────────────────────────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 5) Load & preprocess data
dataset   = load_dataset("imdb")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", force_download=True)

def tokenize_fn(examples):
    return tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
        max_length=128
    )

tokenized = (
    dataset
    .map(tokenize_fn, batched=True)
    .rename_column("label", "labels")
    .remove_columns("text")
)
tokenized.set_format("torch")

collator   = DataCollatorWithPadding(tokenizer=tokenizer)
train_loader = DataLoader(tokenized["train"], batch_size=16, shuffle=True,  collate_fn=collator)
eval_loader  = DataLoader(tokenized["test"],  batch_size=32, shuffle=False, collate_fn=collator)

# 6) Initialize teacher & freeze its body
teacher = BertForSequenceClassification.from_pretrained(
    "bert-base-uncased", force_download=True
)
for param in teacher.bert.parameters():
    param.requires_grad = False

# only classifier head will train initially
optimizer_head = optim.Adam(teacher.classifier.parameters(), lr=2e-5)

# 7) Fine‑tune teacher
train_teacher(teacher, train_loader, optimizer_head, device, epochs=3)
eval_teacher(teacher, eval_loader, device)

# (Optional) Unfreeze last 2 layers and fine‑tune further:
for layer in teacher.bert.encoder.layer[-2:]:
    for param in layer.parameters():
        param.requires_grad = True
optimizer_full = optim.Adam(filter(lambda p: p.requires_grad, teacher.parameters()), lr=1e-5)
train_teacher(teacher, train_loader, optimizer_full, device, epochs=3)
eval_teacher(teacher, eval_loader, device)

# 8) Initialize student and distill
student = RNNStudent(
    vocab_size=tokenizer.vocab_size,
    embed_dim=100,
    hidden_dim=128,
    output_dim=2
)
train_kd(teacher, student, train_loader, device, epochs=5)


Map: 100%|██████████| 25000/25000 [01:04<00:00, 389.47 examples/s]
Map: 100%|██████████| 25000/25000 [01:02<00:00, 401.79 examples/s]
Map: 100%|██████████| 50000/50000 [02:07<00:00, 393.11 examples/s]
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[Teacher] Epoch 1, Loss 0.6907
[Teacher] Epoch 2, Loss 0.6733
[Teacher] Epoch 3, Loss 0.6595
[Teacher] Accuracy: 63.14%
[Teacher] Epoch 1, Loss 0.3918
[Teacher] Epoch 2, Loss 0.3183
[Teacher] Accuracy: 87.30%
[KD] Epoch 1, Loss 0.7695
[KD] Epoch 2, Loss 0.5815
[KD] Epoch 3, Loss 0.2890
[KD] Epoch 4, Loss 0.2191
[KD] Epoch 5, Loss 0.1783


In [11]:
from sklearn.metrics import accuracy_score, f1_score
 # Evaluate student model
def eval_kd(student, eval_loader, epoch=5):
    # teacher.eval()
    # student.eval()
    # y_true = []
    # y_pred = []
    # with torch.no_grad():
    #     for data, target in test_loader:
    #         data, target = data.to(device), target.to(device)
    #         output = student(data)
    #         pred = output.argmax(dim=1, keepdim=True)
    #         y_true.extend(target.view_as(pred))
    #         y_pred.extend(pred)
    # accuracy = accuracy_score(y_true, y_pred)
    # f1 = f1_score(y_true, y_pred, average='weighted')
    # print('Accuracy: {:.2f}%'.format(accuracy * 100))
    # print('F1 score: {:.2f}'.format(f1))
    student.eval()
    preds, true_labels = [], []
    with torch.no_grad():
        for batch in eval_loader:
            texts = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)
            outputs = student(texts)
            preds.extend(torch.argmax(outputs, dim=1).cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
        accuracy = accuracy_score(true_labels, preds)
        f1 = f1_score(true_labels, preds, average='weighted')
        print(f"[KD] Epoch {epoch+1}, Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}")
        
eval_kd(student, eval_loader)

[KD] Epoch 6, Accuracy: 0.8354, F1 Score: 0.8354


In [12]:
with torch.profiler.profile() as prof:
    student(texts)
print(prof.key_averages().table(sort_by="cpu_time_total"))

-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
              aten::embedding        17.48%       9.538ms        63.24%      34.509ms      34.509ms             1  
                   aten::lstm        14.13%       7.710ms        35.46%      19.350ms      19.350ms             1  
           aten::index_select        23.48%      12.812ms        24.28%      13.250ms      13.250ms             1  
                aten::reshape        21.46%      11.711ms        21.47%      11.718ms      11.718ms             1  
       aten::mkldnn_rnn_layer        19.57%      10.681ms        19.74%      10.775ms      10.775ms             1  
             aten::contiguous         0.27%     147.900us         1.42% 