In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import BertForSequenceClassification, BertTokenizer, DataCollatorWithPadding
from datasets import load_dataset
from gensim.models import Word2Vec
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
import torch.nn.utils.prune as prune
import torch.profiler

# Define RNN Student Model with Word2Vec Embeddings
class RNNStudent(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim, word2vec_model=None):
        super(RNNStudent, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        if word2vec_model:
            self.embedding.weight.data.copy_(torch.from_numpy(word2vec_model.wv.vectors))
        self.rnn = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, text):
        embedded = self.embedding(text)
        output, (hidden, _) = self.rnn(embedded)
        return self.fc(hidden[-1])

# Hybrid Distillation Loss (Frame-Level + Sequence-Level)
def distillation_loss(student_logits, teacher_logits, labels, T=2.0, alpha=0.4, beta=0.2):
    soft_loss = nn.KLDivLoss(reduction='batchmean')(
        nn.functional.log_softmax(student_logits / T, dim=1),
        nn.functional.softmax(teacher_logits / T, dim=1)
    ) * (T * T)
    hard_loss = nn.CrossEntropyLoss()(student_logits, labels)
    sequence_loss = nn.CrossEntropyLoss()(student_logits, teacher_logits.argmax(dim=1))  # Sequence-level alignment
    return alpha * soft_loss + (1 - alpha - beta) * hard_loss + beta * sequence_loss

# Fine-Tune Teacher Model
def fine_tune_teacher(teacher, train_loader, val_loader, epochs=3):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    teacher.train()
    optimizer = optim.Adam(teacher.parameters(), lr=2e-5)
    criterion = nn.CrossEntropyLoss()
    teacher.to(device)

    for epoch in range(epochs):
        running_loss = 0.0
        teacher.train()
        for batch in train_loader:
            texts = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            optimizer.zero_grad()
            outputs = teacher(input_ids=texts, attention_mask=attention_mask).logits
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"[Teacher] Epoch {epoch+1}, Train Loss: {running_loss/len(train_loader)}")

        # Validation
        teacher.eval()
        preds, true_labels = [], []
        val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                texts = batch['input_ids'].to(device)
                labels = batch['labels'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                outputs = teacher(input_ids=texts, attention_mask=attention_mask).logits
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                preds.extend(torch.argmax(outputs, dim=1).cpu().numpy())
                true_labels.extend(labels.cpu().numpy())
        accuracy = accuracy_score(true_labels, preds)
        f1 = f1_score(true_labels, preds, average='weighted')
        print(f"[Teacher] Epoch {epoch+1}, Val Loss: {val_loss/len(val_loader)}, Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}")

# Training Function for Knowledge Distillation with Pruning
def train_kd(teacher, student, train_loader, val_loader, epochs=6, prune_amount=0.3):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    teacher.eval()
    student.train()
    optimizer = optim.Adam(student.parameters(), lr=0.0005)  # Reduced LR for stability
    student.to(device)
    teacher.to(device)

    # Apply pruning to student RNN
    for name, module in student.named_modules():
        if isinstance(module, nn.LSTM):
            prune.l1_unstructured(module, name='weight_hh_l0', amount=prune_amount)
            prune.l1_unstructured(module, name='weight_ih_l0', amount=prune_amount)

    for epoch in range(epochs):
        running_loss = 0.0
        student.train()
        for batch in train_loader:
            texts = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            optimizer.zero_grad()
            with torch.no_grad():
                teacher_outputs = teacher(input_ids=texts, attention_mask=attention_mask).logits
            student_outputs = student(texts)
            loss = distillation_loss(student_outputs, teacher_outputs, labels, T=3.0)  # Adjusted T
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"[KD] Epoch {epoch+1}, Train Loss: {running_loss/len(train_loader)}")

        # Validation
        student.eval()
        preds, true_labels = [], []
        val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                texts = batch['input_ids'].to(device)
                labels = batch['labels'].to(device)
                outputs = student(texts)
                loss = nn.CrossEntropyLoss()(outputs, labels)
                val_loss += loss.item()
                preds.extend(torch.argmax(outputs, dim=1).cpu().numpy())
                true_labels.extend(labels.cpu().numpy())
        accuracy = accuracy_score(true_labels, preds)
        f1 = f1_score(true_labels, preds, average='weighted')
        print(f"[KD] Epoch {epoch+1}, Val Loss: {val_loss/len(val_loader)}, Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}")

    # Efficiency Metrics
    student.eval()
    with torch.no_grad():
        for batch in val_loader:
            texts = batch['input_ids'].to(device)
            with torch.profiler.profile(record_shapes=True) as prof:
                student(texts)
            break  # Profile one batch
    print("Student Model Efficiency Metrics:")
    print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

    # Model Size
    param_count = sum(p.numel() for p in student.parameters() if p.requires_grad)
    print(f"Student Model Parameter Count: {param_count}")

# Load and Preprocess Data
dataset = load_dataset("imdb")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)

tokenized = dataset.map(tokenize_function, batched=True)
tokenized = tokenized.rename_column("label", "labels")
tokenized = tokenized.remove_columns(["text"])
tokenized.set_format("torch")

# Split train into train and validation
train_val_split = tokenized["train"].train_test_split(test_size=0.2)
train_dataset = train_val_split["train"]
val_dataset = train_val_split["test"]
test_dataset = tokenized["test"]

# Train Word2Vec on IMDB
sentences = [text.split() for text in dataset['train']['text']]
word2vec_model = Word2Vec(sentences, vector_size=100, window=5, min_count=1, workers=4)

# DataLoader
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=data_collator)
val_loader = DataLoader(val_dataset, batch_size=32, collate_fn=data_collator)
test_loader = DataLoader(test_dataset, batch_size=32, collate_fn=data_collator)

# Initialize Models
teacher = BertForSequenceClassification.from_pretrained("bert-base-uncased")
student = RNNStudent(vocab_size=tokenizer.vocab_size, embed_dim=100, hidden_dim=128, output_dim=2, word2vec_model=word2vec_model)

# Fine-Tune Teacher and Perform Knowledge Distillation
fine_tune_teacher(teacher, train_loader, val_loader, epochs=3)
train_kd(teacher, student, train_loader, val_loader, epochs=6, prune_amount=0.3)

# Final Test Evaluation
student.eval()
preds, true_labels = [], []
with torch.no_grad():
    for batch in test_loader:
        texts = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)
        outputs = student(texts)
        preds.extend(torch.argmax(outputs, dim=1).cpu().numpy())
        true_labels.extend(labels.cpu().numpy())
accuracy = accuracy_score(true_labels, preds)
f1 = f1_score(true_labels, preds, average='weighted')
print(f"[KD] Final Test Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}")