In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import BertForSequenceClassification, BertTokenizer, DataCollatorWithPadding
from transformers.optimization import get_linear_schedule_with_warmup
from datasets import load_dataset
from gensim.models import Word2Vec
from torch.nn.utils import prune
from torch.profiler import profile, record_function, ProfilerActivity
import numpy as np
from tqdm import tqdm

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Configuration
class Config:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    batch_size = 32
    max_len = 128
    teacher_epochs = 3
    student_epochs = 10
    teacher_lr = {"bert": 1e-5, "classifier": 5e-5}
    student_lr = {
        "embedding": 1e-4,  # Lower rate for pre-trained embeddings
        "rnn": 5e-4,       # Higher rate for LSTM
        "attention": 3e-4, # Moderate rate for attention
        "classifier": 1e-3 # Higher rate for task-specific layer
    }
    warmup_steps = 0.1  # 10% of total steps
    T = 3.2  # Temperature for KD
    alpha = 0.4  # Weight for soft loss
    beta = 0.2  # Weight for sequence loss
    patience = 5  # Early stopping patience
    embed_dim = 100  # Word2Vec embedding dimension
    hidden_dim = 256  # RNN hidden dimension
    dropout = 0.3  # Dropout rate
    prune_amount = 0.3  # Pruning percentage

# Load and preprocess IMDB dataset
def load_imdb_data():
    dataset = load_dataset("imdb")
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    
    def preprocess(example):
        encoding = tokenizer(
            example["text"],
            max_length=Config.max_len,
            truncation=True,
            padding="max_length",
            return_tensors="pt"
        )
        return {
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
            "label": example["label"]
        }
    
    train_dataset = dataset["train"].map(preprocess)
    val_dataset = dataset["test"].map(preprocess)  # Using test split as validation
    test_dataset = dataset["test"].map(preprocess)
    
    train_dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])
    val_dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])
    test_dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])
    
    return train_dataset, val_dataset, test_dataset, tokenizer

# Train Word2Vec embeddings
def train_word2vec(dataset, tokenizer):
    sentences = [tokenizer.decode(example["input_ids"], skip_special_tokens=True).split() for example in dataset]
    model = Word2Vec(sentences, vector_size=Config.embed_dim, window=5, min_count=1, workers=4)
    return model

# Student RNN Model
class RNNStudent(nn.Module):
    def __init__(self, word2vec_model, vocab_size, embed_dim, hidden_dim, num_classes=2):
        super(RNNStudent, self).__init__()
        # Embedding layer using Word2Vec weights
        embedding_matrix = torch.FloatTensor(word2vec_model.wv.vectors)
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.embedding.weight = nn.Parameter(embedding_matrix)
        self.embedding.weight.requires_grad = True  # Allow fine-tuning
        
        # 2-layer LSTM with dropout
        self.rnn = nn.LSTM(embed_dim, hidden_dim, num_layers=2, batch_first=True, dropout=Config.dropout)
        # Attention mechanism
        self.attention = nn.Linear(hidden_dim, 1)
        self.fc = nn.Linear(hidden_dim, num_classes)
        self.dropout = nn.Dropout(Config.dropout)
    
    def forward(self, input_ids, attention_mask=None):
        embedded = self.embedding(input_ids)  # [batch, seq_len, embed_dim]
        rnn_output, _ = self.rnn(embedded)  # [batch, seq_len, hidden_dim]
        
        # Attention weights
        attn_weights = torch.softmax(self.attention(rnn_output).squeeze(-1), dim=-1)  # [batch, seq_len]
        context = torch.bmm(attn_weights.unsqueeze(1), rnn_output).squeeze(1)  # [batch, hidden_dim]
        
        output = self.dropout(context)
        logits = self.fc(output)  # [batch, num_classes]
        return logits, rnn_output

# Apply pruning to student model
def apply_pruning(model):
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            prune.l1_unstructured(module, name="weight", amount=Config.prune_amount)
            if module.bias is not None:
                prune.l1_unstructured(module, name="bias", amount=Config.prune_amount)
        elif isinstance(module, nn.LSTM):
            # Prune all weight and bias parameters of LSTM
            for param_name in list(module._parameters.keys()):
                if "weight" in param_name:
                    prune.l1_unstructured(module, name=param_name, amount=Config.prune_amount)
                elif "bias" in param_name and module._parameters[param_name] is not None:
                    prune.l1_unstructured(module, name=param_name, amount=Config.prune_amount)

# Compute hybrid loss
def compute_kd_loss(student_logits, teacher_logits, student_seq, teacher_seq, labels):
    # Soft loss (KL divergence)
    soft_loss = nn.KLDivLoss(reduction="batchmean")(
        nn.functional.log_softmax(student_logits / Config.T, dim=-1),
        nn.functional.softmax(teacher_logits / Config.T, dim=-1)
    ) * (Config.T ** 2)
    
    # Sequence-level loss (MSE on hidden states)
    seq_loss = nn.MSELoss()(student_seq, teacher_seq)
    
    # Hard loss (cross-entropy)
    hard_loss = nn.CrossEntropyLoss()(student_logits, labels)
    
    # Hybrid loss
    return Config.alpha * soft_loss + Config.beta * seq_loss + (1 - Config.alpha - Config.beta) * hard_loss

# Evaluate model
def evaluate(model, dataloader, device, is_teacher=False):
    model.eval()
    total_loss, total_correct, total_samples = 0, 0, 0
    all_preds, all_labels = [], []
    
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = torch.stack(batch["labels"]).to(device)  # Fix: Use "labels" and convert to tensor
            if is_teacher:
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss
                logits = outputs.logits
            else:
                logits, _ = model(input_ids, attention_mask)
                loss = nn.CrossEntropyLoss()(logits, labels)
            
            total_loss += loss.item()
            preds = torch.argmax(logits, dim=-1)
            total_correct += (preds == labels).sum().item()
            total_samples += labels.size(0)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    accuracy = total_correct / total_samples
    f1 = np.mean([1 if p == l else 0 for p, l in zip(all_preds, all_labels)])  # Simplified F1
    return total_loss / len(dataloader), accuracy, f1

# Fine-tune teacher (BERT)
def fine_tune_teacher(teacher, train_dataloader, val_dataloader, device):
    teacher.train()
    optimizer = optim.AdamW([
        {"params": teacher.bert.parameters(), "lr": Config.teacher_lr["bert"]},
        {"params": teacher.classifier.parameters(), "lr": Config.teacher_lr["classifier"]}
    ])
    total_steps = len(train_dataloader) * Config.teacher_epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(Config.warmup_steps * total_steps), num_training_steps=total_steps)
    
    for epoch in range(Config.teacher_epochs):
        total_loss = 0
        for batch in tqdm(train_dataloader, desc=f"[Teacher] Epoch {epoch+1}"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = torch.stack(batch["labels"]).to(device)  # Fix: Use "labels" and convert to tensor
            teacher.zero_grad()
            outputs = teacher(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            total_loss += loss.item()
            
            loss.backward()
            optimizer.step()
            scheduler.step()
        
        avg_train_loss = total_loss / len(train_dataloader)
        val_loss, val_acc, val_f1 = evaluate(teacher, val_dataloader, device, is_teacher=True)
        print(f"[Teacher] Epoch {epoch+1}, Train Loss: {avg_train_loss}")
        print(f"[Teacher] Epoch {epoch+1}, Val Loss: {val_loss}, Accuracy: {val_acc:.4f}, F1 Score: {val_f1:.4f}")

# Train student with KD and different learning rates
def train_student(teacher, student, train_dataloader, val_dataloader, device):
    teacher.eval()
    student.train()
    optimizer = optim.AdamW([
        {"params": student.embedding.parameters(), "lr": Config.student_lr["embedding"]},
        {"params": student.rnn.parameters(), "lr": Config.student_lr["rnn"]},
        {"params": student.attention.parameters(), "lr": Config.student_lr["attention"]},
        {"params": student.fc.parameters(), "lr": Config.student_lr["classifier"]}
    ])
    total_steps = len(train_dataloader) * Config.student_epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(Config.warmup_steps * total_steps), num_training_steps=total_steps)
    
    best_val_loss = float("inf")
    counter = 0
    
    for epoch in range(Config.student_epochs):
        total_loss = 0
        with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
            with record_function("model_training"):
                for batch in tqdm(train_dataloader, desc=f"[KD] Epoch {epoch+1}"):
                    input_ids = batch["input_ids"].to(device)
                    attention_mask = batch["attention_mask"].to(device)
                    labels = torch.stack(batch["labels"]).to(device)  # Fix: Use "labels" and convert to tensor
                    student.zero_grad()
                    with torch.no_grad():
                        teacher_outputs = teacher(input_ids, attention_mask=attention_mask)
                        teacher_logits = teacher_outputs.logits
                        teacher_seq = teacher.bert(input_ids, attention_mask=attention_mask)[0]  # Last hidden state
                    
                    student_logits, student_seq = student(input_ids, attention_mask)
                    loss = compute_kd_loss(student_logits, teacher_logits, student_seq, teacher_seq, labels)
                    total_loss += loss.item()
                    
                    loss.backward()
                    optimizer.step()
                    scheduler.step()
        
        avg_train_loss = total_loss / len(train_dataloader)
        val_loss, val_acc, val_f1 = evaluate(student, val_dataloader, device)
        print(f"[KD] Epoch {epoch+1}, Train Loss: {avg_train_loss}")
        print(f"[KD] Epoch {epoch+1}, Val Loss: {val_loss}, Accuracy: {val_acc:.4f}, F1 Score: {val_f1:.4f}")
        
        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            counter = 0
            torch.save(student.state_dict(), "best_student.pt")
        else:
            counter += 1
        if counter >= Config.patience:
            print("Early stopping triggered")
            break
    
    print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

# Count model parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Main execution
def main():
    # Load data
    train_dataset, val_dataset, test_dataset, tokenizer = load_imdb_data()
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
    train_dataloader = DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True, collate_fn=data_collator)
    val_dataloader = DataLoader(val_dataset, batch_size=Config.batch_size, collate_fn=data_collator)
    test_dataloader = DataLoader(test_dataset, batch_size=Config.batch_size, collate_fn=data_collator)
    
    # Train Word2Vec
    word2vec_model = train_word2vec(train_dataset, tokenizer)
    
    # Initialize models
    teacher = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2).to(Config.device)
    student = RNNStudent(word2vec_model, tokenizer.vocab_size, Config.embed_dim, Config.hidden_dim).to(Config.device)
    
    # Apply pruning
    apply_pruning(student)
    
    # Fine-tune teacher
    fine_tune_teacher(teacher, train_dataloader, val_dataloader, Config.device)
    
    # Train student
    train_student(teacher, student, train_dataloader, val_dataloader, Config.device)
    
    # Load best student model
    student.load_state_dict(torch.load("best_student.pt"))
    
    # Evaluate on test set
    test_loss, test_acc, test_f1 = evaluate(student, test_dataloader, Config.device)
    print(f"[KD] Final Test Accuracy: {test_acc:.4f}, F1 Score: {test_f1:.4f}")
    
    # Print efficiency metrics
    print("Student Model Efficiency Metrics:")
    print(f"Parameter Count: {count_parameters(student)}")

if __name__ == "__main__":
    main()