In [2]:
import os
import math
import json
import random
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple

import numpy as np
import pandas as pd
from sklearn.model_selection import GroupKFold
from sklearn.metrics import log_loss

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    get_cosine_schedule_with_warmup,
)

from peft import LoraConfig, get_peft_model, TaskType



In [19]:
# ==========================
# CONFIG
# ==========================
CFG = {
    # Teacher outputs (A/B/Tie) on TRAIN ONLY
    # Expected file: reward_teacher_outputs/teacher_logits_train.csv
    #   columns: id + either
    #     (pA, pB, pTie)  OR  (logit_A, logit_B, logit_Tie)
    "teacher_train_path": "./reward_teacher_outputs/teacher_logits_train.csv",

    # Where to save student artifacts (used by combination.ipynb)
    "out_dir": "./student_outputs",
    "oof_path": "./student_outputs/student_oof_probs.csv",
    "test_pred_path": "./student_outputs/student_test_probs.csv",
    "avg_model_path": "./student_outputs/student_model_avg.pt",

    # Base model (student backbone)
    # You can change this to a smaller model if you can't run Gemma-2-9B.
    "model_name": "google/gemma-2-9b-it",  # e.g. "microsoft/deberta-v3-base"
    "max_length": 512,
    "num_labels": 3,  # classes: A, B, Tie

    # LoRA hyperparams
    "lora_r": 64,
    "lora_alpha": 128,
    "lora_dropout": 0.05,

    # Training hyperparams
    "n_folds": 5,
    "num_epochs": 3,          
    "train_batch_size": 2,
    "eval_batch_size": 4,
    "learning_rate": 2e-5,
    "weight_decay": 0.01,
    "warmup_ratio": 0.1,

    # Knowledge distillation hyperparams
    "use_kd": True,
    "temperature": 2.0,   # T in KL term
    "ce_weight": 1.0,     # weight on CE(y, p_student)
    "kl_weight": 1.0,     # weight on KL(teacher_T || student)

    # Misc
    "seed": 42,
    "num_workers": 2,
}



In [18]:
def get_device():
    if torch.cuda.is_available():
        return "cuda"
    if torch.backends.mps.is_available():
        return "mps"
    return "cpu"

DEVICE = get_device()
DEVICE

'cpu'

In [17]:
'''
notebook_dir = os.getcwd()
#print(f"Current working directory: {notebook_dir}")
data_dir = os.path.join(notebook_dir, "data")
# Find the extracted data path
data_path = '/data'
'''
notebook_dir = os.getcwd()
data_dir = os.path.join(notebook_dir, "data")
train = pd.read_csv(os.path.join(data_dir, "train.csv"))
test = pd.read_csv(os.path.join(data_dir, "test.csv"))


In [4]:
# ==========================
# UTILS
# ==========================
def seed_everything(seed: int = 42):
    """Set all random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def ensure_dir(path: str):
    """Create directory if it does not exist."""
    os.makedirs(path, exist_ok=True)


def label_from_winners(row: pd.Series) -> int:
    """
    Map competition winners to integer labels:
      0 = A wins, 1 = B wins, 2 = Tie
    Assumes train.csv has columns: winner_model_a, winner_model_b, winner_tie.
    """
    if row["winner_model_a"] == 1:
        return 0
    if row["winner_model_b"] == 1:
        return 1
    return 2


def make_prompt_groups(df: pd.DataFrame) -> np.ndarray:
    """
    GroupKFold groups by prompt so the same prompt never appears in both
    train and validation for a given fold.
    """
    return df["prompt"].astype(str).apply(lambda x: hash(x) % (10**9)).values


def load_official_data() -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Load train/test CSVs and derive integer labels for train.
    """
    train = pd.read_csv(CFG["train_csv"])
    test = pd.read_csv(CFG["test_csv"])

    # derive labels 0/1/2 from winner columns
    train["label"] = train.apply(label_from_winners, axis=1)
    return train, test


def load_teacher_train_logits(train_df: pd.DataFrame) -> np.ndarray:
    """
    Load teacher outputs for TRAIN rows, aligned by id.

    teacher_logits_train.csv must contain:
      - 'id'
      - either (pA, pB, pTie)  -> probabilities
        OR    (logit_A, logit_B, logit_Tie) -> raw logits

    Returns: numpy array of shape [N_train, 3] with logits for A/B/Tie.
    """
    path = CFG["teacher_train_path"]
    if not os.path.exists(path):
        raise FileNotFoundError(
            f"Teacher train file not found at {path}. "
            f"Make sure reward_teacher_outputs/teacher_logits_train.csv exists."
        )

    tdf = pd.read_csv(path)
    if "id" not in tdf.columns:
        raise ValueError("Teacher file must contain an 'id' column.")

    tdf = tdf.set_index("id")

    # Check all train ids are present in teacher file
    if not train_df["id"].isin(tdf.index).all():
        missing = train_df.loc[~train_df["id"].isin(tdf.index), "id"]
        raise ValueError(
            f"Teacher file is missing {len(missing)} train ids, e.g. {missing.iloc[:5].tolist()}"
        )

    # Case 1: teacher saved probabilities pA/pB/pTie
    if {"pA", "pB", "pTie"}.issubset(tdf.columns):
        probs = tdf.loc[train_df["id"], ["pA", "pB", "pTie"]].values.astype(np.float32)
        eps = 1e-9
        probs = np.clip(probs, eps, 1.0 - eps)
        logits = np.log(probs)  # safe "logits" for KD
        return logits

    # Case 2: teacher saved explicit logits
    alt_cols = ["logit_A", "logit_B", "logit_Tie"]
    if set(alt_cols).issubset(tdf.columns):
        logits = tdf.loc[train_df["id"], alt_cols].values.astype(np.float32)
        return logits

    raise ValueError(
        "Teacher file does not have expected columns. "
        "Need either [pA, pB, pTie] or [logit_A, logit_B, logit_Tie]."
    )



In [5]:
# ==========================
# DATASET
# ==========================
class PreferenceDataset(Dataset):
    """
    Torch Dataset for one row:
      (prompt, response_a, response_b, [label], [teacher_logits])

    label:       0/1/2 (A/B/Tie), for train/val
    teacher_logits: optional KD target, shape [3]
    """

    def __init__(
        self,
        df: pd.DataFrame,
        tokenizer,
        max_length: int,
        teacher_logits: Optional[np.ndarray] = None,
    ):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length

        if teacher_logits is not None:
            assert len(teacher_logits) == len(self.df)
            self.teacher_logits = torch.tensor(teacher_logits, dtype=torch.float32)
        else:
            self.teacher_logits = None

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        prompt = str(row["prompt"])
        ra = str(row["response_a"])
        rb = str(row["response_b"])

        # Simple text framing for the student
        text = (
            "Prompt:\n"
            + prompt
            + "\n\nResponse A:\n"
            + ra
            + "\n\nResponse B:\n"
            + rb
        )

        enc = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )

        item = {
            "input_ids": enc["input_ids"].squeeze(0),
            "attention_mask": enc["attention_mask"].squeeze(0),
        }

        # label only for train/val
        if "label" in self.df.columns:
            item["labels"] = int(row["label"])

        # KD targets if provided
        if self.teacher_logits is not None:
            item["teacher_logits"] = self.teacher_logits[idx]

        return item


def collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]:
    """
    Collate dicts from Dataset into batch tensors.
    Handles presence/absence of labels and teacher_logits.
    """
    input_ids = torch.stack([x["input_ids"] for x in batch])
    attention_mask = torch.stack([x["attention_mask"] for x in batch])

    result = {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
    }

    if "labels" in batch[0]:
        labels = torch.tensor([x["labels"] for x in batch], dtype=torch.long)
        result["labels"] = labels

    if "teacher_logits" in batch[0]:
        tlogits = torch.stack([x["teacher_logits"] for x in batch])
        result["teacher_logits"] = tlogits

    return result



In [6]:
# ==========================
# MODEL + LORA HELPERS
# ==========================
def build_tokenizer_and_model():
    """
    Build tokenizer + base sequence-classification model and wrap with LoRA.
    """
    tokenizer = AutoTokenizer.from_pretrained(CFG["model_name"])
    # Some models (e.g. Gemma) may not have pad_token; reuse eos_token.
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    base_model = AutoModelForSequenceClassification.from_pretrained(
        CFG["model_name"],
        num_labels=CFG["num_labels"],
    )

    # Automatically find all linear modules to attach LoRA to
    def find_all_linear_names(model):
        cls = torch.nn.Linear
        lora_module_names = set()
        for name, module in model.named_modules():
            if isinstance(module, cls):
                parts = name.split(".")
                lora_module_names.add(parts[-1])
        # Usually we do NOT LoRA the final classifier head
        for bad in ["classifier", "score"]:
            if bad in lora_module_names:
                lora_module_names.remove(bad)
        return list(lora_module_names)

    target_modules = find_all_linear_names(base_model)

    lora_config = LoraConfig(
        r=CFG["lora_r"],
        lora_alpha=CFG["lora_alpha"],
        lora_dropout=CFG["lora_dropout"],
        bias="none",
        task_type=TaskType.SEQ_CLS,
        target_modules=target_modules,
    )

    model = get_peft_model(base_model, lora_config)
    model.print_trainable_parameters()

    return tokenizer, model


def kd_loss_fn(
    student_logits: torch.Tensor,
    teacher_logits: torch.Tensor,
    labels: torch.Tensor,
) -> torch.Tensor:
    """
    L = ce_weight * CE(y, p_student) + kl_weight * KL(softmax(teacher/T) || p_student)

    - CE term uses hard labels (competition labels).
    - KL term encourages student to mimic teacher distribution (softer signal).
    """
    T = CFG["temperature"]
    ce_w = CFG["ce_weight"]
    kl_w = CFG["kl_weight"]

    # standard cross-entropy
    ce = F.cross_entropy(student_logits, labels)

    # teacher soft probabilities at temperature T (no grad)
    with torch.no_grad():
        t_probs_T = F.softmax(teacher_logits / T, dim=-1)

    # KL(p_teacher_T || p_student)
    log_p_student = F.log_softmax(student_logits, dim=-1)
    kl = F.kl_div(
        log_p_student,
        t_probs_T,
        reduction="batchmean",
    )

    return ce_w * ce + kl_w * kl



In [7]:
# ==========================
# TRAIN / EVAL LOOPS
# ==========================
def train_one_fold(
    fold: int,
    model,
    tokenizer,
    train_df: pd.DataFrame,
    val_df: pd.DataFrame,
    teacher_logits_train: np.ndarray,
    train_idx: np.ndarray,
    val_idx: np.ndarray,
) -> Tuple[np.ndarray, Dict, dict]:
    """
    Train on a single fold and return:
      - val_probs: (len(val_df), 3) predicted probabilities
      - fold_metrics: dict with log-loss, etc.
      - state_dict: model.state_dict() snapshot for this fold
    """
    device = CFG["device"]
    model = model.to(device)

    # Teacher logits for the train fold (KD on train only)
    fold_teacher_logits = teacher_logits_train[train_idx] if CFG["use_kd"] else None

    train_dataset = PreferenceDataset(
        train_df,
        tokenizer=tokenizer,
        max_length=CFG["max_length"],
        teacher_logits=fold_teacher_logits,
    )
    val_dataset = PreferenceDataset(
        val_df,
        tokenizer=tokenizer,
        max_length=CFG["max_length"],
        teacher_logits=None,  # we don't need KD on val
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=CFG["train_batch_size"],
        shuffle=True,
        num_workers=CFG["num_workers"],
        collate_fn=collate_fn,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=CFG["eval_batch_size"],
        shuffle=False,
        num_workers=CFG["num_workers"],
        collate_fn=collate_fn,
    )

    # Optimizer with weight decay on non-bias/non-LayerNorm params
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p
                for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay": CFG["weight_decay"],
        },
        {
            "params": [
                p
                for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]
    optimizer = torch.optim.AdamW(
        optimizer_grouped_parameters,
        lr=CFG["learning_rate"],
    )

    # Cosine schedule with warmup
    num_training_steps = CFG["num_epochs"] * len(train_loader)
    num_warmup_steps = int(CFG["warmup_ratio"] * num_training_steps)
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps,
    )

    # Training loop
    model.train()
    for epoch in range(CFG["num_epochs"]):
        for batch in train_loader:
            optimizer.zero_grad()

            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )
            student_logits = outputs.logits

            if CFG["use_kd"] and "teacher_logits" in batch:
                t_logits = batch["teacher_logits"].to(device)
                loss = kd_loss_fn(student_logits, t_logits, labels)
            else:
                loss = F.cross_entropy(student_logits, labels)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()

    # Evaluation on validation split
    model.eval()
    all_probs = []
    all_labels = []
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].cpu().numpy()

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )
            logits = outputs.logits
            probs = F.softmax(logits, dim=-1).cpu().numpy()

            all_probs.append(probs)
            all_labels.append(labels)

    all_probs = np.concatenate(all_probs, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)
    assert len(all_probs) == len(val_df)

    # Fold log-loss
    fold_ll = log_loss(all_labels, all_probs, labels=[0, 1, 2])
    fold_metrics = {
        "fold": fold,
        "log_loss": float(fold_ll),
        "n_val": int(len(val_df)),
    }

    # Save this fold's weights (for later averaging)
    state_dict = {k: v.cpu().clone() for k, v in model.state_dict().items()}

    return all_probs, fold_metrics, state_dict


def average_state_dicts(state_dicts: List[dict]) -> dict:
    """
    Simple element-wise average of a list of state_dicts.
    All models must share the same architecture.
    """
    avg_state = {}
    n = len(state_dicts)
    keys = state_dicts[0].keys()
    for k in keys:
        avg_state[k] = sum(sd[k] for sd in state_dicts) / n
    return avg_state


def predict_with_model(model, tokenizer, df: pd.DataFrame) -> np.ndarray:
    """
    Run a trained model on df and return probabilities (N, 3).
    """
    device = CFG["device"]
    model = model.to(device)
    model.eval()

    dataset = PreferenceDataset(
        df,
        tokenizer=tokenizer,
        max_length=CFG["max_length"],
        teacher_logits=None,
    )

    loader = DataLoader(
        dataset,
        batch_size=CFG["eval_batch_size"],
        shuffle=False,
        num_workers=CFG["num_workers"],
        collate_fn=collate_fn,
    )

    probs_all = []
    with torch.no_grad():
        for batch in loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )
            logits = outputs.logits
            probs = F.softmax(logits, dim=-1).cpu().numpy()
            probs_all.append(probs)

    probs_all = np.concatenate(probs_all, axis=0)
    assert len(probs_all) == len(df)
    return probs_all



In [8]:
# ==========================
# MAIN
# ==========================
def main():
    seed_everything(CFG["seed"])
    ensure_dir(CFG["out_dir"])

    print("Loading official train/test...")
    train_df, test_df = load_official_data()

    print("Loading teacher logits for KD...")
    teacher_logits_train = load_teacher_train_logits(train_df)

    print("Building tokenizer (once) to use across folds...")
    tokenizer, _ = build_tokenizer_and_model()

    # GroupKFold by prompt
    groups = make_prompt_groups(train_df)
    gkf = GroupKFold(n_splits=CFG["n_folds"])

    oof_preds = np.zeros((len(train_df), CFG["num_labels"]), dtype=np.float32)
    fold_metrics_list = []
    fold_state_dicts = []

    # Cross-validation
    for fold, (tr_idx, va_idx) in enumerate(gkf.split(train_df, groups=groups)):
        print(f"=== Fold {fold+1}/{CFG['n_folds']} ===")

        # Fresh model for each fold
        _, model = build_tokenizer_and_model()

        train_sub = train_df.iloc[tr_idx].reset_index(drop=True)
        val_sub = train_df.iloc[va_idx].reset_index(drop=True)

        val_probs, fold_metrics, state_dict = train_one_fold(
            fold=fold,
            model=model,
            tokenizer=tokenizer,
            train_df=train_sub,
            val_df=val_sub,
            teacher_logits_train=teacher_logits_train,
            train_idx=tr_idx,
            val_idx=va_idx,
        )

        # Fill OOF predictions
        oof_preds[va_idx] = val_probs
        fold_metrics_list.append(fold_metrics)
        fold_state_dicts.append(state_dict)

        print(f"Fold {fold}: log-loss = {fold_metrics['log_loss']:.6f}")

    # Overall OOF log-loss
    overall_ll = log_loss(train_df["label"].values, oof_preds, labels=[0, 1, 2])
    print(f"Overall OOF log-loss: {overall_ll:.6f}")

    # Save OOF probabilities for combination.ipynb
    oof_df = pd.DataFrame(
        {
            "id": train_df["id"].values,
            "pA": oof_preds[:, 0],
            "pB": oof_preds[:, 1],
            "pTie": oof_preds[:, 2],
        }
    )
    oof_df.to_csv(CFG["oof_path"], index=False)
    print(f"Saved student OOF probabilities -> {CFG['oof_path']}")

    # Average state dicts across folds (their "trick")
    print("Averaging LoRA + classifier weights across folds...")
    avg_state = average_state_dicts(fold_state_dicts)

    # Rebuild fresh model and load averaged weights
    _, final_model = build_tokenizer_and_model()
    final_model.load_state_dict(avg_state)
    torch.save(avg_state, CFG["avg_model_path"])
    print(f"Saved averaged student model state_dict -> {CFG['avg_model_path']}")

    # Predict on test set with averaged student
    print("Predicting on test set with averaged student model...")
    test_probs = predict_with_model(final_model, tokenizer, test_df)

    test_df_out = pd.DataFrame(
        {
            "id": test_df["id"].values,
            "pA": test_probs[:, 0],
            "pB": test_probs[:, 1],
            "pTie": test_probs[:, 2],
        }
    )
    test_df_out.to_csv(CFG["test_pred_path"], index=False)
    print(f"Saved student test probabilities -> {CFG['test_pred_path']}")

    # Small JSON summary (handy for your report)
    metrics = {
        "folds": fold_metrics_list,
        "overall_oof_log_loss": float(overall_ll),
        "config": {
            "model_name": CFG["model_name"],
            "n_folds": CFG["n_folds"],
            "num_epochs": CFG["num_epochs"],
            "learning_rate": CFG["learning_rate"],
            "temperature": CFG["temperature"],
            "ce_weight": CFG["ce_weight"],
            "kl_weight": CFG["kl_weight"],
        },
    }
    with open(os.path.join(CFG["out_dir"], "student_training_report.json"), "w") as f:
        json.dump(metrics, f, indent=2)
    print("Saved student_training_report.json")



In [9]:
if __name__ == "__main__":
    main()


Loading official train/test...
Loading teacher logits for KD...


FileNotFoundError: Teacher train file not found at ./reward_teacher_outputs/teacher_logits_train.csv. Make sure reward_teacher_outputs/teacher_logits_train.csv exists.