In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import json
import random
import gc
from collections import defaultdict
import numpy as np
import pandas as pd
from scipy.stats import spearmanr
import torch
from sklearn.model_selection import train_test_split
from datasets import Dataset, DatasetDict

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback,
    set_seed
)
from peft import (
    get_peft_model,
    LoraConfig,
    TaskType,
    PeftModel
)

# hashing for reproducibility
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    set_seed(seed)

seed_everything(42)

# paths
GOLD_TRAIN_PATH = "/content/drive/MyDrive/nlp/train.json"
GOLD_DEV_PATH   = "/content/drive/MyDrive/nlp/dev.json"
OUTPUT_DIR      = "/content/drive/MyDrive/nlp/Qwen_SFT_Gold_Synth_FullMetrics"

# model configuration
MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
MAX_LENGTH = 384


# data loading & preprocessing
def load_json_records(path):
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)
    return list(data.values()) if isinstance(data, dict) else data

def flatten_records(records):
    """Safely flatten nested synthetic JSON lists."""
    flat = []
    for r in records:
        if isinstance(r, list):
            flat.extend(flatten_records(r))
        elif isinstance(r, dict):
            flat.append(r)
    return flat

def build_examples_chat_format(records):
    """
    Constructs input using Qwen's ChatML-like structure.
    Also extracts metadata (group_id, stdev) for detailed metrics.
    """
    ex = []
    for r in records:
        if not isinstance(r, dict): continue

        pre = (r.get("precontext") or "").strip()
        sent = (r.get("sentence") or "").strip()
        ending = (r.get("ending") or "").strip()
        meaning = (r.get("judged_meaning") or "").strip()

        # Meta-data for final evaluation
        homonym = (r.get("homonym") or "").strip()
        # Group ID: Used for Macro-Spearman grouping by unique context
        gid = f"{homonym}||{pre}||{sent}"

        # build story text
        story_text = f"{pre} {sent}"
        if ending:
            story_text += f" {ending}"

        # structure as an instruction
        prompt = (
            f"<|im_start|>system\n"
            f"You are a semantic judge. Rate the plausibility of the Definition given the Story.<|im_end|>\n"
            f"<|im_start|>user\n"
            f"Story: {story_text}\n"
            f"Definition: {meaning}<|im_end|>\n"
            f"<|im_start|>assistant\n"
            f"Plausibility Score:"
        )

        avg = float(r.get("average", 0.0))
        stdev = float(r.get("stdev", 0.0))

        ex.append({
            "text": prompt,
            "label": avg,
            "stdev": stdev,
            "group_id": gid
        })
    return ex

# Load Data
gold_train_records = load_json_records(GOLD_TRAIN_PATH)
gold_dev_records   = load_json_records(GOLD_DEV_PATH)


print(f"Loaded Gold Train: {len(gold_train_records)}")
print(f"Loaded Gold Dev:   {len(gold_dev_records)}")

# Build Examples
gold_train_ex = build_examples_chat_format(gold_train_records)
gold_dev_ex   = build_examples_chat_format(gold_dev_records)


train_full_ex = gold_train_ex
train_df_full = pd.DataFrame(train_full_ex)
dev_df        = pd.DataFrame(gold_dev_ex) # separate for testing

# 80/20 split of train data
train_df, val_df = train_test_split(
    train_df_full,
    test_size=0.2,
    random_state=42,
    shuffle=True,
)

# print(f"Training Set Size:   {len(train_df)}")
# print(f"Validation Set Size: {len(val_df)}")
# print(f"Test Set (Gold Dev): {len(dev_df)}")

dataset = DatasetDict({
    "train": Dataset.from_pandas(train_df, preserve_index=False),
    "validation": Dataset.from_pandas(val_df, preserve_index=False),
    "test": Dataset.from_pandas(dev_df, preserve_index=False)
})

# tokenization
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

def tokenize_batch(batch):
    return tokenizer(
        batch["text"],
        truncation=True,
        padding="max_length",
        max_length=MAX_LENGTH
    )

tokenized = dataset.map(tokenize_batch, batched=True)
tokenized = tokenized.rename_column("label", "labels")

# only torch columns for training
cols_to_keep = ["input_ids", "attention_mask", "labels"]
tokenized.set_format(type="torch", columns=cols_to_keep)

# model setup (LoRA + Regression Head)

print("Loading Model...")

base_model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_ID,
    num_labels=1,
    problem_type="regression",
    device_map="auto",
    torch_dtype=torch.float32,
    trust_remote_code=True
)

base_model.config.pad_token_id = tokenizer.pad_token_id
base_model.config.use_cache = False

peft_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    inference_mode=False,
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
    modules_to_save=["score"]
)

model = get_peft_model(base_model, peft_config)
model.print_trainable_parameters()


# training metrics
def compute_metrics(eval_pred):
    preds, labels = eval_pred
    preds = np.clip(np.squeeze(preds), 1.0, 5.0)
    labels = np.squeeze(labels)

    spearman_corr = spearmanr(labels, preds).correlation
    mae = np.mean(np.abs(labels - preds))

    return {
        "spearman": float(spearman_corr) if not np.isnan(spearman_corr) else 0.0,
        "mae": float(mae),
    }

# training
args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="spearman",
    greater_is_better=True,

    num_train_epochs=10,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=2,
    learning_rate=2e-4,
    weight_decay=0.01,
    warmup_ratio=0.05,
    fp16=True,

    logging_steps=20,
    save_total_limit=1,
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["validation"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)

print("Starting Training...")
trainer.train()

# saving adapter
adapter_path = os.path.join(OUTPUT_DIR, "best_adapter")
trainer.save_model(adapter_path)
print(f"Best adapter saved to {adapter_path}")


# inference and  evaluation
print("\n=== CLEANING MEMORY FOR INFERENCE ===")
del model, base_model, trainer
torch.cuda.empty_cache()
gc.collect()

print("Reloading Base Model + Best Adapter for Inference...")

inference_base = AutoModelForSequenceClassification.from_pretrained(
    MODEL_ID,
    num_labels=1,
    problem_type="regression",
    device_map="auto",
    torch_dtype=torch.float16,
    trust_remote_code=True
)
inference_base.config.pad_token_id = tokenizer.pad_token_id

inference_model = PeftModel.from_pretrained(inference_base, adapter_path)
inference_model.eval()

eval_args = TrainingArguments(
    output_dir=os.path.join(OUTPUT_DIR, "eval_temp"),
    per_device_eval_batch_size=16,
    report_to="none"
)

predictor = Trainer(
    model=inference_model,
    args=eval_args,
    tokenizer=tokenizer
)

print("Predicting on Test (Gold Dev) Set...")
pred_output = predictor.predict(tokenized["test"])
raw_preds = np.clip(np.squeeze(pred_output.predictions), 1.0, 5.0)

# metrics calculations
# Retrieve True Labels and Metadata from the dataframe
# (Order is preserved by the Trainer)
y_true = dev_df["label"].to_numpy(float)
y_stdev = dev_df["stdev"].to_numpy(float)
groups = dev_df["group_id"].tolist()

# Global Spearman & MAE
global_spearman = spearmanr(y_true, raw_preds).correlation
global_mae = np.mean(np.abs(y_true - raw_preds))

# Accuracy within Std Dev
errors = np.abs(raw_preds - y_true)
within_stdev = errors <= y_stdev
acc_stdev = float(np.mean(within_stdev))

# Macro-Spearman
group_indices = defaultdict(list)
for i, gid in enumerate(groups):
    group_indices[gid].append(i)

group_corrs = []
for gid, idxs in group_indices.items():
    g_true = y_true[idxs]
    g_pred = raw_preds[idxs]
    # variance to calculate correlation
    if len(set(g_true)) > 1:
        corr = spearmanr(g_true, g_pred).correlation
        if not np.isnan(corr):
            group_corrs.append(corr)

macro_spearman = float(np.mean(group_corrs)) if group_corrs else 0.0

print("\n" + "="*40)
print("FINAL RESULTS (Gold + Synth SFT)")
print("="*40)
print(f"Global Spearman:       {global_spearman:.4f}")
print(f"Macro Spearman:        {macro_spearman:.4f}")
print(f"MAE:                   {global_mae:.4f}")
print(f"Accuracy within Stdev: {acc_stdev:.4f}")
print("="*40)


In [None]:
# bootstrapping
def bootstrap_full_metrics(
    y_true,
    y_pred,
    y_stdev,
    groups,
    n_bootstrap=1000,
    seed=42
):
    rng = np.random.default_rng(seed)
    n = len(y_true)

    res_global_sps = []
    res_macro_sps = []
    res_maes = []
    res_accs = []

    for _ in range(n_bootstrap):
        idx = rng.integers(0, n, size=n)

        bt_true = y_true[idx]
        bt_pred = y_pred[idx]
        bt_stdev = y_stdev[idx]
        bt_groups = [groups[i] for i in idx]

        # Global Spearman
        gs = spearmanr(bt_true, bt_pred).correlation
        res_global_sps.append(gs if not np.isnan(gs) else 0.0)

        # MAE
        res_maes.append(np.mean(np.abs(bt_true - bt_pred)))

        # Accuracy within Stdev
        res_accs.append(np.mean(np.abs(bt_true - bt_pred) <= bt_stdev))

        # Macro Spearman
        # Re-group based on resampled data
        local_map = defaultdict(list)
        for i, g in enumerate(bt_groups):
            local_map[g].append(i)

        local_corrs = []
        for g, g_idxs in local_map.items():
            gt = bt_true[g_idxs]
            gp = bt_pred[g_idxs]
            if len(set(gt)) > 1:
                c = spearmanr(gt, gp).correlation
                if not np.isnan(c):
                    local_corrs.append(c)

        if local_corrs:
            res_macro_sps.append(np.mean(local_corrs))
        else:
            res_macro_sps.append(np.nan)

    def summarize(arr):
        arr = np.array(arr, dtype=float)
        return {
            "mean": float(np.nanmean(arr)),
            "ci_low": float(np.nanpercentile(arr, 2.5)),
            "ci_high": float(np.nanpercentile(arr, 97.5)),
        }

    return {
        "Global Spearman": summarize(res_global_sps),
        "Macro Spearman": summarize(res_macro_sps),
        "MAE": summarize(res_maes),
        "Acc w/in Stdev": summarize(res_accs),
    }

print("\nBOOTSTRAP RESULTS (TEST SET)")
bootstrap_stats = bootstrap_full_metrics(
    y_true=y_true,
    y_pred=raw_preds,
    y_stdev=y_stdev,
    groups=groups,
    n_bootstrap=1000
)

for metric, stats in bootstrap_stats.items():
    print(
        f"{metric:20s}: "
        f"{stats['mean']:.4f} "
        f"[{stats['ci_low']:.4f}, {stats['ci_high']:.4f}]"
    )