# Finetune LLMs with GRPO

This notebook shows how to finetune an LLM with GRPO, using the `trl` library.

It's by [Ben Burtenshaw](https://huggingface.co/burtenshaw) and [Maxime Labonne](https://huggingface.co/mlabonne).

This is a minimal example. For a complete example, refer to the GRPO chapter in the [course](https://huggingface.co/course/en/chapter12/1).

## Install dependencies

In [1]:
!pip install -qqq datasets==3.2.0 transformers==4.47.1 trl==0.14.0 peft==0.14.0 accelerate==1.2.1 bitsandbytes==0.45.2 wandb==0.19.7 --progress-bar off
!pip install -qqq flash-attn --no-build-isolation --progress-bar off

In [16]:
import torch
print(torch.version.cuda)


12.1


## Load Dataset

In [1]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [2]:
from dotenv import load_dotenv
import os
import wandb

# Load environment variables from .env
load_dotenv()

# W&B reads your key from environment automatically
wandb.login()  # no key needed in notebook


[34m[1mwandb[0m: (1) Create a W&B account
[34m[1mwandb[0m: (2) Use an existing W&B account
[34m[1mwandb[0m: (3) Don't visualize my results
[34m[1mwandb[0m: Enter your choice:[34m[1mwandb[0m: You chose 'Use an existing W&B account'
[34m[1mwandb[0m: Logging into https://api.wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: Find your API key here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mbeshoyarnest01[0m ([33mbeshoyarnest01-minia-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [5]:
import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOConfig, GRPOTrainer

In [6]:
# Load dataset
dataset = load_dataset("mlabonne/smoltldr")
print(dataset)

Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


README.md:   0%|          | 0.00/981 [00:00<?, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/1.44M [00:00<?, ?B/s]

data/validation-00000-of-00001.parquet:   0%|          | 0.00/151k [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/151k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/200 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/200 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['prompt', 'completion'],
        num_rows: 2000
    })
    validation: Dataset({
        features: ['prompt', 'completion'],
        num_rows: 200
    })
    test: Dataset({
        features: ['prompt', 'completion'],
        num_rows: 200
    })
})


## Load Model

In [7]:
# Load model with fallback if flash_attn is not available
model_id = "HuggingFaceTB/SmolLM-135M-Instruct"
try:
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype="auto",
        device_map="auto",
        attn_implementation="flash_attention_2",
    )
except ImportError:
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype="auto",
        device_map="auto",
    )
tokenizer = AutoTokenizer.from_pretrained(model_id)

config.json:   0%|          | 0.00/723 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/269M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/565 [00:00<?, ?B/s]

### LoRA Config

In [8]:
# Load LoRA
lora_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=16,
    lora_alpha=32,
    target_modules="all-linear",
)

In [9]:
model = get_peft_model(model, lora_config)
print(model.print_trainable_parameters())

trainable params: 4,884,480 || all params: 139,399,488 || trainable%: 3.5039
None


## Define Reward Function

In [10]:
# Reward function
def reward_len(completions, **kwargs):
    return [-abs(50 - len(completion)) for completion in completions]

## Define Training Arguments

In [23]:
# Training arguments
training_args = GRPOConfig(
    output_dir="GRPO",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    max_prompt_length=512,
    max_completion_length=96,
    num_generations=8,
    optim="adamw_torch",
    num_train_epochs=1,
    bf16=True,
    report_to=["wandb"],
    remove_unused_columns=False,
    logging_steps=1,
)



### Initialize the trainer

In [24]:
# Trainer
trainer = GRPOTrainer(
    model=model,
    reward_funcs=[reward_len],
    args=training_args,
    train_dataset=dataset["train"],
)

The model is already on multiple devices. Skipping the move to device specified in `args`.


In [25]:
wandb.init(project="GRPO")

0,1
profiling/Time taken: GRPOTrainer._calculate_rewards,▂▃█▅▅▄▃▂▂▂▂▂▂▂▂▂▃▂▁▂▂▂▂▃▁▂▂▂▂
profiling/Time taken: GRPOTrainer._get_per_token_logps_and_entropies,▃▃▁▁▂▃▂▄▄▁▂▃▃▂▃▂▁▂▃▃▁▁▂▂▃▁▁▃▃█▁▃▂▂▃▂▂▂▂▃
profiling/Time taken: GRPOTrainer._prepare_inputs,█▁█▁▁▁▁█▇▇▁▇▁▇▁▁██▁█▇█▁█▁▁█▁█▁▁█▁█▁▁████
profiling/Time taken: GRPOTrainer.compute_loss,▄▄▂▂▃▂▂▅▅▂▃▃▃▃▃▂▂▂▂▄▄▄▁▁▂▃▁▃▃█▁▄▂▂▄▂▂▃▃▃
profiling/Time taken: GRPOTrainer.reward_len,▂▂█▄▄▂▃▃▁▁▃▂▂▂▁▂▁▂▂▁▁▁▁▂▁▁▂▁▁
profiling/Time taken: GRPOTrainer.transformers.generate,█▅▆▆▄▆▃▃▂▂▂▁▄▆▆▃▅▆▄▆▆▄▆▅▆▅▅▆▆
train/clip_ratio/high_max,▁▁▁▁▁▁▁▁▁
train/clip_ratio/high_mean,▁▁▁▁▁▁▁▁▁
train/clip_ratio/low_mean,▁▁▁▁▁▁▁▁▁
train/clip_ratio/low_min,▁▁▁▁▁▁▁▁▁

0,1
profiling/Time taken: GRPOTrainer._calculate_rewards,0.00079
profiling/Time taken: GRPOTrainer._get_per_token_logps_and_entropies,0.33982
profiling/Time taken: GRPOTrainer._prepare_inputs,8.95228
profiling/Time taken: GRPOTrainer.compute_loss,0.49418
profiling/Time taken: GRPOTrainer.reward_len,8e-05
profiling/Time taken: GRPOTrainer.transformers.generate,8.91702
train/clip_ratio/high_max,0
train/clip_ratio/high_mean,0
train/clip_ratio/low_mean,0
train/clip_ratio/low_min,0


In [26]:
# Train model
trainer.train()

Step,Training Loss
1,0.1567
2,0.1864
3,0.1917
4,0.3145
5,0.088
6,0.2849
7,0.2729
8,0.2354
9,0.2308
10,0.1616


TrainOutput(global_step=1000, training_loss=0.12347275123046711, metrics={'train_runtime': 6562.9425, 'train_samples_per_second': 0.305, 'train_steps_per_second': 0.152, 'total_flos': 0.0, 'train_loss': 0.12347275123046711})

## Push Model to Hub

In [33]:
# Save model
merged_model = trainer.model.merge_and_unload()
merged_model.push_to_hub( "fine_tuned_SmolGRPO-135M_using_GRPO", private=False, tags=["GRPO"])

README.md: 0.00B [00:00, ?B/s]

Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  ...90ggj6m/model.safetensors:   9%|9         | 25.1MB /  269MB            

No files have been modified since last commit. Skipping to prevent empty commit.


CommitInfo(commit_url='https://huggingface.co/7beshoyarnest/fine_tuned_SmolGRPO-135M_using_GRPO/commit/1d8a123b78172cf6dc9b4169ed1a0326f0bece5b', commit_message='Upload LlamaForCausalLM', commit_description='', oid='1d8a123b78172cf6dc9b4169ed1a0326f0bece5b', pr_url=None, repo_url=RepoUrl('https://huggingface.co/7beshoyarnest/fine_tuned_SmolGRPO-135M_using_GRPO', endpoint='https://huggingface.co', repo_type='model', repo_id='7beshoyarnest/fine_tuned_SmolGRPO-135M_using_GRPO'), pr_revision=None, pr_num=None)

In [38]:
tokenizer.save_pretrained("fine_tuned_SmolGRPO-135M_using_GRPO")

('fine_tuned_SmolGRPO-135M_using_GRPO/tokenizer_config.json',
 'fine_tuned_SmolGRPO-135M_using_GRPO/special_tokens_map.json',
 'fine_tuned_SmolGRPO-135M_using_GRPO/chat_template.jinja',
 'fine_tuned_SmolGRPO-135M_using_GRPO/vocab.json',
 'fine_tuned_SmolGRPO-135M_using_GRPO/merges.txt',
 'fine_tuned_SmolGRPO-135M_using_GRPO/added_tokens.json',
 'fine_tuned_SmolGRPO-135M_using_GRPO/tokenizer.json')

In [39]:
tokenizer.push_to_hub( "fine_tuned_SmolGRPO-135M_using_GRPO", private=False, tags=["GRPO"])

CommitInfo(commit_url='https://huggingface.co/7beshoyarnest/fine_tuned_SmolGRPO-135M_using_GRPO/commit/9b01868f9efde21e3a61c1da0d517cc89f9bdb52', commit_message='Upload tokenizer', commit_description='', oid='9b01868f9efde21e3a61c1da0d517cc89f9bdb52', pr_url=None, repo_url=RepoUrl('https://huggingface.co/7beshoyarnest/fine_tuned_SmolGRPO-135M_using_GRPO', endpoint='https://huggingface.co', repo_type='model', repo_id='7beshoyarnest/fine_tuned_SmolGRPO-135M_using_GRPO'), pr_revision=None, pr_num=None)

## Generate Text

In [40]:
prompt = """
# A long document about the Cat

The cat (Felis catus), also referred to as the domestic cat or house cat, is a small 
domesticated carnivorous mammal. It is the only domesticated species of the family Felidae.
Advances in archaeology and genetics have shown that the domestication of the cat occurred
in the Near East around 7500 BC. It is commonly kept as a pet and farm cat, but also ranges
freely as a feral cat avoiding human contact. It is valued by humans for companionship and
its ability to kill vermin. Its retractable claws are adapted to killing small prey species
such as mice and rats. It has a strong, flexible body, quick reflexes, and sharp teeth,
and its night vision and sense of smell are well developed. It is a social species,
but a solitary hunter and a crepuscular predator. Cat communication includes
vocalizations—including meowing, purring, trilling, hissing, growling, and grunting—as
well as body language. It can hear sounds too faint or too high in frequency for human ears,
such as those made by small mammals. It secretes and perceives pheromones.
"""

messages = [
    {"role": "user", "content": prompt},
]

In [44]:
# Generate text
from transformers import pipeline
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("7beshoyarnest/fine_tuned_SmolGRPO-135M_using_GRPO")
tokenizer = AutoTokenizer.from_pretrained("7beshoyarnest/fine_tuned_SmolGRPO-135M_using_GRPO")

generator = pipeline("text-generation", model=model, tokenizer=tokenizer)

# # Or use the model and tokenizer we defined earlier
# generator = pipeline("text-generation", model=model, tokenizer=tokenizer)

generate_kwargs = {
    "max_new_tokens": 256,
    "do_sample": True,
    "temperature": 0.5,
    "min_p": 0.1,
}

generated_text = generator(messages, **generate_kwargs)

# Clean dictionary format
output_dict = {
    "user": messages,
    "assistant": generated_text[0]["generated_text"]
}

# Pretty-print
import json
print(json.dumps(output_dict, indent=4, ensure_ascii=False))
# print(generated_text)

Device set to use cuda:0


{
    "user": [
        {
            "role": "user",
            "content": "\n# A long document about the Cat\n\nThe cat (Felis catus), also referred to as the domestic cat or house cat, is a small \ndomesticated carnivorous mammal. It is the only domesticated species of the family Felidae.\nAdvances in archaeology and genetics have shown that the domestication of the cat occurred\nin the Near East around 7500 BC. It is commonly kept as a pet and farm cat, but also ranges\nfreely as a feral cat avoiding human contact. It is valued by humans for companionship and\nits ability to kill vermin. Its retractable claws are adapted to killing small prey species\nsuch as mice and rats. It has a strong, flexible body, quick reflexes, and sharp teeth,\nand its night vision and sense of smell are well developed. It is a social species,\nbut a solitary hunter and a crepuscular predator. Cat communication includes\nvocalizations—including meowing, purring, trilling, hissing, growling, and grunting

In [1]:
import wandb
api = wandb.Api()
run = api.run("/beshoyarnest01-minia-university/GRPO/runs/b19l41iz")
#Show history data:
print(run.history())

    train/completions/clipped_ratio  train/step_time  \
0                               NaN              NaN   
1                               NaN              NaN   
2                               NaN              NaN   
3                               NaN              NaN   
4                               NaN              NaN   
5                               NaN              NaN   
6                               NaN              NaN   
7                               NaN              NaN   
8                               NaN              NaN   
9                            0.9375        13.612489   
10                              NaN              NaN   
11                              NaN              NaN   
12                              NaN              NaN   
13                              NaN              NaN   
14                              NaN              NaN   
15                              NaN              NaN   
16                              NaN             