In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import BertForSequenceClassification, BertTokenizer, DataCollatorWithPadding
from transformers.optimization import get_linear_schedule_with_warmup
from datasets import load_dataset
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
import torch.nn.utils.prune as prune
import torch.profiler
try:
    from gensim.models import Word2Vec
except ImportError:
    raise ImportError("Please install gensim: pip install gensim")

# Required dependencies: pip install gensim torch transformers datasets numpy scikit-learn

# Define RNN Student Model with Attention and Word2Vec Embeddings
class RNNStudent(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim, word2vec_model=None, tokenizer=None):
        super(RNNStudent, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        if word2vec_model and tokenizer:
            # Initialize embedding matrix with random values
            embedding_matrix = torch.randn(vocab_size, embed_dim) * 0.01
            # Copy Word2Vec vectors for matching tokens, ensuring writability
            for token, idx in tokenizer.vocab.items():
                if token in word2vec_model.wv:
                    embedding_matrix[idx] = torch.from_numpy(word2vec_model.wv[token].copy())
            self.embedding.weight.data.copy_(embedding_matrix)
        self.rnn = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.attention = nn.Linear(hidden_dim, 1)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, text):
        embedded = self.embedding(text)
        output, (hidden, _) = self.rnn(embedded)
        attention_weights = torch.softmax(self.attention(output).squeeze(-1), dim=1).unsqueeze(-1)
        context = torch.sum(output * attention_weights, dim=1)
        return self.fc(context)

# Hybrid Distillation Loss
def distillation_loss(student_logits, teacher_logits, labels, T=3.0, alpha=0.4, beta=0.2):
    soft_loss = nn.KLDivLoss(reduction='batchmean')(
        nn.functional.log_softmax(student_logits / T, dim=1),
        nn.functional.softmax(teacher_logits / T, dim=1)
    ) * (T * T)
    hard_loss = nn.CrossEntropyLoss()(student_logits, labels)
    sequence_loss = nn.CrossEntropyLoss()(student_logits, teacher_logits.argmax(dim=1))
    return alpha * soft_loss + (1 - alpha - beta) * hard_loss + beta * sequence_loss

# Fine-Tune Teacher Model with Multiple Learning Rates
def fine_tune_teacher(teacher, train_loader, val_loader, epochs=3):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    teacher.train()
    teacher.to(device)
    param_groups = [
        {"params": [p for n, p in teacher.named_parameters() if "classifier" not in n], "lr": 1e-5},
        {"params": [p for n, p in teacher.named_parameters() if "classifier" in n], "lr": 5e-5}
    ]
    optimizer = optim.Adam(param_groups)
    criterion = nn.CrossEntropyLoss()
    total_steps = len(train_loader) * epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
    for epoch in range(epochs):
        running_loss = 0.0
        teacher.train()
        for batch in train_loader:
            texts = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            optimizer.zero_grad()
            outputs = teacher(input_ids=texts, attention_mask=attention_mask).logits
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()
            running_loss += loss.item()
        print(f"[Teacher] Epoch {epoch+1}, Train Loss: {running_loss/len(train_loader)}")
        teacher.eval()
        preds, true_labels = [], []
        val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                texts = batch['input_ids'].to(device)
                labels = batch['labels'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                outputs = teacher(input_ids=texts, attention_mask=attention_mask).logits
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                preds.extend(torch.argmax(outputs, dim=1).cpu().numpy())
                true_labels.extend(labels.cpu().numpy())
        accuracy = accuracy_score(true_labels, preds)
        f1 = f1_score(true_labels, preds, average='weighted')
        print(f"[Teacher] Epoch {epoch+1}, Val Loss: {val_loss/len(val_loader)}, Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}")

# Training Function for Knowledge Distillation
def train_kd(teacher, student, train_loader, val_loader, epochs=6, prune_amount=0.3):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    teacher.eval()
    student.train()
    optimizer = optim.Adam(student.parameters(), lr=0.0005)
    student.to(device)
    teacher.to(device)
    for name, module in student.named_modules():
        if isinstance(module, nn.LSTM):
            prune.l1_unstructured(module, name='weight_ih_l0', amount=prune_amount)
            prune.l1_unstructured(module, name='weight_hh_l0', amount=prune_amount)
        elif isinstance(module, nn.Linear) and module != student.attention:
            prune.l1_unstructured(module, name='weight', amount=prune_amount)
    for epoch in range(epochs):
        running_loss = 0.0
        student.train()
        for batch in train_loader:
            texts = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            optimizer.zero_grad()
            with torch.no_grad():
                teacher_outputs = teacher(input_ids=texts, attention_mask=attention_mask).logits
            student_outputs = student(texts)
            loss = distillation_loss(student_outputs, teacher_outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"[KD] Epoch {epoch+1}, Train Loss: {running_loss/len(train_loader)}")
        student.eval()
        preds, true_labels = [], []
        val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                texts = batch['input_ids'].to(device)
                labels = batch['labels'].to(device)
                outputs = student(texts)
                loss = nn.CrossEntropyLoss()(outputs, labels)
                val_loss += loss.item()
                preds.extend(torch.argmax(outputs, dim=1).cpu().numpy())
                true_labels.extend(labels.cpu().numpy())
        accuracy = accuracy_score(true_labels, preds)
        f1 = f1_score(true_labels, preds, average='weighted')
        print(f"[KD] Epoch {epoch+1}, Val Loss: {val_loss/len(val_loader)}, Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}")
    student.eval()
    with torch.no_grad():
        for batch in val_loader:
            texts = batch['input_ids'].to(device)
            with torch.profiler.profile(record_shapes=True) as prof:
                student(texts)
            break
    print("Student Model Efficiency Metrics:")
    print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
    param_count = sum(p.numel() for p in student.parameters() if p.requires_grad)
    print(f"Student Model Parameter Count: {param_count}")

# Load and Preprocess Data
dataset = load_dataset("imdb")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)

tokenized = dataset.map(tokenize_function, batched=True)
tokenized = tokenized.rename_column("label", "labels")
tokenized = tokenized.remove_columns(["text"])
tokenized.set_format("torch")

# Split train into train and validation
train_val_split = tokenized["train"].train_test_split(test_size=0.2)
train_dataset = train_val_split["train"]
val_dataset = train_val_split["test"]
test_dataset = tokenized["test"]

# Train Word2Vec on IMDB
sentences = [text.split() for text in dataset['train']['text']]
word2vec_model = Word2Vec(sentences, vector_size=100, window=5, min_count=1, workers=4)

# DataLoader
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=data_collator)
val_loader = DataLoader(val_dataset, batch_size=32, collate_fn=data_collator)
test_loader = DataLoader(test_dataset, batch_size=32, collate_fn=data_collator)

# Initialize Models
# Note: The classifier weights are randomly initialized and will be fine-tuned in fine_tune_teacher
teacher = BertForSequenceClassification.from_pretrained("bert-base-uncased")
student = RNNStudent(
    vocab_size=tokenizer.vocab_size,
    embed_dim=100,
    hidden_dim=128,
    output_dim=2,
    word2vec_model=word2vec_model,
    tokenizer=tokenizer
)

# Fine-Tune Teacher and Perform Knowledge Distillation
fine_tune_teacher(teacher, train_loader, val_loader, epochs=3)
train_kd(teacher, student, train_loader, val_loader, epochs=6)

# Final Test Evaluation
student.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
student.to(device)
preds, true_labels = [], []
with torch.no_grad():
    for batch in test_loader:
        texts = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)
        outputs = student(texts)
        preds.extend(torch.argmax(outputs, dim=1).cpu().numpy())
        true_labels.extend(labels.cpu().numpy())
accuracy = accuracy_score(true_labels, preds)
f1 = f1_score(true_labels, preds, average='weighted')
print(f"[KD] Final Test Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}")

  from .autonotebook import tqdm as notebook_tqdm
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[Teacher] Epoch 1, Train Loss: 0.3615740340590477
[Teacher] Epoch 1, Val Loss: 0.31023692498636096, Accuracy: 0.8686, F1 Score: 0.8684
[Teacher] Epoch 2, Train Loss: 0.23698944787979126
[Teacher] Epoch 2, Val Loss: 0.2948019550342089, Accuracy: 0.8814, F1 Score: 0.8814
[Teacher] Epoch 3, Train Loss: 0.1772795939952135
[Teacher] Epoch 3, Val Loss: 0.31792469564706655, Accuracy: 0.8828, F1 Score: 0.8828
[KD] Epoch 1, Train Loss: 0.8123866650819779
[KD] Epoch 1, Val Loss: 0.4458315135187404, Accuracy: 0.7912, F1 Score: 0.7903
[KD] Epoch 2, Train Loss: 0.4788409783601761
[KD] Epoch 2, Val Loss: 0.3927853171043335, Accuracy: 0.8258, F1 Score: 0.8258
[KD] Epoch 3, Train Loss: 0.34607676577568053
[KD] Epoch 3, Val Loss: 0.41091740515771186, Accuracy: 0.8318, F1 Score: 0.8312
[KD] Epoch 4, Train Loss: 0.2508574266910553
[KD] Epoch 4, Val Loss: 0.37675645983048306, Accuracy: 0.8456, F1 Score: 0.8455
[KD] Epoch 5, Train Loss: 0.18469736734628678
[KD] Epoch 5, Val Loss: 0.3902667104533524, Accura

In [2]:
torch.device("cuda" if torch.cuda.is_available() else "cpu")

device(type='cuda')

In [1]:
import torch

print("CUDA Available:", torch.cuda.is_available())
print("CUDA Device Count:", torch.cuda.device_count())
print("Device Name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU found")


CUDA Available: True
CUDA Device Count: 1
Device Name: NVIDIA GeForce RTX 4060 Ti
