In [None]:
# ----------------------------
# LoRA-on-Teacher (PEFT) — paste & run as one cell in Colab
# ----------------------------
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# Install libs (peft + hf)
!pip install -q transformers datasets evaluate accelerate peft

# Imports
import os, gc, math, pprint, inspect, json
import numpy as np, pandas as pd, torch
from datasets import Dataset, DatasetDict, Value
from transformers import (
    AutoTokenizer, AutoConfig, AutoModelForSequenceClassification,
    Trainer, TrainingArguments, DataCollatorWithPadding, set_seed
)
from peft import LoraConfig, get_peft_model, PeftModel
import evaluate
import warnings
warnings.filterwarnings("ignore")

# ---------- User knobs (edit if you want) ----------
DRIVE_BASE = "/content/drive/MyDrive/Colab Notebooks/HindiCodeMix"
SPLIT_BASE = "/content/drive/MyDrive/Colab Notebooks/HindiCodeMix/data_processed"
DRIVE_BASE_LoRA = "/content/drive/MyDrive/Colab Notebooks/LoRA/HindiCodeMix/DropOut_0.05"
TRAIN_CSV = os.path.join(SPLIT_BASE, "train.csv")
VAL_CSV   = os.path.join(SPLIT_BASE, "val.csv")
TEST_CSV  = os.path.join(SPLIT_BASE, "test.csv")

# Where teacher full-finetuned model was saved by your teacher trainer
TEACHER_SAVE_DIR = os.path.join(DRIVE_BASE, "results_teacher_4epoch", "model")
# Output adapter folder (PEFT will save adapter files here)
LORA_OUTPUT_DIR = os.path.join(DRIVE_BASE_LoRA, "lora_adapter_r16_a64_epoch2")

# Base checkpoint id (fallback if TEACHER_SAVE_DIR is missing)
CHECKPOINT = "distilbert-base-multilingual-cased"

SEED = 42
set_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# LoRA hyperparams (sane defaults)
LORA_R = 16
LORA_ALPHA = 64
LORA_DROPOUT = 0.05
LORA_BIAS = "lora_only"   # "none" | "all" | "lora_only"

# Training hyperparams
MAX_LEN = 64             # match teacher trainer
PER_DEVICE_BATCH = 8     # reduce if OOM; increase if VRAM allows
GRAD_ACCUM = 2
EPOCHS = 2
LR = 2e-4                # LoRA typical: 1e-4 .. 5e-4 for many tasks
FP16 = torch.cuda.is_available()

# ---------- Load CSVs and build HF DatasetDict ----------
for p in [TRAIN_CSV, VAL_CSV, TEST_CSV]:
    if not os.path.exists(p):
        raise FileNotFoundError(f"CSV not found: {p} — run data-prep first")

train_df = pd.read_csv(TRAIN_CSV)
val_df   = pd.read_csv(VAL_CSV)
test_df  = pd.read_csv(TEST_CSV)
print("Sizes: train/val/test =", len(train_df), len(val_df), len(test_df))

# Create HF DatasetDict (like teacher trainer)
dataset = DatasetDict({
    "train": Dataset.from_pandas(train_df.reset_index(drop=True)),
    "validation": Dataset.from_pandas(val_df.reset_index(drop=True)),
    "test": Dataset.from_pandas(test_df.reset_index(drop=True)),
})

# ---------- Tokenizer (use same one saved with teacher) ----------
TOKENIZER_PATH = os.path.join(os.path.dirname(TEACHER_SAVE_DIR), "tokenizer")
print("Attempting to load tokenizer from:", TOKENIZER_PATH)
# robust load: use local_files_only to avoid hub validation
try:
    tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, use_fast=True, local_files_only=True)
    print("Loaded tokenizer from:", TOKENIZER_PATH)
except Exception as e:
    print("Local tokenizer load failed:", e)
    print("Falling back to HF checkpoint tokenizer:", CHECKPOINT)
    tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT, use_fast=True)

# safety: ensure pad token exists
if tokenizer.pad_token is None:
    if tokenizer.eos_token is not None:
        tokenizer.pad_token = tokenizer.eos_token
    else:
        tokenizer.add_special_tokens({"pad_token": "[PAD]"})

def tokenize_fn(batch):
    return tokenizer(batch["review"], truncation=True, padding="max_length", max_length=MAX_LEN)

dataset = dataset.map(tokenize_fn, batched=True, remove_columns=dataset["train"].column_names)

# re-add labels as ints (consistent with teacher trainer)
dataset["train"] = dataset["train"].add_column("label", train_df["label"].astype(int).tolist())
dataset["validation"] = dataset["validation"].add_column("label", val_df["label"].astype(int).tolist())
dataset["test"] = dataset["test"].add_column("label", test_df["label"].astype(int).tolist())

# cast label to int64
for s in ["train","validation","test"]:
    dataset[s] = dataset[s].cast_column("label", Value("int64"))

# set torch format
cols = ["input_ids", "attention_mask", "label"]
if "token_type_ids" in dataset["train"].column_names:
    cols.insert(1, "token_type_ids")
dataset.set_format(type="torch", columns=cols)

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# ---------- Load base model (prefer teacher checkpoint if present) ----------
base_model_source = TEACHER_SAVE_DIR if os.path.isdir(TEACHER_SAVE_DIR) else CHECKPOINT
print("Loading base model from:", base_model_source)
base_config = AutoConfig.from_pretrained(base_model_source, num_labels=2, output_attentions=True, output_hidden_states=True)
base_model = AutoModelForSequenceClassification.from_pretrained(base_model_source, config=base_config).to(device)

# ---------- Helper: auto-detect candidate target modules for LoRA ----------
def detect_target_modules(model, substrings=None, max_matches=50):
    """
    Finds module names that match common projection substrings and returns a list.
    You can override by setting target_modules explicitly.
    """
    if substrings is None:
        substrings = ["q_lin","v_lin","k_lin","q_proj","k_proj","v_proj","o_proj","c_attn","c_proj","query_key_value","query","key","value","c_attn"]
    names = []
    for n, m in model.named_modules():
        ln = n.lower()
        for s in substrings:
            if s in ln:
                names.append(n)
                break
    # deduplicate while preserving order
    seen = set(); res = []
    for x in names:
        if x not in seen:
            seen.add(x); res.append(x)
    print(f"Detected {len(res)} candidate target modules (showing up to {max_matches}):")
    print(res[:max_matches])
    return res

candidate_targets = detect_target_modules(base_model)
# If detection fails or returns nothing, fall back to a simple default (works for some HF models)
if not candidate_targets:
    candidate_targets = ["query_key_value", "q_proj","k_proj","v_proj","o_proj","c_attn","c_proj"]
    print("Fallback target_modules:", candidate_targets)

# You can override target modules here if you know exact module names:
target_modules = candidate_targets

# ---------- Create LoraConfig and wrap model ----------
lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=target_modules,
    lora_dropout=LORA_DROPOUT,
    bias=LORA_BIAS,
    task_type="SEQ_CLS"  # classification task
)

print("Wrapping base model with PEFT LoRA (this will freeze base params)...")
peft_model = get_peft_model(base_model, lora_config)

# Print trainable parameter summary
def print_trainable_summary(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total params: {total:,}, Trainable params (adapter): {trainable:,} ({trainable/total*100:.4f}%)")
print_trainable_summary(peft_model)

# ---------------- Compatibility wrapper for TrainingArguments ----------------
def make_train_args(output_dir, **kwargs):
    """
    Create TrainingArguments robustly across HF versions:
    maps 'evaluation_strategy' <-> 'eval_strategy' depending on signature.
    """
    ta_kwargs = dict(kwargs)
    sig = inspect.signature(TrainingArguments.__init__)
    param_names = list(sig.parameters.keys())
    # adapt evaluation_strategy name
    if "evaluation_strategy" in ta_kwargs and "evaluation_strategy" not in param_names and "eval_strategy" in param_names:
        ta_kwargs["eval_strategy"] = ta_kwargs.pop("evaluation_strategy")
    # remove unsupported keys safely
    if "save_strategy" in ta_kwargs and "save_strategy" not in param_names:
        ta_kwargs.pop("save_strategy", None)
    if "metric_for_best_model" in ta_kwargs and "metric_for_best_model" not in param_names:
        ta_kwargs.pop("metric_for_best_model", None)
    return TrainingArguments(output_dir=output_dir, **ta_kwargs)

# ---------- TrainingArguments (LoRA) ----------
out_dir = LORA_OUTPUT_DIR
os.makedirs(out_dir, exist_ok=True)
train_args = make_train_args(
    output_dir=out_dir,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    per_device_train_batch_size=PER_DEVICE_BATCH,
    per_device_eval_batch_size=PER_DEVICE_BATCH,
    gradient_accumulation_steps=GRAD_ACCUM,
    num_train_epochs=EPOCHS,
    learning_rate=LR,
    weight_decay=0.01,
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="macro_f1",
    greater_is_better=True,
    fp16=FP16,
    save_total_limit=3,
    report_to="none"
)

# small debug print to confirm settings
print("TrainingArguments created. per_device_train_batch_size =", train_args.per_device_train_batch_size)
print("Num epochs =", train_args.num_train_epochs)

# ---------- Metrics (robust compute_metrics) ----------
import numpy as _np

accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")

def compute_metrics(eval_pred):
    """
    Robust metric function that handles:
     - logits as tuple (logits, other)
     - logits as list of arrays (concatenate)
     - logits as numpy array
    """
    logits, labels = eval_pred

    # if HF returned extra tuple values (e.g., (logits, hidden_states)), unwrap
    if isinstance(logits, tuple):
        logits = logits[0]

    # If list of arrays (batch-wise), concatenate along axis=0
    if isinstance(logits, list):
        try:
            logits = _np.concatenate([_np.asarray(x) for x in logits], axis=0)
        except Exception:
            # fallback: try to take first element if shapes differ
            logits = _np.asarray(logits[0])

    # if torch tensor, convert
    if hasattr(logits, "detach"):
        logits = logits.detach().cpu().numpy()

    # final sanity: ensure logits is numpy array
    logits = _np.asarray(logits)

    if logits.ndim == 1:
        # weird case: model predicted single logit per example; treat as prob and threshold 0.5
        preds = (logits > 0.5).astype("int32")
    else:
        preds = _np.argmax(logits, axis=-1)

    # ensure labels is numpy array
    if hasattr(labels, "detach"):
        labels = labels.detach().cpu().numpy()
    labels = _np.asarray(labels)

    # compute metrics
    acc = float(accuracy_metric.compute(predictions=preds, references=labels)["accuracy"])
    f1 = float(f1_metric.compute(predictions=preds, references=labels, average="macro")["f1"])
    return {"accuracy": acc, "macro_f1": f1}


# ---------- Trainer ----------
trainer = Trainer(
    model=peft_model,
    args=train_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

# ---------- Train ----------
print("Starting LoRA adapter training...")
trainer.train()

# ---------- Save adapter (PEFT saves only adapter weights) ----------
print("Saving LoRA adapter to:", out_dir)
trainer.model.save_pretrained(out_dir)
tokenizer.save_pretrained(os.path.join(out_dir, "tokenizer"))  # optional

# ---------- Evaluate adapter by loading back into base and running eval ----------
print("Evaluating saved adapter...")
# load base model fresh (to avoid trainer wrappers), then load adapter
base_for_eval = AutoModelForSequenceClassification.from_pretrained(base_model_source, config=base_config)
peft_loaded = PeftModel.from_pretrained(base_for_eval, out_dir, is_trainable=False)
peft_loaded.to(device)

# tokenization for test (we used padding=max_length earlier)
enc = tokenizer(test_df["review"].astype(str).tolist(), truncation=True, padding="max_length", max_length=MAX_LEN, return_tensors="pt")
indices = list(range(len(test_df)))
batches = [indices[i:i+ (PER_DEVICE_BATCH*4) ] for i in range(0, len(indices), (PER_DEVICE_BATCH*4))]

def eval_logits(model, device):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for b in batches:
            out = model(input_ids=enc["input_ids"][b].to(device), attention_mask=enc["attention_mask"][b].to(device))
            logits = out.logits if hasattr(out, "logits") else out[0]
            preds = torch.argmax(logits, dim=-1).cpu().numpy().tolist()
            all_preds.extend(preds)
            all_labels.extend(test_df["label"].astype(int).tolist()[b[0]:b[-1]+1])
    acc = accuracy_metric.compute(predictions=all_preds, references=all_labels)["accuracy"]
    f1  = f1_metric.compute(predictions=all_preds, references=all_labels, average="macro")["f1"]
    return float(acc), float(f1)

teacher_total_params = sum(p.numel() for p in base_for_eval.parameters())
adapter_trainable_params = sum(p.numel() for p in peft_loaded.parameters() if p.requires_grad)
adapter_kb = adapter_trainable_params * 4 / 1024.0

adapter_acc, adapter_f1 = eval_logits(peft_loaded, device)

print(f"Adapter eval -> Acc: {adapter_acc:.4f}, Macro-F1: {adapter_f1:.4f}")
print(f"Adapter trainable params: {adapter_trainable_params:,} (~{adapter_kb:.1f} KB), Teacher total params: {teacher_total_params:,}")

# ---------- Save a tiny metadata JSON for result-summary script to pick up ----------
meta = {
    "label": os.path.basename(out_dir) or "Teacher+LoRA",
    "accuracy": adapter_acc,
    "macro_f1": adapter_f1,
    "adapter_trainable_params": int(adapter_trainable_params),
    "adapter_kb": float(adapter_kb),
    "teacher_params": int(teacher_total_params),
    "notes": f"r={LORA_R},alpha={LORA_ALPHA},dropout={LORA_DROPOUT},bias={LORA_BIAS},lr={LR},epochs={EPOCHS}"
}
with open(os.path.join(out_dir, "adapter_results.json"), "w") as f:
    json.dump(meta, f, indent=2)

# cleanup
try:
    trainer.model.to("cpu"); del trainer
    peft_loaded.to("cpu"); del peft_loaded
    base_for_eval.to("cpu"); del base_for_eval
    base_model.to("cpu"); del base_model
    gc.collect(); torch.cuda.empty_cache()
except Exception:
    pass

print("Done. Adapter saved at:", out_dir)
pprint.pprint(meta)
