In [None]:
!pip install peft datasets transformers evaluate

In [None]:
import os
# Optional: for accurate CUDA error traces during debugging
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import torch
import torch.nn as nn
from datasets import load_dataset
from transformers import (
    BertTokenizerFast,
    RobertaTokenizerFast,
    BertForSequenceClassification,
    RobertaForSequenceClassification,
    Trainer,
    TrainingArguments,
)
from peft import LoraConfig, get_peft_model

# 1) Load AG News and build label mapping
raw = load_dataset("ag_news")
labels = raw["train"].features["label"].names
id2label = {i: lab for i, lab in enumerate(labels)}
num_labels = len(labels)

# 2) Tokenizers for teacher & student
teacher_tokenizer = BertTokenizerFast.from_pretrained("bert-large-uncased")
student_tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")

# 3) Preprocess: return both student & teacher encodings + labels
def preprocess(examples):
    texts = examples["text"]
    t_enc = teacher_tokenizer(texts, truncation=True, max_length=512)
    s_enc = student_tokenizer(texts, truncation=True, max_length=512)
    return {
        "input_ids":              s_enc["input_ids"],
        "attention_mask":         s_enc["attention_mask"],
        "teacher_input_ids":      t_enc["input_ids"],
        "teacher_attention_mask": t_enc["attention_mask"],
        "labels":                 examples["label"],
    }

tokenized_train = raw["train"].map(
    preprocess,
    batched=True,
    load_from_cache_file=False,
    remove_columns=raw["train"].column_names
)
tokenized_eval = raw["test"].map(
    preprocess,
    batched=True,
    load_from_cache_file=False,
    remove_columns=raw["test"].column_names
)

# 4) Collate function: pad both student & teacher inputs
def collate_fn(examples):
    # Pad student inputs
    student_batch = student_tokenizer.pad(
        {
            "input_ids":      [ex["input_ids"] for ex in examples],
            "attention_mask": [ex["attention_mask"] for ex in examples],
        },
        return_tensors="pt"
    )
    # Pad teacher inputs
    teacher_batch = teacher_tokenizer.pad(
        {
            "input_ids":      [ex["teacher_input_ids"] for ex in examples],
            "attention_mask": [ex["teacher_attention_mask"] for ex in examples],
        },
        return_tensors="pt"
    )
    # Labels tensor
    labels = torch.tensor([ex["labels"] for ex in examples], dtype=torch.long)
    # Assemble batch
    return {
        **student_batch,
        "teacher_input_ids":      teacher_batch["input_ids"],
        "teacher_attention_mask": teacher_batch["attention_mask"],
        "labels":                 labels,
    }

# 5) Set device and load models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Teacher on GPU
teacher = BertForSequenceClassification.from_pretrained(
    "bert-large-uncased",
    num_labels=num_labels,
    id2label=id2label,
).to(device)
teacher.eval()
for p in teacher.parameters():
    p.requires_grad = False

# Student base + LoRA, on GPU
student_base = RobertaForSequenceClassification.from_pretrained(
    "roberta-base",
    num_labels=num_labels,
    id2label=id2label,
)
peft_cfg = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["query", "value"],
    bias="none",
    task_type="SEQ_CLS",
)
student = get_peft_model(student_base, peft_cfg).to(device)

# Define a function to print trainable parameters
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    # print(f"📊 Total trainable parameters: {total_trainable_params:,}")
    return total_trainable_params

print(f"teacher trainable parameters: {print_trainable_parameters(teacher)}")

print(f"student trainable parameters: {print_trainable_parameters(student)}")

# 6) Custom Trainer for distillation
class DistillationTrainer(Trainer):
    def __init__(self, teacher_model, alpha=0.7, temperature=2.0, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher_model
        self.alpha   = alpha
        self.temp    = temperature
        self.kl_div  = nn.KLDivLoss(reduction="batchmean")

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")

        # Student forward on GPU
        s_out = model(
            input_ids=inputs["input_ids"].to(device),
            attention_mask=inputs["attention_mask"].to(device),
        )
        logits_s = s_out.logits

        # Teacher forward on GPU
        with torch.no_grad():
            t_out = self.teacher(
                input_ids=inputs["teacher_input_ids"].to(device),
                attention_mask=inputs["teacher_attention_mask"].to(device),
            )
        logits_t = t_out.logits

        # Hard-label cross-entropy
        loss_ce = nn.CrossEntropyLoss()(logits_s, labels.to(device))

        # Soft-label distillation (KL divergence)
        T       = self.temp
        log_p_s = nn.functional.log_softmax(logits_s / T, dim=-1)
        p_t     = nn.functional.softmax(logits_t / T,      dim=-1)
        loss_kd = self.kl_div(log_p_s, p_t) * (T * T)

        loss = self.alpha * loss_ce + (1 - self.alpha) * loss_kd
        return (loss, s_out) if return_outputs else loss

# 7) TrainingArguments with unused columns preserved
training_args = TrainingArguments(
    output_dir="distill_results",
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    learning_rate=5e-5,
    eval_strategy="steps",             # run evaluation every N steps
    save_strategy="steps",             # save checkpoint every N steps
    logging_strategy="steps",          # log train loss every N steps
    logging_steps=100,                 # N = 100 steps
    eval_steps=100,                    # eval (and log accuracy) every 100 steps
    save_steps=500,                    # save checkpoint every 500 steps
    fp16=True,
    max_steps=1600,
    gradient_accumulation_steps=2,
    load_best_model_at_end=True,       # now valid since eval_strategy == save_strategy
    metric_for_best_model="accuracy",
    greater_is_better=True,
    save_total_limit=8,
    remove_unused_columns=False,
    report_to="none",
)
# 8) Instantiate & run
trainer = DistillationTrainer(
    teacher_model=teacher,
    alpha=0.7,
    temperature=2.0,
    model=student,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_eval,
    data_collator=collate_fn,
    tokenizer=student_tokenizer,
    compute_metrics=lambda p: {
        "accuracy": (p.predictions.argmax(-1) == p.label_ids).mean()
    },
)

trainer.train()
print("Final evaluation:", trainer.evaluate())


In [None]:
import os
import torch
from torch.utils.data import DataLoader
import pandas as pd
from transformers import (
    RobertaForSequenceClassification,
    RobertaTokenizer,
    DataCollatorWithPadding,
)
from peft import PeftModel

# 1) Config
BASE_MODEL      = "roberta-base"
PEFT_CHECKPOINT = "checkpoint-1600" # your best LoRA checkpoint
TEST_PICKLE     = "test_unlabelled.pkl"
OUTPUT_CSV      = "inference_output.csv"

# 2) Tokenizer & collator (student-only)
tokenizer     = RobertaTokenizer.from_pretrained(BASE_MODEL)
data_collator = DataCollatorWithPadding(tokenizer)

# 3) Load and wrap your student
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

base = RobertaForSequenceClassification.from_pretrained(
    BASE_MODEL,
    num_labels=4
)
model = PeftModel.from_pretrained(base, PEFT_CHECKPOINT).to(device)
model.eval()

# 4) Evaluation helper
def evaluate_model(inf_model, dataset, batch_size=8, collate_fn=None):
    loader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)
    preds = []
    with torch.no_grad():
        for batch in loader:
            # only input_ids & attention_mask will be in batch
            inputs = {
                "input_ids":      batch["input_ids"].to(device),
                "attention_mask": batch["attention_mask"].to(device),
            }
            logits = inf_model(**inputs).logits
            preds.append(logits.argmax(dim=-1).cpu())
    return torch.cat(preds, dim=0)

# 5) Load & tokenize your test set
unlabelled = pd.read_pickle(TEST_PICKLE)
test_ds = unlabelled.map(
    lambda ex: tokenizer(ex["text"], truncation=True, padding=False),
    batched=True,
    remove_columns=["text"]
)

# 6) Run inference
predictions = evaluate_model(model, test_ds, batch_size=64, collate_fn=data_collator)

# 7) Save
df = pd.DataFrame({"ID": range(len(predictions)), "Label": predictions.tolist()})
df.to_csv(OUTPUT_CSV, index=False)
print(f"Inference complete — wrote {OUTPUT_CSV}")
