# Fine-tuning LLM with TRL using GRPO

_Authored by: [Sergio Paniego](https://github.com/sergiopaniego)_


# 1. Install Dependencies

Let’s start by installing the essential libraries we’ll need for fine-tuning! 🚀


In [2]:
!pip install  -U -q transformers trl datasets peft accelerate
# Tested with transformers==4.48.1, trl==0.14.0.dev0, datasets==3.2.0, peft==0.14.0, accelerate==1.3.0

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.7/9.7 MB[0m [31m127.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m293.4/293.4 kB[0m [31m28.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m39.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.1/69.1 MB[0m [31m34.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m336.6/336.6 kB[0m [31m29.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m179.3/179.3 kB[0m [31m18.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━

In [3]:
!pip install git+https://github.com/huggingface/trl.git@main

Collecting git+https://github.com/huggingface/trl.git@main
  Cloning https://github.com/huggingface/trl.git (to revision main) to /tmp/pip-req-build-9ywzwq3n
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/trl.git /tmp/pip-req-build-9ywzwq3n
  Resolved https://github.com/huggingface/trl.git to commit fe4b5efe4e23f4331ba9c5b0c8bd92dc8302c287
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: trl
  Building wheel for trl (pyproject.toml) ... [?25l[?25hdone
  Created wheel for trl: filename=trl-0.14.0.dev0-py3-none-any.whl size=306806 sha256=39d8f9e5e5f19d6a6006491ca47264a846d84cac5930c8606f332e54fd2711b0
  Stored in directory: /tmp/pip-ephem-wheel-cache-ba9juo4y/wheels/86/55/e9/4fb51fd8f4973abd44ac9118a3cf4610b1271263c00f8f85c9
Successfully built trl
Installing collected packages: trl
  Atte

In [10]:
!pip install -q flash-attn --no-build-isolation

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/3.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m3.2/3.2 MB[0m [31m141.4 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m79.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for flash-attn (setup.py) ... [?25l[?25hdone


Authenticate with your Hugging Face account to save and share your model directly from this notebook 🗝️.

In [1]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
import wandb
wandb.login(relogin=True)

# 2. Load Dataset 📁


In [2]:
from datasets import load_dataset

dataset_id = "trl-lib/tldr"
train_dataset, eval_dataset, test_dataset = load_dataset(dataset_id, split=['train[:10%]', 'validation[:10%]', 'test[:10%]'])

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [3]:
train_dataset

Dataset({
    features: ['prompt', 'completion'],
    num_rows: 11672
})

In [4]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "Qwen/Qwen2-0.5B-Instruct"

# 4. Fine-Tune the Model using TRL


In [5]:
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj"],
)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype="auto",
    device_map="auto",
)

model = get_peft_model(model, lora_config)
tokenizer = AutoTokenizer.from_pretrained(model_id)

model.print_trainable_parameters()

trainable params: 540,672 || all params: 494,573,440 || trainable%: 0.1093


In [6]:
from transformers import AutoModelForSequenceClassification

reward_model_id = "weqweasdas/RM-Gemma-2B"

reward_model = AutoModelForSequenceClassification.from_pretrained(
    reward_model_id,
    num_labels=1,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

'\nreward_model = AutoModelForSequenceClassification.from_pretrained(\n    reward_model_id, \n    num_labels=1, \n    torch_dtype=torch.bfloat16,\n    quantization_config=bnb_config,\n    _attn_implementation="flash_attention_2",\n)\n'

![Image](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/grpo_visual.png)

In [7]:
from trl import GRPOConfig

# Configure training arguments using SFTConfig
training_args = GRPOConfig(
    output_dir="Qwen2-0.5B-GRPO",
    learning_rate=1e-5,
    #logging_steps=10,
    #gradient_accumulation_steps=16,
    #max_completion_length=128,
    logging_steps=2,
    gradient_accumulation_steps=2,
    max_completion_length=16,
    num_generations=2,
    #report_to=["tensorboard"]
    report_to=["wandb"]
)

## 4.3 Training the Model 🏃

In [8]:
from trl import GRPOTrainer

trainer = GRPOTrainer(
    model=model,
    reward_model=reward_model,
    args=training_args,
    train_dataset=train_dataset,
)

Time to Train the Model! 🎉

In [9]:
trainer.train()



<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Step,Training Loss
2,0.0
4,0.0
6,0.0
8,0.0001
10,0.0
12,0.0001
14,0.0001
16,0.0001
18,0.0002
20,0.0001


KeyboardInterrupt: 

Let's save the results 💾

In [None]:
#trainer.save_model(training_args.output_dir)
trainer.save_model(training_args.output_dir)
trainer.push_to_hub(dataset_name=dataset_id)