In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import (
    BertForSequenceClassification,
    BertTokenizer,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup
)
from datasets import load_dataset
from gensim.models import Word2Vec
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
import torch.nn.utils.prune as prune
import torch.profiler

# 1) Define RNN student (no direct Word2Vec arg)
class RNNStudent(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.rnn = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.attention = nn.Linear(hidden_dim, 1)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, text):
        embedded = self.embedding(text)
        output, (hidden, _) = self.rnn(embedded)
        weights = torch.softmax(self.attention(output).squeeze(-1), dim=1).unsqueeze(-1)
        context = torch.sum(output * weights, dim=1)
        return self.fc(context)

# 2) Distillation loss (unchanged)
def distillation_loss(student_logits, teacher_logits, labels, T=3.0, alpha=0.4, beta=0.2):
    soft  = nn.KLDivLoss(reduction='batchmean')(
                nn.functional.log_softmax(student_logits / T, dim=1),
                nn.functional.softmax(teacher_logits / T, dim=1)
            ) * (T * T)
    hard  = nn.CrossEntropyLoss()(student_logits, labels)
    seq_l = nn.CrossEntropyLoss()(student_logits, teacher_logits.argmax(dim=1))
    return alpha * soft + (1 - alpha - beta) * hard + beta * seq_l

# Fine-Tune Teacher Model with Multiple Learning Rates
def fine_tune_teacher(teacher, train_loader, val_loader, epochs=3):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    teacher.train()
    teacher.to(device)

    # Define parameter groups for multiple learning rates
    param_groups = [
        {"params": [p for n, p in teacher.named_parameters() if "classifier" not in n], "lr": 1e-5},  # Lower LR for BERT layers
        {"params": [p for n, p in teacher.named_parameters() if "classifier" in n], "lr": 5e-5}    # Higher LR for classifier head
    ]
    optimizer = optim.Adam(param_groups)
    criterion = nn.CrossEntropyLoss()
    
    # Linear scheduler with warmup
    total_steps = len(train_loader) * epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

    for epoch in range(epochs):
        running_loss = 0.0
        teacher.train()
        for batch in train_loader:
            texts = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            optimizer.zero_grad()
            outputs = teacher(input_ids=texts, attention_mask=attention_mask).logits
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()
            running_loss += loss.item()
        print(f"[Teacher] Epoch {epoch+1}, Train Loss: {running_loss/len(train_loader)}")

        # Validation
        teacher.eval()
        preds, true_labels = [], []
        val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                texts = batch['input_ids'].to(device)
                labels = batch['labels'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                outputs = teacher(input_ids=texts, attention_mask=attention_mask).logits
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                preds.extend(torch.argmax(outputs, dim=1).cpu().numpy())
                true_labels.extend(labels.cpu().numpy())
        accuracy = accuracy_score(true_labels, preds)
        f1 = f1_score(true_labels, preds, average='weighted')
        print(f"[Teacher] Epoch {epoch+1}, Val Loss: {val_loss/len(val_loader)}, Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}")

# Training Function for Knowledge Distillation
def train_kd(teacher, student, train_loader, val_loader, epochs=6, prune_amount=0.3):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    teacher.eval()
    student.train()
    optimizer = optim.Adam(student.parameters(), lr=0.0005)
    student.to(device)
    teacher.to(device)

    # Apply pruning to student
    for name, module in student.named_modules():
        if isinstance(module, nn.LSTM):
            prune.l1_unstructured(module, name='weight_ih_l0', amount=prune_amount)
            prune.l1_unstructured(module, name='weight_hh_l0', amount=prune_amount)
        elif isinstance(module, nn.Linear) and module != student.attention:
            prune.l1_unstructured(module, name='weight', amount=prune_amount)

    for epoch in range(epochs):
        running_loss = 0.0
        student.train()
        for batch in train_loader:
            texts = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            optimizer.zero_grad()
            with torch.no_grad():
                teacher_outputs = teacher(input_ids=texts, attention_mask=attention_mask).logits
            student_outputs = student(texts)
            loss = distillation_loss(student_outputs, teacher_outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"[KD] Epoch {epoch+1}, Train Loss: {running_loss/len(train_loader)}")

        # Validation
        student.eval()
        preds, true_labels = [], []
        val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                texts = batch['input_ids'].to(device)
                labels = batch['labels'].to(device)
                outputs = student(texts)
                loss = nn.CrossEntropyLoss()(outputs, labels)
                val_loss += loss.item()
                preds.extend(torch.argmax(outputs, dim=1).cpu().numpy())
                true_labels.extend(labels.cpu().numpy())
        accuracy = accuracy_score(true_labels, preds)
        f1 = f1_score(true_labels, preds, average='weighted')
        print(f"[KD] Epoch {epoch+1}, Val Loss: {val_loss/len(val_loader)}, Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}")

    # Efficiency Metrics
    student.eval()
    with torch.no_grad():
        for batch in val_loader:
            texts = batch['input_ids'].to(device)
            with torch.profiler.profile(record_shapes=True) as prof:
                student(texts)
            break
    print("Student Model Efficiency Metrics:")
    print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

    param_count = sum(p.numel() for p in student.parameters() if p.requires_grad)
    print(f"Student Model Parameter Count: {param_count}")

# Load and Preprocess Data
# Load & tokenize IMDB
dataset   = load_dataset("imdb")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

def tokenize_fn(examples):
    return tokenizer(
        examples["text"], padding="max_length",
        truncation=True, max_length=128
    )

tokenized = (
    dataset
    .map(tokenize_fn, batched=True)
    .rename_column("label", "labels")
    .remove_columns("text")
)
tokenized.set_format("torch")
split = tokenized["train"].train_test_split(test_size=0.2)
train_ds, val_ds = split["train"], split["test"]
test_ds = tokenized["test"]

collator    = DataCollatorWithPadding(tokenizer)
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True,  collate_fn=collator)
val_loader   = DataLoader(val_ds,   batch_size=32, shuffle=False, collate_fn=collator)
test_loader  = DataLoader(test_ds,  batch_size=32, shuffle=False, collate_fn=collator)

# 3) Train Word2Vec on raw-words
sentences = [text.split() for text in dataset['train']['text']]
w2v = Word2Vec(sentences, vector_size=100, window=5, min_count=1, workers=4)

# 4) Build hybrid embedding matrix
vocab_size = tokenizer.vocab_size   # 30522
embed_dim  = 100

# start random, then overwrite known tokens
emb_matrix = np.random.normal(size=(vocab_size, embed_dim)).astype(np.float32)
for token, idx in tokenizer.vocab.items():
    if token in w2v.wv:
        emb_matrix[idx] = w2v.wv[token]

# 5) Initialize models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher = BertForSequenceClassification.from_pretrained("bert-base-uncased").to(device)
student = RNNStudent(vocab_size, embed_dim, hidden_dim=128, output_dim=2).to(device)

# copy in the hybrid embeddings
student.embedding.weight.data.copy_(torch.from_numpy(emb_matrix))
# Fine-Tune Teacher and Perform Knowledge Distillation
fine_tune_teacher(teacher, train_loader, val_loader, epochs=3)
train_kd(teacher, student, train_loader, val_loader, epochs=6)

# Final Test Evaluation
student.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
student.to(device)
preds, true_labels = [], []
with torch.no_grad():
    for batch in test_loader:
        texts = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)
        outputs = student(texts)
        preds.extend(torch.argmax(outputs, dim=1).cpu().numpy())
        true_labels.extend(labels.cpu().numpy())
accuracy = accuracy_score(true_labels, preds)
f1 = f1_score(true_labels, preds, average='weighted')
print(f"[KD] Final Test Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}")

Map: 100%|██████████| 25000/25000 [01:15<00:00, 331.02 examples/s]
Map: 100%|██████████| 25000/25000 [01:04<00:00, 386.01 examples/s]
Map: 100%|██████████| 50000/50000 [02:10<00:00, 382.23 examples/s]
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[Teacher] Epoch 1, Train Loss: 0.3566975662469864
[Teacher] Epoch 1, Val Loss: 0.29004416553078183, Accuracy: 0.8748, F1 Score: 0.8747
[Teacher] Epoch 2, Train Loss: 0.23664038577079774
[Teacher] Epoch 2, Val Loss: 0.3073555787989668, Accuracy: 0.8740, F1 Score: 0.8738
[Teacher] Epoch 3, Train Loss: 0.17687230843007565
[Teacher] Epoch 3, Val Loss: 0.32172115856579914, Accuracy: 0.8730, F1 Score: 0.8729
[KD] Epoch 1, Train Loss: 0.8141248406410218
[KD] Epoch 1, Val Loss: 0.4488598212694666, Accuracy: 0.7910, F1 Score: 0.7909
[KD] Epoch 2, Train Loss: 0.48520623807907104
[KD] Epoch 2, Val Loss: 0.4178249608179566, Accuracy: 0.8138, F1 Score: 0.8128
[KD] Epoch 3, Train Loss: 0.36743490434885023
[KD] Epoch 3, Val Loss: 0.372931222864397, Accuracy: 0.8368, F1 Score: 0.8367
[KD] Epoch 4, Train Loss: 0.2849605684518814
[KD] Epoch 4, Val Loss: 0.37012320614544447, Accuracy: 0.8434, F1 Score: 0.8434
[KD] Epoch 5, Train Loss: 0.2238520857155323
[KD] Epoch 5, Val Loss: 0.39473469009634793, Accura