In [3]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import get_peft_model, LoraConfig
from trl import GRPOTrainer, GRPOConfig
from beir.retrieval.evaluation import EvaluateRetrieval
from utils import get_data,creat_index,make_conversation


In [4]:
data_set="scifact"
data_path= "data/text_data/"
corpus, queries, qrels  =get_data(data_set,data_path=data_path)
queries_ids=list(queries.keys())
document_ids= list(corpus.keys())

  0%|          | 0/5183 [00:00<?, ?it/s]

In [5]:
index_path="data/index_data/"+data_set+"_docIndex"
index_path=index_path+'/doc_index'     
creat_index(index_path,corpus)
from pyserini.search.lucene import LuceneSearcher
searcher = LuceneSearcher(index_path)
searcher.set_bm25(k1=0.9, b=0.4)


mai 13, 2025 1:33:03 PM org.apache.lucene.store.MMapDirectory lookupProvider


In [10]:
# 2) Model
device = torch.device("mps")
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
base_model = AutoModelForCausalLM.from_pretrained(
  model_name,torch_dtype="auto",
).to(device)

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


In [11]:
# 4) WRAP with LoRA
peft_config = LoraConfig(
  task_type="CAUSAL_LM",
  r=32,#size of the matrixe 
  lora_alpha=32,
  lora_dropout=0.1,
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],#where to add wieghts

)
model = get_peft_model(base_model, peft_config)
model.print_trainable_parameters()

trainable params: 36,929,536 || all params: 1,580,643,840 || trainable%: 2.3364


In [13]:
def reward_func(completions, queries_id, **_):
    # --- gather all newly-generated strings -----------------------------
    if type(completions[0]) is list:
        batch_queries  = [c[0]['content'] for c in completions]
    else : batch_queries=completions
    batch_qids     = list(queries_id)

    # --- single batched call to Lucene (or Anserini) ---------------------
    # returns {qid: [hits…]}
    hits = searcher.batch_search(batch_queries,
                                 batch_qids,
                                 k=10,  # depth
                                 threads=4)  # use 4–8; tune for your CPU / SSD
    # --- compute metrics -------------------------------------------------
    rewards = []
    for qid in batch_qids:
        scores_dict   = {d.docid: d.score for d in hits[qid]}
        filtered      = {d: s for d, s in scores_dict.items() if d != str(qid)}
        ndcg, *_      = EvaluateRetrieval.evaluate(qrels,
                                                   {qid: filtered},
                                                   [10])
        rewards.append(ndcg['NDCG@10'])
    return rewards

In [6]:
from datasets import Dataset
dataset = Dataset.from_list(
    [{"prompt":value,"queries_id":key}  for key, value in queries.items()
    ]
)
SYSTEM_PROMPT = (
    "A conversation between a User and an Assistant. The User provides a query to be reformulated and rich, "
    "and the Assistant must rewrite the query without asking a question so that a keywords retrieval system can better find the correct document. "
    "No additional text is needed — only the new, reformulated query should be returned.\n"
    "Format: USER: [ORIGINAL QUERY]  ASSISTANT: [NEW QUERY]"
)
SYSTEM_PROMPT2=("Reformulate the query so keywords retrieval system can better find the correct document.\n\n\n [ORIGINAL QUERY]: ")

train_dataset = dataset.map(lambda x: make_conversation(x, SYSTEM_PROMPT2, False))


Map:   0%|          | 0/300 [00:00<?, ? examples/s]

In [18]:
train_dataset[0]

{'prompt': 'Reformulate the query so keywords retrieval system can better find the correct document.\n\n\n [ORIGINAL QUERY]: 0-dimensional biomaterials show inductive properties. [NEW QUERY]:',
 'queries_id': '1'}

In [20]:
# 5) GRPO TRAINER
grpo_config = GRPOConfig(
    output_dir="checkpoints/Qwen_GRPO_ex",
    learning_rate=1e-5,
    remove_unused_columns=False,  # to access the solution column in accuracy_reward
    gradient_accumulation_steps=20,
    num_train_epochs=25,
    #bf16=True,
    # Parameters that control de data preprocessing
    max_completion_length=48,  # default: 256
    num_generations=20,  # how many querires are generated for each prompt
    max_prompt_length=250,  # default: 512
    # Parameters related to reporting and saving
    report_to=["tensorboard"],
    logging_steps=20,# logging info into tensorboard on how many step
    save_strategy="steps",
    save_steps=100,# save at what step
)
trainer = GRPOTrainer(model=model, 
                          args=grpo_config, 
                          reward_funcs=reward_func, 
                          train_dataset=train_dataset)

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [None]:
# 6) TRAIN
trainer.train()

In [None]:
trainer.save_model("Ex_Great_model")

In [None]:
# extra to save hyber_paramters
from torch.utils.tensorboard import SummaryWriter
import os
import json
# Set up the logging directory
log_dir = "./runs/experiment-ONLY-GRPO"
writer = SummaryWriter(log_dir)
for step_idx, log in enumerate(trainer.state.log_history):
    for key, value in log.items():
        if isinstance(value, (int, float)):  # Log scalar values only
            writer.add_scalar(f"metrics/{key}", value, step_idx)
# 1. Extract and save all config params (GRPOConfig is a subclass of dataclass)
config_dict = grpo_config.to_dict()

# Optional: save full config as JSON file
with open(os.path.join(log_dir, "config.json"), "w") as f:
    json.dump(config_dict, f, indent=2)

# 2. Add all hyperparameters to TensorBoard
# TensorBoard's `add_hparams()` only accepts scalar hparams, so filter them
scalar_hparams = {
    k: v for k, v in config_dict.items()
    if isinstance(v, (int, float, str, bool))
}

# Get final loss (if available)
final_loss = None
for log in reversed(trainer.state.log_history):
    if 'loss' in log:
        final_loss = log['loss']
        break

writer.add_hparams(
    scalar_hparams,
    {'final_loss': final_loss if final_loss is not None else 0}
)
writer.close()

In [None]:
Resources:
https://huggingface.co/docs/trl/main/en/grpo_trainer
https://medium.com/%40rajveer.rathod1301/fine-tuning-deepseek-7b-with-grpo-a-comprehensive-guide-1b7a89ae21b1
https://huggingface.co/learn/cookbook/en/fine_tuning_llm_grpo_trl
https://www.entrypointai.com/blog/lora-fine-tuning/