In [1]:
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    BitsAndBytesConfig,
    DataCollatorWithPadding,
)
from torch.utils.data import DataLoader
import numpy as np
import torch
import datasets
from torch.nn import CrossEntropyLoss
from tqdm import tqdm

In [2]:
def remove_extra_brackets(input: str) -> str:
    text = input[2:-2]
    text = text.strip()
    return text


CLASSIFICATION_PROMPT = """
You are an expert AI assistant that is specialized in selecting user preferences.
The task it that you are provided with a prompt and two responses (A and B) to that prompt from different LLMs.
The possible outcomes are 3 classes:
- Winner A: Response A is better
- Winner B: Response B is better
- Tie: Both responses are equally good

The prompt to the models is:
```
{prompt}
```

Then response A is:
```
{response_a}
```

And response B is:
```
{response_b}
```

Based on the above, classify which response is better by choosing one of the following options: "Winner A", "Winner B", or "Tie".
"""


In [3]:
# Load multiple CSV files
df = datasets.load_dataset(
    "csv", data_files={"train": "../data/train.csv", "test": "../data/test.csv"}
)

In [4]:
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForSequenceClassification.from_pretrained(
    "Vmpletsos/qwen3-4b-merged-preference-classifier",
    attn_implementation="sdpa",
    num_labels=3,
    torch_dtype=torch.bfloat16,
    device_map="cuda",
)
model.config.pad_token_id = model.config.eos_token_id

`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Some weights of Qwen3ForSequenceClassification were not initialized from the model checkpoint at Qwen/Qwen3-4B-Instruct-2507 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
df = df['train'].shuffle().select(range(1000))

In [6]:
def fix_dataset(row):
    cleaned_prompt = remove_extra_brackets(row["prompt"])
    cleaned_response_a = remove_extra_brackets(row["response_a"])
    cleaned_response_b = remove_extra_brackets(row["response_b"])

    full_prompt = CLASSIFICATION_PROMPT.format(
        prompt=cleaned_prompt,
        response_a=cleaned_response_a,
        response_b=cleaned_response_b,
    )

    tokenized = tokenizer(full_prompt)

    winner = [row["winner_model_a"], row["winner_model_b"], row["winner_tie"]]

    return {**tokenized, "winner": winner, "length": len(tokenized["input_ids"])}

df = df.map(fix_dataset, batched=False).remove_columns(
    [
        "id",
        "model_a",
        "model_b",
        "prompt",
        "response_a",
        "response_b",
        "winner_model_a",
        "winner_model_b",
        "winner_tie",
    ]
)

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [7]:
df = df.with_format("torch")
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=True)
dataloader = DataLoader(
    df, batch_size=2, shuffle=False, collate_fn=data_collator
)

In [None]:
loss_fn = CrossEntropyLoss()
total_loss = 0
total_correct = 0
total_count = 0
data_size = len(dataloader)
for data in tqdm(dataloader):
    with torch.no_grad():
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            data = {
                key: value.to("cuda") for key, value in data.items()
            }
            outputs = model(data["input_ids"]).logits
            _, predicted = torch.max(outputs, 1)
            _, true_labels = torch.max(data["winner"], 1)
            total_loss += loss_fn(outputs, true_labels).item()
            total_count += true_labels.size(0)
            total_correct += (predicted == true_labels).sum().item()

accuracy = 100 * (total_correct / total_count)
print(f"Validation Accuracy: {accuracy:.2f}%")
print(f"Validation Loss: {total_loss/data_size:.4f}")

100%|██████████| 500/500 [01:09<00:00,  7.20it/s]

Validation Accuracy: 60.50%
Validation Loss: 0.8676%



