In [8]:
import os
from dataclasses import dataclass
from typing import List, Dict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoConfig,
)

from peft import LoraConfig, get_peft_model
import json
from tqdm import tqdm
from collections import Counter
import matplotlib.pyplot as plt

In [9]:
CACHE_DIR = "/scratch/users/jiaxun1218"
MODEL_NAME = "Qwen/Qwen2.5-Math-7B"  # adjust if your HF ID differs
MAX_LEN = 1024  # prompt + answer snippet; tune as needed

BATCH_SIZE = 2
NUM_EPOCHS = 10
LR = 3e-4

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {DEVICE}")
label2id = {
    "model_a": 0,
    "model_b": 1,
    "tie": 2,
    "both_bad": 3,
}
id2label = {v: k for k, v in label2id.items()}

Device: cuda


In [10]:
class PairwiseArenaDataset(Dataset):
    """
    Expects a list of dicts, each with at least:
      {
        "question": str,
        "answer_a": str,
        "answer_b": str,
        "human_label": "model_a" | "model_b" | "tie" | "both_bad"
      }
    """
    def __init__(self, data: List[Dict], tokenizer: AutoTokenizer, max_len: int = 1024):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len

        self.system_prefix = (
            "You are a strict math answer judge. "
            "You will be given a question and an answer. "
            "Evaluate the answer's correctness and reasoning quality.\n\n"
        )

    def build_text(self, question: str, answer: str) -> str:
        return (
            self.system_prefix
            + "Question:\n"
            + question.strip()
            + "\n\nAnswer:\n"
            + answer.strip()
        )

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        item = self.data[idx]

        q = item["question"]
        aA = item["answer_a"]
        aB = item["answer_b"]
        lab_str = item["human_label"]

        # map label string -> int id
        label_id = label2id[lab_str]

        text_a = self.build_text(q, aA)
        text_b = self.build_text(q, aB)

        enc_a = self.tokenizer(
            text_a,
            max_length=self.max_len,
            truncation=True,
            padding=False,
            return_tensors="pt",
        )
        enc_b = self.tokenizer(
            text_b,
            max_length=self.max_len,
            truncation=True,
            padding=False,
            return_tensors="pt",
        )

        return {
            "input_ids_a": enc_a["input_ids"].squeeze(0),
            "attention_mask_a": enc_a["attention_mask"].squeeze(0),
            "input_ids_b": enc_b["input_ids"].squeeze(0),
            "attention_mask_b": enc_b["attention_mask"].squeeze(0),
            "labels": torch.tensor(label_id, dtype=torch.long),
        }


# -------------------------
# 3. Collator
# -------------------------

@dataclass
class PairwiseCollator:
    tokenizer: AutoTokenizer
    pad_to_multiple_of: int = 8

    def __call__(self, batch: List[Dict]) -> Dict[str, torch.Tensor]:
        # We need to pad A and B separately
        ids_a = [x["input_ids_a"] for x in batch]
        mask_a = [x["attention_mask_a"] for x in batch]
        ids_b = [x["input_ids_b"] for x in batch]
        mask_b = [x["attention_mask_b"] for x in batch]
        labels = torch.stack([x["labels"] for x in batch], dim=0)

        enc_a = self.tokenizer.pad(
            {"input_ids": ids_a, "attention_mask": mask_a},
            padding=True,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        enc_b = self.tokenizer.pad(
            {"input_ids": ids_b, "attention_mask": mask_b},
            padding=True,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        return {
            "input_ids_a": enc_a["input_ids"],
            "attention_mask_a": enc_a["attention_mask"],
            "input_ids_b": enc_b["input_ids"],
            "attention_mask_b": enc_b["attention_mask"],
            "labels": labels,
        }


In [11]:
def reward_loss_4way(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
    """
    logits: shape (batch, 4) â€“ unnormalized scores for [model_a, model_b, tie, both_bad]
    labels: shape (batch,) with values in {0,1,2,3}
    """
    return F.cross_entropy(logits, labels)

class RewardModel(nn.Module):
    """
    Wraps a causal LM with:
      - a scalar reward head (for scores_a / scores_b)
      - a 4-way classification head over (A,B) pairs.

    LoRA is applied to the LM via PEFT; the reward_head and classifier
    are normal nn.Linear modules.
    """

    def __init__(self, base_model: AutoModelForCausalLM):
        super().__init__()
        self.base_model = base_model
        hidden_size = base_model.config.hidden_size

        base_dtype = next(base_model.parameters()).dtype

        # Reward head: scalar score from hidden vector
        self.reward_head = nn.Linear(hidden_size, 1)
        self.reward_head.to(dtype=base_dtype)

        # 4-way classification head: takes [hA, hB] -> 4 logits
        self.classifier = nn.Linear(2 * hidden_size, 4)
        self.classifier.to(dtype=base_dtype)

    def encode_hidden(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> torch.Tensor:
        """
        Return last hidden state vector for each sequence in the batch.
        Shape: (batch, hidden_size)
        """
        outputs = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
        )
        hidden_states = outputs.hidden_states[-1]  # (batch, seq, hidden)

        last_indices = attention_mask.sum(dim=1) - 1
        batch_idx = torch.arange(hidden_states.size(0), device=hidden_states.device)
        last_hidden = hidden_states[batch_idx, last_indices]  # (batch, hidden)
        return last_hidden

    def forward(
        self,
        input_ids_a: torch.Tensor,
        attention_mask_a: torch.Tensor,
        input_ids_b: torch.Tensor,
        attention_mask_b: torch.Tensor,
        labels: torch.Tensor = None,
    ):
        # 1. Get hidden vectors for A and B
        hA = self.encode_hidden(input_ids_a, attention_mask_a)  # (batch, hidden)
        hB = self.encode_hidden(input_ids_b, attention_mask_b)  # (batch, hidden)

        # 2. Scalar reward scores (for later use / diagnostics)
        sA = self.reward_head(hA).squeeze(-1)  # (batch,)
        sB = self.reward_head(hB).squeeze(-1)  # (batch,)

        # 3. 4-way classification logits
        pair_repr = torch.cat([hA, hB], dim=-1)  # (batch, 2*hidden)
        logits = self.classifier(pair_repr)      # (batch, 4)

        if labels is not None:
            loss = reward_loss_4way(logits, labels)
            return {
                "loss": loss,
                "scores_a": sA,
                "scores_b": sB,
                "logits": logits,
            }
        else:
            return {
                "scores_a": sA,
                "scores_b": sB,
                "logits": logits,
            }
def build_model_with_lora():
    config = AutoConfig.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
    config.output_hidden_states = True  # so we can access hidden states

    base_model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        config=config,
        torch_dtype=torch.bfloat16,
        device_map={"": DEVICE},
        cache_dir=CACHE_DIR,
    )

    lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                        "up_proj", "down_proj", "gate_proj"],
    )
    base_model = get_peft_model(base_model, lora_config)
    base_model.print_trainable_parameters()

    rm = RewardModel(base_model)
    rm.to(DEVICE)
    return rm

@torch.inference_mode()
def predict_pair_label_from_logits(question, answer_a, answer_b):
    # build inputs as usual, get model outputs ...
    out = model(
        input_ids_a=...,
        attention_mask_a=...,
        input_ids_b=...,
        attention_mask_b=...,
        labels=None,
    )
    logits = out["logits"]  # (1, 4)
    pred_id = logits.argmax(dim=-1).item()
    pred_label = id2label[pred_id]  # {0:"model_a",1:"model_b",2:"tie",3:"both_bad"}
    return pred_label


In [12]:
def train_reward_model(
    train_data: List[Dict],
    val_data: List[Dict] = None,   # optional, still unused
    test_data: List[Dict] = None,
):

    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    train_ds = PairwiseArenaDataset(train_data, tokenizer, max_len=MAX_LEN)
    collator = PairwiseCollator(tokenizer)

    train_loader = DataLoader(
        train_ds,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=collator,
    )

    # Build model with LoRA (now has classifier head)
    model = build_model_with_lora()
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

    # Optional test loader
    if test_data is not None:
        test_ds = PairwiseArenaDataset(test_data, tokenizer, max_len=MAX_LEN)
        test_loader = DataLoader(
            test_ds,
            batch_size=1,          # easier for per-example stats
            shuffle=False,
            collate_fn=collator,
        )
    else:
        test_loader = None

    # --------------------------
    # Training loop
    # --------------------------
    for epoch in range(NUM_EPOCHS):
        model.train()
        total_loss = 0.0

        for step, batch in enumerate(train_loader):
            batch = {k: v.to(DEVICE) for k, v in batch.items()}

            out = model(
                input_ids_a=batch["input_ids_a"],
                attention_mask_a=batch["attention_mask_a"],
                input_ids_b=batch["input_ids_b"],
                attention_mask_b=batch["attention_mask_b"],
                labels=batch["labels"],              # used by CE loss
            )
            loss = out["loss"]

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()

            if (step + 1) % 50 == 0:
                print(
                    f"Epoch {epoch+1} | Step {step+1} | "
                    f"Train Loss {total_loss/(step+1):.4f}"
                )

        print(f"[Epoch {epoch+1}] Train Avg Loss = {total_loss / len(train_loader):.4f}")

        # --------------------------
        #   Test-set evaluation (4-way classification)
        # --------------------------
        if test_loader is not None:
            model.eval()
            test_loss = 0.0
            total = 0
            correct = 0

            per_label_counts = Counter()
            per_label_correct = Counter()

            with torch.no_grad():
                for batch in test_loader:
                    batch = {k: v.to(DEVICE) for k, v in batch.items()}

                    out = model(
                        input_ids_a=batch["input_ids_a"],
                        attention_mask_a=batch["attention_mask_a"],
                        input_ids_b=batch["input_ids_b"],
                        attention_mask_b=batch["attention_mask_b"],
                        labels=batch["labels"],      # still compute loss on test
                    )

                    loss = out["loss"]
                    test_loss += loss.item()

                    logits = out["logits"]          # shape (1, 4)
                    pred_id = logits.argmax(dim=-1).item()
                    gold_id = batch["labels"].item()

                    pred_label = id2label[pred_id]
                    gold_label = id2label[gold_id]

                    total += 1
                    per_label_counts[gold_label] += 1

                    if pred_id == gold_id:
                        correct += 1
                        per_label_correct[gold_label] += 1

            overall_acc = correct / total if total > 0 else 0.0
            avg_test_loss = test_loss / total if total > 0 else 0.0

            print(
                f"[Epoch {epoch+1}] Test Avg Loss = {avg_test_loss:.4f} | "
                f"Overall Test Acc = {overall_acc:.4f}"
            )
            print("  Per-label accuracy:")
            for lab, cnt in per_label_counts.items():
                acc_lab = per_label_correct[lab] / cnt if cnt > 0 else 0.0
                print(f"    {lab:9s}: {acc_lab:.4f} (n={cnt})")

    # --------------------------
    # Save trained LoRA RM
    # --------------------------
    save_dir = "./models/qwen2_5_math7b_reward_lora_classify_bs_2"
    os.makedirs(save_dir, exist_ok=True)
    model.base_model.save_pretrained(save_dir)
    torch.save(model.reward_head.state_dict(), os.path.join(save_dir, "reward_head.pt"))
    print(f"Saved LoRA RM to {save_dir}")

In [13]:
from sklearn.model_selection import train_test_split

def load_arena_json(path: str):
    """
    Expects a JSON file containing a list of dicts like:

    [
      {
        "id": "...",
        "question": "...",
        "model_a": "...",
        "model_b": "...",
        "answer_a": "...",
        "answer_b": "...",
        "human_label": "model_a" | "model_b" | "tie" | "both_bad"
      },
      ...
    ]
    """
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)

    # Filter to examples with valid labels
    filtered = [ex for ex in data if ex.get("human_label") in label2id]
    return filtered

def split_train_val_test(data, val_ratio=0.1, test_ratio=0.1, seed=42):
    # First train+temp vs test
    train_val, test = train_test_split(
        data, test_size=test_ratio, random_state=seed, shuffle=True
    )
    # Then split train vs val
    val_size = val_ratio / (1.0 - test_ratio)
    train, val = train_test_split(
        train_val, test_size=val_size, random_state=seed, shuffle=True
    )
    return train, val, test

In [14]:
arena_data = load_arena_json("./data/arena_140k_math_filtered.json")
print("Loaded", len(arena_data), "examples")

train_data, val_data, test_data = split_train_val_test(arena_data, val_ratio=0.1, test_ratio=0.1)
print("Train:", len(train_data), "Val:", len(val_data), "Test:", len(test_data))
# print(train_data)
train_reward_model(train_data, val_data, test_data)

Loaded 894 examples
Train: 714 Val: 90 Test: 90


The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

You're using a Qwen2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


trainable params: 40,370,176 || all params: 7,655,986,688 || trainable%: 0.5273
Epoch 1 | Step 50 | Train Loss 2.9537
Epoch 1 | Step 100 | Train Loss 2.7170
Epoch 1 | Step 150 | Train Loss 2.4427
Epoch 1 | Step 200 | Train Loss 2.3119
Epoch 1 | Step 250 | Train Loss 2.1666
Epoch 1 | Step 300 | Train Loss 2.0797
Epoch 1 | Step 350 | Train Loss 2.0054
[Epoch 1] Train Avg Loss = 2.0052
[Epoch 1] Test Avg Loss = 1.8012 | Overall Test Acc = 0.3556
  Per-label accuracy:
    model_a  : 1.0000 (n=32)
    model_b  : 0.0000 (n=21)
    both_bad : 0.0000 (n=18)
    tie      : 0.0000 (n=19)
Epoch 2 | Step 50 | Train Loss 1.4822
Epoch 2 | Step 100 | Train Loss 1.4528
Epoch 2 | Step 150 | Train Loss 1.4863
Epoch 2 | Step 200 | Train Loss 1.4770
Epoch 2 | Step 250 | Train Loss 1.4697
Epoch 2 | Step 300 | Train Loss 1.4757
Epoch 2 | Step 350 | Train Loss 1.4762
[Epoch 2] Train Avg Loss = 1.4773
[Epoch 2] Test Avg Loss = 1.4480 | Overall Test Acc = 0.2333
  Per-label accuracy:
    model_a  : 0.0000 (n=3