In [None]:
#!/usr/bin/env python3
"""
Cross-Encoder Reranker with Grid Search and K-Fold CV

Changes from original:
 - Automatically splits input CSV into 90% (train+val) and 10% (test)
 - Performs 10-fold cross-validation with Grid Search for hyperparameter tuning
 - Uses BCEWithLogitsLoss and CLS embedding for relevance score
 - Evaluates final best model on test set
"""

import os
import math
import random
import argparse
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from typing import List, Dict

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset

from transformers import (
    AutoTokenizer,
    AutoModel,
    AdamW,
    get_linear_schedule_with_warmup,
)

from sklearn.model_selection import KFold, ParameterGrid, train_test_split
from sklearn.metrics import roc_auc_score, accuracy_score

# ------------------- Dataset -------------------

class PairDataset(Dataset):
    def __init__(self, queries: List[str], chunks: List[str], labels: List[int]):
        assert len(queries) == len(chunks) == len(labels)
        self.queries = queries
        self.chunks = chunks
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            "query": self.queries[idx],
            "chunk": self.chunks[idx],
            "label": float(self.labels[idx]),
        }

def collate_fn(batch: List[Dict], tokenizer, max_length: int):
    queries = [b["query"] for b in batch]
    chunks = [b["chunk"] for b in batch]
    labels = torch.tensor([b["label"] for b in batch], dtype=torch.float)

    enc = tokenizer(
        queries, chunks, padding=True, truncation=True, max_length=max_length, return_tensors="pt"
    )
    if "token_type_ids" not in enc:
        enc["token_type_ids"] = torch.zeros_like(enc["input_ids"])

    return {
        "input_ids": enc["input_ids"],
        "attention_mask": enc["attention_mask"],
        "token_type_ids": enc["token_type_ids"],
        "labels": labels,
    }

# ------------------- Model -------------------

class CrossEncoder(nn.Module):
    def __init__(self, model_name_or_path: str, dropout_prob: float = 0.1):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name_or_path)
        hidden_size = self.encoder.config.hidden_size
        self.dropout = nn.Dropout(dropout_prob)
        self.classifier = nn.Linear(hidden_size, 1)
        nn.init.normal_(self.classifier.weight, mean=0.0, std=self.encoder.config.initializer_range)
        nn.init.zeros_(self.classifier.bias)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            return_dict=True,
        )
        cls_emb = outputs.last_hidden_state[:, 0, :]
        x = self.dropout(cls_emb)
        logits = self.classifier(x).squeeze(-1)
        return logits, cls_emb

# ------------------- Helper functions -------------------

def load_csv_dataset(path: str):
    df = pd.read_csv(path)
    assert {"query", "chunk", "label"} <= set(df.columns)
    return df

def evaluate(model, dataloader, device):
    model.eval()
    loss_fn = nn.BCEWithLogitsLoss()
    all_logits, all_labels = [], []
    total_loss = 0.0
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attn_mask = batch["attention_mask"].to(device)
            token_type_ids = batch["token_type_ids"].to(device)
            labels = batch["labels"].to(device)
            logits, _ = model(input_ids, attention_mask=attn_mask, token_type_ids=token_type_ids)
            loss = loss_fn(logits, labels)
            total_loss += loss.item() * len(labels)
            all_logits.append(logits.cpu())
            all_labels.append(labels.cpu())
    logits = torch.cat(all_logits).numpy()
    labels = torch.cat(all_labels).numpy()
    probs = 1 / (1 + np.exp(-logits))
    auc = roc_auc_score(labels, probs)
    preds = (probs >= 0.5).astype(int)
    acc = accuracy_score(labels, preds)
    return {"auc": auc, "acc": acc, "loss": total_loss / len(labels)}

def train_one_fold(train_loader, val_loader, args, lr, dropout):
    model = CrossEncoder(args.model_name_or_path, dropout_prob=dropout).to(args.device)
    optimizer = AdamW(model.parameters(), lr=lr)
    total_steps = len(train_loader) * args.epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(args.warmup_steps_ratio * total_steps),
        num_training_steps=total_steps,
    )
    loss_fn = nn.BCEWithLogitsLoss()
    for _ in range(args.epochs):
        model.train()
        for batch in train_loader:
            input_ids = batch["input_ids"].to(args.device)
            attn_mask = batch["attention_mask"].to(args.device)
            token_type_ids = batch["token_type_ids"].to(args.device)
            labels = batch["labels"].to(args.device)
            logits, _ = model(input_ids, attention_mask=attn_mask, token_type_ids=token_type_ids)
            loss = loss_fn(logits, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
    val_metrics = evaluate(model, val_loader, args.device)
    return model, val_metrics

# ------------------- Main -------------------

def main(args):
    df = load_csv_dataset(args.data_csv)

    # Split 90% train+val, 10% test
    trainval_df, test_df = train_test_split(df, test_size=0.1, random_state=42, stratify=df["label"])

    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)

    trainval_ds = PairDataset(trainval_df["query"].tolist(), trainval_df["chunk"].tolist(), trainval_df["label"].tolist())
    test_ds = PairDataset(test_df["query"].tolist(), test_df["chunk"].tolist(), test_df["label"].tolist())

    kf = KFold(n_splits=10, shuffle=True, random_state=42)
    param_grid = {"lr": [2e-5, 3e-5, 5e-5], "dropout": [0.1, 0.2]}

    best_auc = -1
    best_params = None
    best_model = None

    for params in ParameterGrid(param_grid):
        fold_aucs = []
        print(f"\nGrid params: {params}")
        for fold, (train_idx, val_idx) in enumerate(kf.split(np.arange(len(trainval_ds)))):
            train_subset = Subset(trainval_ds, train_idx)
            val_subset = Subset(trainval_ds, val_idx)
            train_loader = DataLoader(
                train_subset,
                batch_size=args.batch_size,
                shuffle=True,
                collate_fn=lambda b: collate_fn(b, tokenizer, args.max_length),
            )
            val_loader = DataLoader(
                val_subset,
                batch_size=args.eval_batch_size,
                shuffle=False,
                collate_fn=lambda b: collate_fn(b, tokenizer, args.max_length),
            )
            model, val_metrics = train_one_fold(train_loader, val_loader, args, lr=params["lr"], dropout=params["dropout"])
            fold_aucs.append(val_metrics["auc"])
            print(f"Fold {fold+1} AUC: {val_metrics['auc']:.4f}")
        mean_auc = np.mean(fold_aucs)
        print(f"Mean AUC for {params}: {mean_auc:.4f}")
        if mean_auc > best_auc:
            best_auc = mean_auc
            best_params = params
            best_model = model

    print("\nBest params:", best_params)
    print("Best CV AUC:", best_auc)

    # Evaluate on test set
    test_loader = DataLoader(
        test_ds,
        batch_size=args.eval_batch_size,
        shuffle=False,
        collate_fn=lambda b: collate_fn(b, tokenizer, args.max_length),
    )
    test_metrics = evaluate(best_model, test_loader, args.device)
    print("\nTest Set Results:", test_metrics)

    os.makedirs(args.output_dir, exist_ok=True)
    torch.save(best_model.state_dict(), os.path.join(args.output_dir, "best_model.pt"))
    tokenizer.save_pretrained(args.output_dir)
    print("Saved best model and tokenizer to", args.output_dir)

# ------------------- Args -------------------

def parse_args():
    p = argparse.ArgumentParser(description="Cross-Encoder fine-tuning with Grid Search + K-Fold CV")
    p.add_argument("--data_csv", type=str, required=True, help="CSV with columns query,chunk,label")
    p.add_argument("--model_name_or_path", type=str, required=True, help="Model name or local path")
    p.add_argument("--output_dir", type=str, default="./crossenc_out")
    p.add_argument("--epochs", type=int, default=2)
    p.add_argument("--batch_size", type=int, default=16)
    p.add_argument("--eval_batch_size", type=int, default=64)
    p.add_argument("--max_length", type=int, default=256)
    p.add_argument("--warmup_steps_ratio", type=float, default=0.06)
    p.add_argument("--max_grad_norm", type=float, default=1.0)
    p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    return p.parse_args()

if __name__ == "__main__":
    args = parse_args()
    args.device = torch.device(args.device)
    main(args)