In [None]:
#!/usr/bin/env python3
"""
Cross-Encoder Reranker with LoRA fine-tuning + Grid Search and K-Fold CV

- Uses LoRA adapters (injected into all nn.Linear layers of the encoder)
- Freezes base pretrained weights; trains only LoRA adapters + classifier head
- Grid search over learning rate and dropout, 10-fold CV on train+val, held-out test (10%)
- Saves best model (weights include LoRA adapters) and tokenizer to output_dir
"""

import os
import math
import random
import argparse
import numpy as np
import pandas as pd
from typing import List, Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset

from transformers import (
    AutoTokenizer,
    AutoModel,
    get_linear_schedule_with_warmup,
)
from torch.optim import AdamW

from sklearn.model_selection import KFold, ParameterGrid, train_test_split
from sklearn.metrics import roc_auc_score, accuracy_score

# ------------------- Dataset -------------------

class PairDataset(Dataset):
    def __init__(self, queries: List[str], chunks: List[str], labels: List[int]):
        assert len(queries) == len(chunks) == len(labels)
        self.queries = queries
        self.chunks = chunks
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            "query": self.queries[idx],
            "chunk": self.chunks[idx],
            "label": float(self.labels[idx]),
        }

def collate_fn(batch: List[Dict], tokenizer, max_length: int):
    queries = [b["query"] for b in batch]
    chunks = [b["chunk"] for b in batch]
    labels = torch.tensor([b["label"] for b in batch], dtype=torch.float)

    enc = tokenizer(
        queries, chunks, padding=True, truncation=True, max_length=max_length, return_tensors="pt"
    )
    # Some tokenizers/models (e.g., RoBERTa) don't return token_type_ids -> create zeros
    if "token_type_ids" not in enc:
        enc["token_type_ids"] = torch.zeros_like(enc["input_ids"])

    return {
        "input_ids": enc["input_ids"],
        "attention_mask": enc["attention_mask"],
        "token_type_ids": enc["token_type_ids"],
        "labels": labels,
    }

# ------------------- LoRA Implementation -------------------

class LoRALinear(nn.Module):
    """
    Replacement for nn.Linear that adds a low-rank update:
      W_eff = W (frozen) + (B @ A) * scaling
    Where:
      - A: (r, in_features)
      - B: (out_features, r)
      - scaling = alpha / r
    Only A and B (and optional dropout) are trainable by default.
    """
    def __init__(self, orig_linear: nn.Linear, r: int = 8, alpha: float = 32.0, dropout: float = 0.0):
        super().__init__()
        self.in_features = orig_linear.in_features
        self.out_features = orig_linear.out_features
        self.r = r
        self.alpha = alpha
        self.scaling = alpha / max(1, r)
        self.dropout_p = dropout

        # copy original weight and bias and freeze them
        self.weight = nn.Parameter(orig_linear.weight.data.clone(), requires_grad=False)
        if orig_linear.bias is not None:
            self.bias = nn.Parameter(orig_linear.bias.data.clone(), requires_grad=False)
        else:
            self.bias = None

        # LoRA parameters (trainable)
        if r > 0:
            # A: (r, in_features)
            self.A = nn.Parameter(torch.randn(r, self.in_features) * 0.01)
            # B: (out_features, r)
            self.B = nn.Parameter(torch.zeros(self.out_features, r))
        else:
            self.A = None
            self.B = None

        # optional dropout on input to adapter
        self.dropout = nn.Dropout(self.dropout_p) if self.dropout_p > 0.0 else None

    def forward(self, x):
        # x: (batch, seq_len, in_features) or (batch, in_features)
        # compute base linear: using frozen weight
        # use torch.nn.functional.linear to use weight.T semantics
        base = F.linear(x, self.weight, self.bias)  # (.., out_features)

        if self.r > 0:
            # compute adapter output: (.., out_features)
            # x_flat: (N, in_features) where N = batch*seq_len maybe
            orig_shape = x.shape
            if x.dim() == 3:
                N, S, D = x.shape
                x_flat = x.reshape(-1, D)  # (N*S, D)
            else:
                x_flat = x  # (N, D)

            if self.dropout is not None:
                x_drop = self.dropout(x_flat)
            else:
                x_drop = x_flat

            # A @ x^T -> (r, N) then B @ (A x)^T -> (out_features, N) -> transpose -> (N, out_features)
            # Efficient: (x_drop @ A.T) -> (N, r), then @ B.T -> (N, out_features)
            low_rank = (x_drop @ self.A.t()) @ self.B.t()  # (N, out_features)
            low_rank = low_rank * self.scaling
            if x.dim() == 3:
                low_rank = low_rank.view(N, S, self.out_features)
                # match base shape (N,S,out)
                out = base + low_rank
            else:
                out = base + low_rank
            return out
        else:
            return base

def replace_linear_with_lora(module: nn.Module, r: int, alpha: float, dropout: float, prefix=""):
    """
    Recursively replace nn.Linear modules **inside** the given module with LoRALinear wrappers.
    The original linear layers are replaced in-place.
    """
    for name, child in list(module.named_children()):
        child_prefix = f"{prefix}.{name}" if prefix else name
        if isinstance(child, nn.Linear):
            # create LoRA-wrapped module
            lora_linear = LoRALinear(child, r=r, alpha=alpha, dropout=dropout)
            # assign to parent
            setattr(module, name, lora_linear)
        else:
            replace_linear_with_lora(child, r=r, alpha=alpha, dropout=dropout, prefix=child_prefix)

def inject_lora(model: nn.Module, r: int, alpha: float, dropout: float):
    """
    Inject LoRA into the encoder of the CrossEncoder (AutoModel).
    We'll operate on model.encoder (AutoModel instance).
    Freeze original base parameters and ensure only LoRA params + classifier are trainable.
    """
    # Replace all nn.Linear inside encoder with LoRALinear
    encoder = model.encoder
    replace_linear_with_lora(encoder, r=r, alpha=alpha, dropout=dropout)

    # Freeze all parameters of base encoder (the original weights are now frozen in LoRALinear,
    # but other params may exist like LayerNorm, embeddings â€” to keep stable we freeze them)
    for name, p in encoder.named_parameters():
        # LoRA adapter parameters inside LoRALinear are trainable by default (A,B)
        # They appear as parameters of encoder because LoRALinear registers A/B as parameters.
        # Only freeze parameters that were originally from pretrained model: we detect them by attribute names
        # but that's tricky; simpler: set requires_grad=False for all, then set True for LoRA and classifier.
        p.requires_grad = False

    # Now unfreeze LoRA adapter params (A and B) by searching for them
    for name, module in encoder.named_modules():
        if isinstance(module, LoRALinear):
            if module.A is not None:
                module.A.requires_grad = True
            if module.B is not None:
                module.B.requires_grad = True
            # LoRALinear has weight and bias frozen by design (requires_grad=False)
    # Return model (modified in place)
    return model

# ------------------- Model -------------------

class CrossEncoderLoRA(nn.Module):
    def __init__(self, model_name_or_path: str, dropout_prob: float = 0.1):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name_or_path)
        hidden_size = self.encoder.config.hidden_size
        self.dropout = nn.Dropout(dropout_prob)
        self.classifier = nn.Linear(hidden_size, 1)
        # initialize classifier
        nn.init.normal_(self.classifier.weight, mean=0.0, std=self.encoder.config.initializer_range)
        nn.init.zeros_(self.classifier.bias)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            return_dict=True,
        )
        cls_emb = outputs.last_hidden_state[:, 0, :]
        x = self.dropout(cls_emb)
        logits = self.classifier(x).squeeze(-1)
        return logits, cls_emb

# ------------------- Helper functions -------------------

def load_csv_dataset(path: str):
    df = pd.read_csv(path)
    assert {"query", "chunk", "label"} <= set(df.columns)
    return df

def evaluate(model, dataloader, device):
    model.eval()
    loss_fn = nn.BCEWithLogitsLoss()
    all_logits, all_labels = [], []
    total_loss = 0.0
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attn_mask = batch["attention_mask"].to(device)
            token_type_ids = batch["token_type_ids"].to(device)
            labels = batch["labels"].to(device)
            logits, _ = model(input_ids, attention_mask=attn_mask, token_type_ids=token_type_ids)
            loss = loss_fn(logits, labels)
            total_loss += loss.item() * len(labels)
            all_logits.append(logits.cpu())
            all_labels.append(labels.cpu())
    logits = torch.cat(all_logits).numpy()
    labels = torch.cat(all_labels).numpy()
    probs = 1 / (1 + np.exp(-logits))
    auc = roc_auc_score(labels, probs) if len(np.unique(labels)) > 1 else 0.0
    preds = (probs >= 0.5).astype(int)
    acc = accuracy_score(labels, preds)
    return {"auc": auc, "acc": acc, "loss": total_loss / len(labels)}

def train_one_fold(train_loader, val_loader, args, lr, dropout, lora_rank, lora_alpha, lora_dropout):
    # build model with LoRA injection
    model = CrossEncoderLoRA(args.model_name_or_path, dropout_prob=dropout)
    # inject LoRA into encoder
    model = inject_lora(model, r=lora_rank, alpha=lora_alpha, dropout=lora_dropout)
    model.to(args.device)

    # Ensure classifier head is trainable
    for p in model.classifier.parameters():
        p.requires_grad = True

    # Build optimizer with only trainable parameters (LoRA A/B and classifier)
    trainable_params = [p for p in model.parameters() if p.requires_grad]
    optimizer = AdamW(trainable_params, lr=lr)

    total_steps = len(train_loader) * args.epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(args.warmup_steps_ratio * total_steps),
        num_training_steps=total_steps,
    )
    loss_fn = nn.BCEWithLogitsLoss()

    for _ in range(args.epochs):
        model.train()
        for batch in train_loader:
            input_ids = batch["input_ids"].to(args.device)
            attn_mask = batch["attention_mask"].to(args.device)
            token_type_ids = batch["token_type_ids"].to(args.device)
            labels = batch["labels"].to(args.device)

            logits, _ = model(input_ids, attention_mask=attn_mask, token_type_ids=token_type_ids)
            loss = loss_fn(logits, labels)
            loss.backward()

            torch.nn.utils.clip_grad_norm_(trainable_params, args.max_grad_norm)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

    val_metrics = evaluate(model, val_loader, args.device)
    return model, val_metrics

# ------------------- Main -------------------

def main(args):
    df = load_csv_dataset(args.data_csv)

    # Split 90% train+val, 10% test
    trainval_df, test_df = train_test_split(df, test_size=0.1, random_state=42, stratify=df["label"])

    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)

    trainval_ds = PairDataset(trainval_df["query"].tolist(), trainval_df["chunk"].tolist(), trainval_df["label"].tolist())
    test_ds = PairDataset(test_df["query"].tolist(), test_df["chunk"].tolist(), test_df["label"].tolist())

    kf = KFold(n_splits=10, shuffle=True, random_state=42)
    param_grid = {
        "lr": [2e-5, 3e-5, 5e-5],
        "dropout": [0.1, 0.2],
        # LoRA hyperparams could be included in grid if desired. We add them as command-line args.
    }

    best_auc = -1
    best_params = None
    best_model = None

    for params in ParameterGrid(param_grid):
        fold_aucs = []
        print(f"\nGrid params: {params}")
        for fold, (train_idx, val_idx) in enumerate(kf.split(np.arange(len(trainval_ds)))):
            train_subset = Subset(trainval_ds, train_idx)
            val_subset = Subset(trainval_ds, val_idx)
            train_loader = DataLoader(
                train_subset,
                batch_size=args.batch_size,
                shuffle=True,
                collate_fn=lambda b: collate_fn(b, tokenizer, args.max_length),
            )
            val_loader = DataLoader(
                val_subset,
                batch_size=args.eval_batch_size,
                shuffle=False,
                collate_fn=lambda b: collate_fn(b, tokenizer, args.max_length),
            )
            model, val_metrics = train_one_fold(
                train_loader,
                val_loader,
                args,
                lr=params["lr"],
                dropout=params["dropout"],
                lora_rank=args.lora_rank,
                lora_alpha=args.lora_alpha,
                lora_dropout=args.lora_dropout,
            )
            fold_aucs.append(val_metrics["auc"])
            print(f"Fold {fold+1} AUC: {val_metrics['auc']:.4f}")
        mean_auc = np.mean(fold_aucs)
        print(f"Mean AUC for {params}: {mean_auc:.4f}")
        if mean_auc > best_auc:
            best_auc = mean_auc
            best_params = params
            best_model = model  # last fold model (trained for this param setting). You may choose to re-train final model later.

    print("\nBest params:", best_params)
    print("Best CV AUC:", best_auc)

    # Evaluate on test set using best_model
    test_loader = DataLoader(
        test_ds,
        batch_size=args.eval_batch_size,
        shuffle=False,
        collate_fn=lambda b: collate_fn(b, tokenizer, args.max_length),
    )
    test_metrics = evaluate(best_model, test_loader, args.device)
    print("\nTest Set Results:", test_metrics)

    # Save model and tokenizer (best_model includes LoRA adapters)
    os.makedirs(args.output_dir, exist_ok=True)
    # Save entire model state_dict and tokenizer
    torch.save(best_model.state_dict(), os.path.join(args.output_dir, "best_model_with_lora.pt"))
    tokenizer.save_pretrained(args.output_dir)
    print("Saved best model (with LoRA adapters) and tokenizer to", args.output_dir)

# ------------------- Args -------------------

class Args:
    # file paths
    data_csv = "./training_pairs.csv"
    model_name_or_path = "bert-base-uncased"
    output_dir = "./crossenc_lora_out"

    # training config
    epochs = 5
    batch_size = 16
    eval_batch_size = 64
    max_length = 256
    warmup_steps_ratio = 0.06
    max_grad_norm = 1.0

    # device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"[INFO] Using device: {device}")
    # LoRA config
    lora_rank = 32
    lora_alpha = 64
    lora_dropout = 0.0

if __name__ == "__main__":
    args = Args()
    args.device = torch.device(args.device)
    main(args)


[INFO] Using device: cuda

Grid params: {'dropout': 0.1, 'lr': 2e-05}


AcceleratorError: CUDA error: CUDA-capable device(s) is/are busy or unavailable
Search for `cudaErrorDevicesUnavailable' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [5]:
print("cuda available:", torch.cuda.is_available())
print("cuda device count:", torch.cuda.device_count())
print("CUDA_VISIBLE_DEVICES:", os.environ.get("CUDA_VISIBLE_DEVICES"))


cuda available: True
cuda device count: 1
CUDA_VISIBLE_DEVICES: None
