In [1]:
import os
import re
import json
import random
from typing import Dict, Any

import torch
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState

  from .autonotebook import tqdm as notebook_tqdm


ModuleNotFoundError: No module named 'peft'

In [None]:
MODEL_NAME = os.environ.get("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct")
DATA_DIR = os.environ.get("DATA_DIR", "data/squad_llm_judge/processed/llm_training")
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "outputs/llm_judge_lora")
SEED = int(os.environ.get("SEED", 42))

torch.manual_seed(SEED)
random.seed(SEED)

In [None]:
# dtype: T4 normalmente fp16; se GPU suportar bf16, usa bf16
use_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8
DTYPE = torch.bfloat16 if use_bf16 else torch.float16

# QLoRA (4-bit NF4)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=DTYPE,
)

# LoRA config
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
)

In [None]:
PROMPT_TMPL = """You are an evaluator assessing the retrieval effectiveness of dense
retrieval ( Cosine Distance ) and BM25 retrieval for finding the
correct answer .
## Task :
Given a question and two top1 search results ( one from dense retrieval ,
one from BM25 retrieval ) , score each retrieval method from **0 to 5**
based on whether the correct answer is likely to appear in top2 ,
top3 , etc .
### ** Scoring Criteria :**
1. ** Direct hit --> 5 points **
- If the retrieved document directly answers the question , assign **5
points **.
2. ** Good wrong result ( High likelihood correct answer is nearby ) --> 3 -4
points **
- If the top1 result is ** conceptually close ** to the correct answer (
e . g . , mentions relevant entities , related events , partial answer ) ,
it indicates the search method is in the right direction .
- Give **4** if it 's very close , **3** if somewhat close .
3. ** Bad wrong result ( Low likelihood correct answer is nearby ) --> 1 -2
points **
- If the top1 result is ** loosely related but misleading ** ( e . g . ,
shares keywords but changes context ) , correct answers might not be
in top2 , top3 .
- Give **2** if there 's a small chance correct answers are nearby ,
**1** if unlikely .
4. ** Completely off - track --> 0 points **
- If the result is ** totally unrelated ** , it means the retrieval
method is failing .
---
### ** Given Data :**
- ** Question :** "{question}"
- ** dense retrieval Top1 Result :** "{vector_reference}"
- ** BM25 retrieval Top1 Result :** "{bm25_reference}"
---
### ** Output Format :**
Return two integers separated by a space :
- ** First number :** dense retrieval score .
- ** Second number :** BM25 retrieval score .
- Example output : 3 4
( Vector : 3 , BM25 : 4)
** Do not output any other text .**
"""


def build_prompt(ex: Dict[str, Any]) -> str:
    return PROMPT_TMPL.format(
        question=(ex.get("question") or "").strip().replace('"', '\\"'),
        vector_reference=(ex.get("ctx_dense_top1") or "").strip().replace('"', '\\"'),
        bm25_reference=(ex.get("ctx_bm25_top1") or "").strip().replace('"', '\\"'),
    )


def build_completion(ex: Dict[str, Any]) -> str:
    # Formato alvo: "<dense> <bm25>"
    return f"{int(ex['score_dense'])} {int(ex['score_bm25'])}"

In [None]:
def load_llm_judge_dataset(path: str) -> DatasetDict:
    data_files = {
        "train": os.path.join(path, "train.parquet"),
        "test": os.path.join(path, "test.parquet"),
    }
    ds = load_dataset("parquet", data_files=data_files)
    ds = DatasetDict(train=ds["train"], eval=ds["test"])

    def to_prompt_completion(ex):
        return {
            "prompt": build_prompt(ex),
            "completion": build_completion(ex),
            "alpha_label": float(ex.get("alpha_label", 0.5)),
            "id": ex.get("id", ""),
        }

    cols = ds["train"].column_names
    ds = ds.map(to_prompt_completion, remove_columns=cols)
    return ds

PAIR_RE = re.compile(r"\b([0-5])\s+([0-5])\b")


def parse_pair(txt: str):
    m = PAIR_RE.search(txt or "")
    if not m:
        return None, None
    dense = int(m.group(1))
    bm25 = int(m.group(2))
    return dense, bm25


def alpha_from(sb: int, sv: int) -> float:
    # sb = score_bm25, sv = score_dense
    s = sb + sv
    return float(sv) / s if s > 0 else 0.5

class GenEvalCallback(TrainerCallback):
    def __init__(self, trainer: SFTTrainer, tokenizer, eval_samples=800, max_new_tokens=6):
        self.trainer = trainer
        self.tok = tokenizer
        self.eval_samples = eval_samples
        self.max_new_tokens = max_new_tokens

    def on_evaluate(self, args, state: TrainerState, control: TrainerControl, **kwargs):
        model = self.trainer.model.eval()
        ds = self.trainer.eval_dataset

        idxs = list(range(len(ds)))
        random.shuffle(idxs)
        idxs = idxs[: self.eval_samples]

        gold_dense, gold_bm25 = [], []
        pred_dense, pred_bm25 = [], []
        gold_alpha, pred_alpha = [], []

        with torch.no_grad():
            for i in idxs:
                row = ds[i]
                prompt = row["prompt"]

                gd, gb = parse_pair(row["completion"])
                gold_dense.append(gd)
                gold_bm25.append(gb)
                gold_alpha.append(alpha_from(gb, gd))

                inputs = self.tok(prompt, return_tensors="pt").to(model.device)

                gen = model.generate(
                    **inputs,
                    max_new_tokens=self.max_new_tokens,
                    do_sample=False,
                    temperature=0.0,
                    pad_token_id=self.tok.eos_token_id,
                    eos_token_id=self.tok.eos_token_id,
                )
                out = self.tok.decode(
                    gen[0][inputs["input_ids"].shape[1]:],
                    skip_special_tokens=True,
                )

                pd, pb = parse_pair(out)
                if pd is None or pb is None:
                    pred_dense.append(-1)
                    pred_bm25.append(-1)
                    pred_alpha.append(0.5)
                else:
                    pred_dense.append(pd)
                    pred_bm25.append(pb)
                    pred_alpha.append(alpha_from(pb, pd))

        def acc(a, b):
            return sum(1 for x, y in zip(a, b) if x == y) / max(1, len(a))

        def mae(a, b):
            return sum(abs(x - y) for x, y in zip(a, b)) / max(1, len(a))

        pd_fill = [x if x >= 0 else 0 for x in pred_dense]
        pb_fill = [x if x >= 0 else 0 for x in pred_bm25]

        metrics = {
            "gen/acc_dense": acc(gold_dense, pred_dense),
            "gen/acc_bm25": acc(gold_bm25, pred_bm25),
            "gen/acc_both": acc(
                [f"{d}-{b}" for d, b in zip(gold_dense, gold_bm25)],
                [f"{d}-{b}" for d, b in zip(pred_dense, pred_bm25)],
            ),
            "gen/mae_dense": mae(gold_dense, pd_fill),
            "gen/mae_bm25": mae(gold_bm25, pb_fill),
            "gen/mae_alpha": mae(gold_alpha, pred_alpha),
        }

        self.trainer.log(metrics)
        return control

In [None]:
ds = load_llm_judge_dataset(DATA_DIR)

tok = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

if tok.pad_token is None:
    tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    torch_dtype=DTYPE,
    device_map="auto",
)

model.config.use_cache = False
sft_cfg = SFTConfig(
    output_dir=OUTPUT_DIR,
    num_train_epochs=2,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=8,  # batch efetivo ~32
    learning_rate=1e-4,
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    logging_steps=50,
    save_strategy="steps",
    save_steps=1000,
    save_total_limit=2,
    evaluation_strategy="steps",
    eval_steps=1000,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    gradient_checkpointing=True,
    fp16=not use_bf16,
    bf16=use_bf16,
    packing=False,
    max_seq_length=1536,
    report_to=["none"],
    seed=SEED,
)
trainer = SFTTrainer(
    model=model,
    tokenizer=tok,
    args=sft_cfg,
    peft_config=peft_config,
    train_dataset=ds["train"],
    eval_dataset=ds["eval"],
)

# Callback com métricas reais baseadas em geração
trainer.add_callback(GenEvalCallback(trainer, tok, eval_samples=800, max_new_tokens=6))
trainer.train(resume_from_checkpoint=True)
trainer.save_model()
tok.save_pretrained(OUTPUT_DIR)
meta = {
    "model": MODEL_NAME,
    "dtype": str(DTYPE),
    "seed": SEED,
    "train_size": len(ds["train"]),
    "eval_size": len(ds["eval"]),
    "prompt_style": "paper_official_dense_first_space_sep",
}
with open(os.path.join(OUTPUT_DIR, "RUN_METADATA.json"), "w") as f:
    json.dump(meta, f, indent=2)
print(f"DONE → {OUTPUT_DIR}")