In [None]:
# train_and_infer_fixed.py
import json
import pandas as pd
import numpy as np
import torch
from datasets import Dataset
from sklearn.model_selection import GroupShuffleSplit
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
    EarlyStoppingCallback,
)
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
import torch.nn as nn

# ----------------------
# 1. Load data (flatten)
# ----------------------
train_path = "data/train_v2.jsonl"
test_path = "data/test_v4.jsonl"

train_rows = [json.loads(line) for line in open(train_path, "r", encoding="utf-8")]
test_rows = [json.loads(line) for line in open(test_path, "r", encoding="utf-8")]

rows = []
for item in train_rows:
    text = item["text"].strip()
    acronym = item["acronym"].strip()
    for opt, is_correct in item["options"].items():
        rows.append({
            "text": text,
            "acronym": acronym,
            "option_text": opt.strip(),
            "label": int(is_correct)
        })

df = pd.DataFrame(rows)
df["group_id"] = (df["text"].str.strip() + "||" + df["acronym"].str.strip()).factorize()[0]
print("Total binary pairs:", len(df))
print("Positive count:", int(df["label"].sum()), "Negative count:", int((1 - df["label"]).sum()))

# ----------------------
# 2. Grouped split
# ----------------------
gss = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
train_idx, val_idx = next(gss.split(df, groups=df["group_id"]))
df_train = df.iloc[train_idx].reset_index(drop=True)
df_val   = df.iloc[val_idx].reset_index(drop=True)
print("Train size:", len(df_train), "Val size:", len(df_val))

# ----------------------
# 3. Tokenizer & preprocess (batched)
# ----------------------
model_name = "xlm-roberta-large"   # try base first; swap to -large if GPU allows
tokenizer = AutoTokenizer.from_pretrained(model_name)

def preprocess(batch):
    # batch is a dict of lists: batch["text"], batch["acronym"], batch["option_text"], batch["label"]
    inputs = [
        f"In the context: '{t}', what does the acronym '{a}' mean? Option: {o}"
        for t, a, o in zip(batch["text"], batch["acronym"], batch["option_text"])
    ]
    tokenized = tokenizer(
        inputs,
        truncation=True,
        padding="max_length",
        max_length=256,
    )
    # tokenized is dict of lists: input_ids, attention_mask, token_type_ids (maybe)
    tokenized["labels"] = batch["label"]
    return tokenized

# Create HF datasets from pandas
ds_train = Dataset.from_pandas(df_train[["text", "acronym", "option_text", "label"]])
ds_val   = Dataset.from_pandas(df_val[["text", "acronym", "option_text", "label"]])

ds_train = ds_train.map(preprocess, batched=True, remove_columns=ds_train.column_names)
ds_val   = ds_val.map(preprocess, batched=True, remove_columns=ds_val.column_names)

# sanity-check: print a single example
print("Example tokenized train sample:")
print(ds_train[0])

# set format for trainer (Trainer will convert automatically but this is explicit)
ds_train.set_format(type="torch")
ds_val.set_format(type="torch")

# ----------------------
# 4. Model (standard PreTrainedModel)
# ----------------------
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

# ----------------------
# 5. Compute class weights for loss
# ----------------------
# simple class weights: inverse frequency
counts = df_train["label"].value_counts().to_dict()
n_pos = counts.get(1, 0)
n_neg = counts.get(0, 0)
if n_pos == 0:
    pos_weight = 1.0
else:
    pos_weight = n_neg / (n_pos + 1e-12)
print(f"pos_weight (neg/pos) = {pos_weight:.3f}")

class_weights = torch.tensor([1.0, pos_weight], dtype=torch.float)

# ----------------------
# 6. Custom Trainer to apply weighted loss
# ----------------------
class WeightedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        """
        Apply weighted CrossEntropyLoss with class weights.
        Accept **kwargs to stay compatible with new Trainer API (num_items_in_batch, etc.).
        """
        labels = inputs.get("labels")
        outputs = model(**{k: v for k, v in inputs.items() if k != "labels"})
        logits = outputs.logits

        loss_fct = nn.CrossEntropyLoss(weight=class_weights.to(logits.device))
        loss = loss_fct(logits.view(-1, model.config.num_labels), labels.view(-1))

        return (loss, outputs) if return_outputs else loss

# ----------------------
# 7. Metrics
# ----------------------
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    acc = accuracy_score(labels, preds)
    f1m = f1_score(labels, preds, average="macro")
    # AUC
    try:
        probs = torch.nn.functional.softmax(torch.tensor(logits), dim=1).numpy()
        auc = roc_auc_score(labels, probs[:, 1])
    except Exception:
        auc = float("nan")
    return {"accuracy": acc, "f1_macro": f1m, "roc_auc": auc}

# ----------------------
# 8. TrainingArguments + Trainer
# ----------------------
training_args = TrainingArguments(
    output_dir="./results_fixed_v2",
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="f1_macro",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=2,
    num_train_epochs=30,
    learning_rate=2e-5,
    weight_decay=0.01,
    seed=42,
    logging_steps=50,
    save_total_limit=3,
)

trainer = WeightedTrainer(
    model=model,
    args=training_args,
    train_dataset=ds_train,
    eval_dataset=ds_val,
    compute_metrics=compute_metrics,
    tokenizer=tokenizer,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)

# ----------------------
# 9. Train
# ----------------------
trainer.train()
trainer.save_model("./results_fixed_v2")
tokenizer.save_pretrained("./results_fixed_v2")

# ----------------------
# 10. Inference (example)
# ----------------------
model.eval()
submission = []
for item in test_rows:
    text = item["text"].strip()
    acronym = item["acronym"].strip()
    options = item["options"]  # list

    inputs = [
        f"In the context: '{text}', what does the acronym '{acronym}' mean? Option: {opt}"
        for opt in options
    ]
    enc = tokenizer(inputs, truncation=True, padding=True, max_length=256, return_tensors="pt")
    with torch.no_grad():
        logits = model(**enc).logits
        probs = torch.softmax(logits, dim=1)[:, 1].cpu().numpy()

    # rank descending, then filter by ascending indices + threshold example:
    ranked = sorted(range(len(probs)), key=lambda i: probs[i], reverse=True)
    # now apply your ascending order + threshold logic (example)
    threshold = 0.5
    selected = []
    last = -1
    for idx in ranked:
        if probs[idx] < threshold:
            continue
        if idx > last:
            selected.append(idx)
            last = idx
        else:
            break
    if not selected and len(ranked) > 0:
        selected = [ranked[0]]
    submission.append({"id": item["id"], "prediction": str(selected)})

pd.DataFrame(submission).to_csv("submission_fixed_v2.csv", index=False)
print("Saved submission_fixed_v2.csv")


  import pynvml  # type: ignore[import]
  from .autonotebook import tqdm as notebook_tqdm


Total binary pairs: 2177
Positive count: 433 Negative count: 1744
Train size: 1715 Val size: 462


Map: 100%|██████████| 1715/1715 [00:00<00:00, 4093.71 examples/s]
Map: 100%|██████████| 462/462 [00:00<00:00, 6728.76 examples/s]


Example tokenized train sample:
{'input_ids': [0, 360, 70, 43701, 12, 242, 397, 2246, 97566, 20, 1363, 224, 6, 34440, 25, 4, 2367, 14602, 70, 10, 15322, 5264, 242, 34440, 25, 29459, 32, 86769, 12, 57212, 4188, 253, 17019, 152, 110267, 224, 915, 19, 4, 32762, 5911, 224, 27998, 7, 8, 915, 19, 4, 52088, 7162, 15, 73, 6000, 137656, 16, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

Some weights of XLMRobertaForSequenceClassification were not initialized from the model checkpoint at xlm-roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = WeightedTrainer(


pos_weight (neg/pos) = 4.000




Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,Roc Auc
1,0.6705,0.739216,0.805195,0.446043,0.619668


