In [3]:
# ===========================
# Student training (baseline, patched for 4-epoch teacher)
# Saves results with "_teacher4epoch" in directory name
# ===========================
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

!pip install -q transformers datasets evaluate accelerate

# imports
import os, gc, warnings, numpy as np, torch, pandas as pd
import torch.nn.functional as F
from datasets import Dataset, DatasetDict, Value
from transformers import (
    AutoTokenizer, AutoModelForSequenceClassification, AutoConfig,
    Trainer, TrainingArguments, DataCollatorWithPadding, set_seed
)
import evaluate
warnings.filterwarnings("ignore")
set_seed(42)

import os, random

def project_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

project_seed(42)

# ---------- Paths ----------
DRIVE_BASE = "/content/drive/MyDrive/Colab Notebooks/CodeMix"
train_csv = os.path.join(DRIVE_BASE, "train.csv")
val_csv   = os.path.join(DRIVE_BASE, "val.csv")
test_csv  = os.path.join(DRIVE_BASE, "test.csv")

# use new teacher directory
teacher_base_dir = os.path.join(DRIVE_BASE, "results_teacher_4epoch")
RESULTS_DIR = os.path.join(DRIVE_BASE, "results_students")
os.makedirs(RESULTS_DIR, exist_ok=True)

# sanity checks
for p in (train_csv, val_csv, test_csv):
    if not os.path.exists(p):
        raise FileNotFoundError(f"Missing split file: {p}")
if not os.path.isdir(teacher_base_dir):
    raise FileNotFoundError(f"Teacher folder not found: {teacher_base_dir}")

# ---------- Teacher detection ----------
def detect_teacher_folder(base_dir):
    model_dir = os.path.join(base_dir, "model")
    if os.path.isdir(model_dir) and "config.json" in os.listdir(model_dir):
        return model_dir
    if "config.json" in os.listdir(base_dir):
        return base_dir
    ckpts = [os.path.join(base_dir, d) for d in os.listdir(base_dir) if d.startswith("checkpoint")]
    ckpts = [d for d in ckpts if os.path.isdir(d)]
    if ckpts:
        return max(ckpts, key=os.path.getmtime)
    return None

teacher_path = detect_teacher_folder(teacher_base_dir)
if teacher_path is None:
    raise FileNotFoundError(f"Could not locate teacher model in {teacher_base_dir}")
print("Using teacher from:", teacher_path)

# ---------- Load dataset ----------
train_df = pd.read_csv(train_csv)
val_df   = pd.read_csv(val_csv)
test_df  = pd.read_csv(test_csv)
print("Loaded splits:", len(train_df), len(val_df), len(test_df))

dataset = DatasetDict({
    "train": Dataset.from_pandas(train_df.reset_index(drop=True)),
    "validation": Dataset.from_pandas(val_df.reset_index(drop=True)),
    "test": Dataset.from_pandas(test_df.reset_index(drop=True)),
})

CHECKPOINT = "distilbert-base-multilingual-cased"
MAX_LEN = 64
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)

def tokenize_batch(batch):
    return tokenizer(batch["review"], truncation=True, padding="max_length", max_length=MAX_LEN)

dataset = dataset.map(tokenize_batch, batched=True, remove_columns=dataset["train"].column_names)
dataset["train"] = dataset["train"].add_column("label", train_df["label"].astype(int).tolist())
dataset["validation"] = dataset["validation"].add_column("label", val_df["label"].astype(int).tolist())
dataset["test"] = dataset["test"].add_column("label", test_df["label"].astype(int).tolist())

# keep expected cols
keep_cols = ["input_ids", "attention_mask", "label"]
for split in dataset.keys():
    to_remove = [c for c in dataset[split].column_names if c not in keep_cols]
    if to_remove: dataset[split] = dataset[split].remove_columns(to_remove)
    dataset[split] = dataset[split].cast_column("label", Value("int64"))
dataset.set_format(type="torch")
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# ---------- Metrics ----------
accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    if isinstance(logits, tuple): logits = logits[0]
    preds = np.argmax(logits, axis=-1)
    return {
        "accuracy": accuracy_metric.compute(predictions=preds, references=labels)["accuracy"],
        "macro_f1": f1_metric.compute(predictions=preds, references=labels, average="macro")["f1"]
    }

# ---------- Student Config ----------
num_student_layers = 2   # adjust layer count here
seed = 42

student_config = AutoConfig.from_pretrained(
    CHECKPOINT,
    num_labels=2,
    num_hidden_layers=num_student_layers,
    output_hidden_states=True,
    output_attentions=True
)

# baseline = no teacher guidance
DISTILL_TYPE = "baseline"

# clear any old globals
for n in ["trainer","student","teacher_loaded"]:
    if n in globals():
        del globals()[n]
gc.collect(); torch.cuda.empty_cache()

# ---------- DistillTrainer ----------
class DistillTrainer(Trainer):
    def __init__(self, *args, distill_type="baseline", **kwargs):
        super().__init__(*args, **kwargs)
        self.distill_type = distill_type

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        outputs = model(**inputs)
        loss = outputs.loss
        return (loss, outputs) if return_outputs else loss

# ---------- Training Args ----------
PER_DEVICE_BATCH = 4
GRAD_ACCUM = 2
EPOCHS = 2
LR = 2e-5

# patched run name (with teacher4epoch marker)
run_name = f"student_{DISTILL_TYPE}_layers{num_student_layers}_seed{seed}_teacher4epoch"
output_dir = os.path.join(RESULTS_DIR, run_name)

def make_train_args(output_dir, **kwargs):
    ta_kwargs = dict(kwargs)
    if "evaluation_strategy" in TrainingArguments.__init__.__code__.co_varnames:
        if "eval_strategy" in ta_kwargs:
            ta_kwargs["evaluation_strategy"] = ta_kwargs.pop("eval_strategy")
    else:
        if "evaluation_strategy" in ta_kwargs:
            ta_kwargs["eval_strategy"] = ta_kwargs.pop("evaluation_strategy")
    return TrainingArguments(output_dir=output_dir, **ta_kwargs)

train_args = make_train_args(
    output_dir=output_dir,
    eval_strategy="epoch",
    save_strategy="no",
    per_device_train_batch_size=PER_DEVICE_BATCH,
    per_device_eval_batch_size=PER_DEVICE_BATCH,
    gradient_accumulation_steps=GRAD_ACCUM,
    num_train_epochs=EPOCHS,
    learning_rate=LR,
    logging_steps=50,
    report_to="none",
    fp16=torch.cuda.is_available()
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
student = AutoModelForSequenceClassification.from_config(student_config).to(device)

trainer = DistillTrainer(
    model=student,
    args=train_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    distill_type=DISTILL_TYPE
)

print("Starting baseline training | layers =", num_student_layers)
trainer.train()

# ---------- Evaluate & Save ----------
res = trainer.evaluate(dataset["test"])
print("Baseline student test results:", res)

trainer.save_model(output_dir)
print("Saved baseline student ->", output_dir)

# cleanup
trainer.model.to("cpu")
gc.collect(); torch.cuda.empty_cache()


KeyboardInterrupt: 