In [None]:
# ===========================
# Student training (soft, 4epoch teacher) â€” patched
# ===========================
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

!pip install -q transformers datasets evaluate accelerate

# imports
import os, gc, warnings, numpy as np, torch, torch.nn.functional as F, pandas as pd
from datasets import Dataset, DatasetDict, Value
from transformers import (
    AutoTokenizer, AutoModelForSequenceClassification, AutoConfig,
    Trainer, TrainingArguments, DataCollatorWithPadding, set_seed
)
import evaluate
warnings.filterwarnings("ignore")
set_seed(42)

# paths (update if needed)
DRIVE_BASE = "/content/drive/MyDrive/Colab Notebooks/CodeMix"
train_csv = os.path.join(DRIVE_BASE, "train.csv")
val_csv   = os.path.join(DRIVE_BASE, "val.csv")
test_csv  = os.path.join(DRIVE_BASE, "test.csv")

# ---- PATCHED: point to new teacher folder (4epoch) ----
teacher_base_dir = os.path.join(DRIVE_BASE, "results_teacher_4epoch")
RESULTS_DIR = os.path.join(DRIVE_BASE, "results_students")
os.makedirs(RESULTS_DIR, exist_ok=True)

# simple checks
for p in (train_csv, val_csv, test_csv):
    if not os.path.exists(p):
        raise FileNotFoundError(f"Missing split file: {p}")
if not os.path.isdir(teacher_base_dir):
    raise FileNotFoundError(f"Teacher base folder not found: {teacher_base_dir}")

# detect teacher folder (prefer model/ then latest checkpoint)
def detect_teacher_folder(base_dir):
    model_dir = os.path.join(base_dir, "model")
    if os.path.isdir(model_dir) and "config.json" in os.listdir(model_dir):
        return model_dir
    if "config.json" in os.listdir(base_dir) and any(n in os.listdir(base_dir)
        for n in ["pytorch_model.bin","model.safetensors"]):
        return base_dir
    ckpts = [os.path.join(base_dir, d) for d in os.listdir(base_dir) if d.startswith("checkpoint")]
    ckpts = [d for d in ckpts if os.path.isdir(d)]
    ckpts_with_model = [d for d in ckpts if any(n in os.listdir(d)
        for n in ["pytorch_model.bin","model.safetensors","config.json"])]
    if ckpts_with_model:
        return max(ckpts_with_model, key=os.path.getmtime)
    return None

teacher_path = detect_teacher_folder(teacher_base_dir)
if teacher_path is None:
    raise FileNotFoundError(f"Could not find teacher model in {teacher_base_dir}. Contents: {os.listdir(teacher_base_dir)}")

print("Using teacher folder:", teacher_path)
print("Files there:", os.listdir(teacher_path)[:50])

# load splits
train_df = pd.read_csv(train_csv)
val_df   = pd.read_csv(val_csv)
test_df  = pd.read_csv(test_csv)
print("Loaded splits sizes:", len(train_df), len(val_df), len(test_df))

dataset = DatasetDict({
    "train": Dataset.from_pandas(train_df.reset_index(drop=True)),
    "validation": Dataset.from_pandas(val_df.reset_index(drop=True)),
    "test": Dataset.from_pandas(test_df.reset_index(drop=True)),
})

# tokenizer + tokenization (pad to fixed length)
CHECKPOINT = "distilbert-base-multilingual-cased"
MAX_LEN = 64
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)

def tokenize_batch(batch):
    return tokenizer(batch["review"], truncation=True, padding="max_length", max_length=MAX_LEN)

dataset = dataset.map(tokenize_batch, batched=True, remove_columns=dataset["train"].column_names)
dataset["train"] = dataset["train"].add_column("label", train_df["label"].astype(int).tolist())
dataset["validation"] = dataset["validation"].add_column("label", val_df["label"].astype(int).tolist())
dataset["test"] = dataset["test"].add_column("label", test_df["label"].astype(int).tolist())

# keep expected cols + cast labels
keep_cols = ["input_ids", "attention_mask", "label"]
for split in dataset.keys():
    to_remove = [c for c in dataset[split].column_names if c not in keep_cols]
    if to_remove:
        dataset[split] = dataset[split].remove_columns(to_remove)
    dataset[split] = dataset[split].cast_column("label", Value("int64"))

dataset.set_format(type="torch")
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# metrics
accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    if isinstance(logits, tuple): logits = logits[0]
    preds = np.argmax(logits, axis=-1)
    return {
        "accuracy": accuracy_metric.compute(predictions=preds, references=labels)["accuracy"],
        "macro_f1": f1_metric.compute(predictions=preds, references=labels, average="macro")["f1"]
    }

# load teacher (local)
teacher = AutoModelForSequenceClassification.from_pretrained(
    teacher_path,
    local_files_only=True,
    output_hidden_states=True,
    output_attentions=True
)
teacher.eval()
teacher.to("cpu")
print("Teacher loaded on:", next(teacher.parameters()).device)

# ---- PATCHED: set student layers + run metadata ----
num_student_layers = 2   # change as desired
seed = 42
DISTILL_TYPE = "soft"

# student config
student_config = AutoConfig.from_pretrained(
    CHECKPOINT,
    num_labels=2,
    num_hidden_layers=num_student_layers,
    output_hidden_states=True,
    output_attentions=True
)

# clear old
for n in ["trainer", "trainer_student", "student", "teacher_loaded"]:
    if n in globals():
        try: del globals()[n]
        except: pass
gc.collect(); torch.cuda.empty_cache()

# DistillTrainer definition
class DistillTrainer(Trainer):
    def __init__(self, *args, teacher_model=None, distill_type="baseline", alpha=0.5, temperature=2.0, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher_model
        self.distill_type = distill_type
        self.alpha = alpha
        self.temperature = temperature
        if self.teacher is not None: self.teacher.eval()

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None, **kwargs):
        device = model.device
        student_inputs = {k:(v.to(device) if isinstance(v, torch.Tensor) else v) for k,v in inputs.items()}
        student_outputs = model(**student_inputs, output_hidden_states=True, output_attentions=True)
        student_loss = student_outputs.loss

        if self.distill_type == "baseline" or self.teacher is None:
            return (student_loss, student_outputs) if return_outputs else student_loss

        with torch.no_grad():
            teacher_inputs = {k:(v.detach().cpu() if isinstance(v, torch.Tensor) else v) for k,v in inputs.items()}
            teacher_outputs = self.teacher(**teacher_inputs, output_hidden_states=True, output_attentions=True)

        distill_loss = 0.0
        if self.distill_type in ("soft", "full"):
            t_logits = teacher_outputs.logits.to(device) / self.temperature
            s_logits = student_outputs.logits / self.temperature
            distill_loss += F.kl_div(
                F.log_softmax(s_logits, dim=-1),
                F.softmax(t_logits, dim=-1),
                reduction="batchmean"
            ) * (self.temperature ** 2)
        if self.distill_type in ("hidden", "full"):
            distill_loss += F.mse_loss(student_outputs.hidden_states[-1], teacher_outputs.hidden_states[-1].to(device))
        if self.distill_type == "embedding":
            distill_loss += F.mse_loss(student_outputs.hidden_states[0], teacher_outputs.hidden_states[0].to(device))
        if self.distill_type == "attention":
            distill_loss += F.mse_loss(
                student_outputs.attentions[-1].sum(dim=1),
                teacher_outputs.attentions[-1].sum(dim=1).to(device)
            )

        loss = self.alpha * student_loss + (1.0 - self.alpha) * distill_loss
        return (loss, student_outputs) if return_outputs else loss

# Training args
def make_train_args(output_dir, **kwargs):
    ta_kwargs = dict(kwargs)
    if "evaluation_strategy" in TrainingArguments.__init__.__code__.co_varnames:
        if "eval_strategy" in ta_kwargs: ta_kwargs["evaluation_strategy"] = ta_kwargs.pop("eval_strategy")
    else:
        if "evaluation_strategy" in ta_kwargs: ta_kwargs["eval_strategy"] = ta_kwargs.pop("evaluation_strategy")
    return TrainingArguments(output_dir=output_dir, **ta_kwargs)

PER_DEVICE_BATCH = 4
GRAD_ACCUM = 2
EPOCHS = 2
LR = 2e-5

# ---- PATCHED: run name includes 4epoch tag ----
run_name = f"student_{DISTILL_TYPE}_layers{num_student_layers}_seed{seed}_teacher4epoch"
output_dir = os.path.join(RESULTS_DIR, run_name)
# ------------------------------------------------

train_args = make_train_args(
    output_dir=output_dir,
    eval_strategy="epoch",
    save_strategy="no",
    per_device_train_batch_size=PER_DEVICE_BATCH,
    per_device_eval_batch_size=PER_DEVICE_BATCH,
    gradient_accumulation_steps=GRAD_ACCUM,
    num_train_epochs=EPOCHS,
    learning_rate=LR,
    logging_steps=50,
    report_to="none",
    fp16=torch.cuda.is_available()
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
student = AutoModelForSequenceClassification.from_config(student_config).to(device)

trainer = DistillTrainer(
    model=student,
    args=train_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    teacher_model=teacher if DISTILL_TYPE != "baseline" else None,
    distill_type=DISTILL_TYPE,
    alpha=0.5,
    temperature=2.0
)

print("Starting student training:", run_name)
trainer.train()

# evaluate & save
res = trainer.evaluate(dataset["test"])
print("Student test results:", res)

save_dir = output_dir
trainer.save_model(save_dir)
print("Saved student ->", save_dir)

# cleanup
trainer.model.to("cpu")
gc.collect(); torch.cuda.empty_cache()
