In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import RobertaTokenizer, RobertaForSequenceClassification, BertForSequenceClassification, DataCollatorWithPadding
from datasets import load_dataset
import numpy as np
import random
import math
import warnings
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns

warnings.filterwarnings('ignore')

# Seed for reproducibility
def set_reproducible_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_reproducible_seed()

# Config
class Config:
    max_len = 128
    batch_size = 8
    epochs = 10
    lr = 2e-4
    num_classes = 5
    temperature = 3.0
    alpha = 0.7
    beta = 0.2
    patience = 4
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load dataset
print("Loading PubMed pcb_med20 dataset...")
dataset = load_dataset(
    "csv",
    data_files={
        "train": "./data/train.txt",
        "validation": "./data/dev.txt",
        "test": "./data/test.txt"
    },
    delimiter="\t",
    column_names=["label", "text"]
)
valid_labels = {"BACKGROUND", "OBJECTIVE", "METHODS", "RESULTS", "CONCLUSIONS"}
dataset = dataset.filter(lambda x: x["label"] in valid_labels)
label_map = {"BACKGROUND": 0, "OBJECTIVE": 1, "METHODS": 2, "RESULTS": 3, "CONCLUSIONS":4}

def encode_label(example):
    example["label"] = label_map[example["label"]]
    return example

dataset = dataset.map(encode_label)

# Choose tokenizer and student model architecture
student_arch = "roberta"  # or "bert"

if student_arch == "bert":
    from transformers import BertTokenizer
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    StudentModel = BertForSequenceClassification
else:
    tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
    StudentModel = RobertaForSequenceClassification

def preprocess(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=Config.max_len)

tokenized_train = dataset["train"].map(preprocess, batched=True)
tokenized_validation = dataset["validation"].map(preprocess, batched=True)
tokenized_test = dataset["test"].map(preprocess, batched=True)

tokenized_train.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
tokenized_validation.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
tokenized_test.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

train_loader = DataLoader(tokenized_train, batch_size=Config.batch_size, shuffle=True, collate_fn=DataCollatorWithPadding(tokenizer))
val_loader = DataLoader(tokenized_validation, batch_size=Config.batch_size, collate_fn=DataCollatorWithPadding(tokenizer))
test_loader = DataLoader(tokenized_test, batch_size=Config.batch_size, collate_fn=DataCollatorWithPadding(tokenizer))

# Load teacher model
teacher_model_path = r'G:\ML\Lightweight BERT with Knowledge Distillation for Low-Resource Text Classification\best_teacher_model2.pt'
teacher = RobertaForSequenceClassification.from_pretrained("roberta-base", num_labels=Config.num_classes)
teacher.load_state_dict(torch.load(teacher_model_path, map_location=Config.device))
teacher.to(Config.device)
teacher.eval()

# Initialize student model
student = StudentModel.from_pretrained(
    "bert-base-uncased" if student_arch == "bert" else "roberta-base",
    num_labels=Config.num_classes
)
student.to(Config.device)

# Distillation loss - fixed
def fixed_distillation_loss(student_logits, teacher_logits, labels, temperature=3.0,
                            alpha=0.7, beta=0.2, model=None, l2_lambda=1e-5,
                            epoch=0, max_epochs=25):
    progress = min(epoch / max_epochs, 1.0)
    alpha_adaptive = alpha * (1 - progress * 0.1)
    beta_adaptive = beta * (1 - progress * 0.2)
    gamma_adaptive = max(0.3, 1 - alpha_adaptive - beta_adaptive)

    soft_loss = nn.KLDivLoss(reduction='batchmean')(
        nn.functional.log_softmax(student_logits / temperature, dim=1),
        nn.functional.softmax(teacher_logits / temperature, dim=1)
    )
    hard_loss = nn.CrossEntropyLoss(label_smoothing=0.05)(student_logits, labels)
    teacher_preds = torch.argmax(teacher_logits, dim=1)
    sequence_loss = nn.CrossEntropyLoss(label_smoothing=0.02)(student_logits, teacher_preds)
    l2_reg = 0
    if model is not None:
        for param in model.parameters():
            l2_reg += torch.norm(param, 2)**2

    total_loss = (alpha_adaptive * soft_loss +
                  beta_adaptive * sequence_loss +
                  gamma_adaptive * hard_loss +
                  l2_lambda * l2_reg)
    return total_loss

# Training function
def train(student, teacher, train_loader, val_loader, epochs=Config.epochs, lr=Config.lr, patience=Config.patience):
    optimizer = optim.AdamW(student.parameters(), lr=lr)
    best_val_acc = 0
    patience_counter = 0
    best_model = None
    for epoch in range(epochs):
        student.train()
        train_loss = 0
        all_preds = []
        all_labels = []
        for batch in train_loader:
            input_ids = batch["input_ids"].to(Config.device)
            attention_mask = batch["attention_mask"].to(Config.device)
            labels = batch["label"].to(Config.device)

            optimizer.zero_grad()
            with torch.no_grad():
                teacher_logits = teacher(input_ids, attention_mask=attention_mask).logits
            
            student_logits = student(input_ids, attention_mask=attention_mask).logits
            loss = fixed_distillation_loss(student_logits, teacher_logits, labels, model=student, epoch=epoch, max_epochs=epochs)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            preds = torch.argmax(student_logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
        train_acc = accuracy_score(all_labels, all_preds)

        # Validation
        student.eval()
        val_loss = 0
        val_preds = []
        val_labels = []
        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch["input_ids"].to(Config.device)
                attention_mask = batch["attention_mask"].to(Config.device)
                labels = batch["label"].to(Config.device)
                teacher_logits = teacher(input_ids, attention_mask=attention_mask).logits
                student_logits = student(input_ids, attention_mask=attention_mask).logits
                loss = fixed_distillation_loss(student_logits, teacher_logits, labels, model=student, epoch=epoch, max_epochs=epochs)
                val_loss += loss.item()
                preds = torch.argmax(student_logits, dim=1)
                val_preds.extend(preds.cpu().numpy())
                val_labels.extend(labels.cpu().numpy())
        val_acc = accuracy_score(val_labels, val_preds)

        print(f"Epoch {epoch + 1}/{epochs} - Train Loss: {train_loss / len(train_loader):.4f} Train Acc: {train_acc:.4f} Val Loss: {val_loss / len(val_loader):.4f} Val Acc: {val_acc:.4f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            best_model = student.state_dict()
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping due to no improvement.")
                break
    if best_model:
        student.load_state_dict(best_model)
    return student

# Testing function
def test(student, test_loader):
    student.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch["input_ids"].to(Config.device)
            attention_mask = batch["attention_mask"].to(Config.device)
            labels = batch["label"].to(Config.device)
            logits = student(input_ids, attention_mask=attention_mask).logits
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    accuracy = accuracy_score(all_labels, all_preds)
    print(f"Test Accuracy: {accuracy:.4f}")

if __name__ == "__main__":
    trained_student = train(student, teacher, train_loader, val_loader)
    test(trained_student, test_loader)
    torch.save(trained_student.state_dict(), "final_student_model.pt")
    print("Student model saved.")
