In [1]:
!pip install -qqq datasets==3.2.0 transformers==4.47.1 trl==0.14.0 peft==0.14.0 accelerate==1.2.1 bitsandbytes==0.45.2 wandb==0.19.7 --progress-bar off
!pip install -qqq flash-attn --no-build-isolation --progress-bar off

In [1]:
import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOConfig, GRPOTrainer


  from .autonotebook import tqdm as notebook_tqdm


INFO 03-09 22:35:20 __init__.py:183] Automatically detected platform cuda.


2025-03-09 22:35:21,069	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


# Wandb Initialization

In [2]:
import wandb

# Login
wandb.login()





[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33marthur-edmond-pro[0m ([33marthur-edmond-perso[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

# Dataset

In [3]:
dataset = load_dataset("mlabonne/smoltldr")
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['prompt', 'completion'],
        num_rows: 2000
    })
    validation: Dataset({
        features: ['prompt', 'completion'],
        num_rows: 200
    })
    test: Dataset({
        features: ['prompt', 'completion'],
        num_rows: 200
    })
})


In [4]:
model_id = "HuggingFaceTB/SmolLM-135M-Instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype="auto",
    device_map="auto",
    attn_implementation="flash_attention_2",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

# LORA

In [5]:
# Load LoRA
lora_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=16,
    lora_alpha=32,
    target_modules="all-linear",
)
model = get_peft_model(model, lora_config)
print(model.print_trainable_parameters())

trainable params: 4,884,480 || all params: 139,399,488 || trainable%: 3.5039
None


# Reward function

In [6]:
# Reward function
ideal_length = 50


def reward_len(completions, **kwargs):
    return [-abs(ideal_length - len(completion)) for completion in completions]

In [7]:
# Training arguments
training_args = GRPOConfig(
    output_dir="GRPO",
    learning_rate=2e-5,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    max_prompt_length=256,
    max_completion_length=48,
    num_generations=8,
    optim="adamw_8bit",
    num_train_epochs=1,
    bf16=True,
    report_to=["wandb"],
    remove_unused_columns=False,
    logging_steps=1,
)

In [8]:
# Trainer
trainer = GRPOTrainer(
    model=model,
    reward_funcs=[reward_len],
    args=training_args,
    train_dataset=dataset["train"],
)
# Train model
wandb.init(project="GRPO")
trainer.train()

  0%|          | 1/2000 [00:02<1:29:00,  2.67s/it]

{'loss': -0.0, 'grad_norm': 0.2670714557170868, 'learning_rate': 1.9990000000000003e-05, 'completion_length': 48.0, 'rewards/reward_len': -155.5, 'reward': -155.5, 'reward_std': 22.97825050354004, 'kl': 0.0, 'epoch': 0.0}


  0%|          | 2/2000 [00:04<1:05:09,  1.96s/it]

{'loss': 0.0, 'grad_norm': 0.31133097410202026, 'learning_rate': 1.9980000000000002e-05, 'completion_length': 46.75, 'rewards/reward_len': -150.5, 'reward': -150.5, 'reward_std': 25.433385848999023, 'kl': 0.0009569795802235603, 'epoch': 0.0}


  0%|          | 3/2000 [00:05<57:47,  1.74s/it]  

{'loss': 0.0, 'grad_norm': 0.3228992223739624, 'learning_rate': 1.9970000000000004e-05, 'completion_length': 44.5, 'rewards/reward_len': -129.75, 'reward': -129.75, 'reward_std': 45.74385070800781, 'kl': 0.001077850116416812, 'epoch': 0.0}


  0%|          | 4/2000 [00:07<53:43,  1.62s/it]

{'loss': 0.0001, 'grad_norm': 0.3292877972126007, 'learning_rate': 1.9960000000000002e-05, 'completion_length': 48.0, 'rewards/reward_len': -133.25, 'reward': -133.25, 'reward_std': 18.903892517089844, 'kl': 0.001446358161047101, 'epoch': 0.0}


  0%|          | 5/2000 [00:08<51:52,  1.56s/it]

{'loss': 0.0, 'grad_norm': 0.251634806394577, 'learning_rate': 1.9950000000000004e-05, 'completion_length': 47.125, 'rewards/reward_len': -133.125, 'reward': -133.125, 'reward_std': 11.825366020202637, 'kl': 0.0010869547259062529, 'epoch': 0.0}


  0%|          | 6/2000 [00:09<50:29,  1.52s/it]

{'loss': 0.0, 'grad_norm': 0.2914084494113922, 'learning_rate': 1.9940000000000002e-05, 'completion_length': 48.0, 'rewards/reward_len': -142.0, 'reward': -142.0, 'reward_std': 24.442352294921875, 'kl': 0.0008702330524101853, 'epoch': 0.0}


  0%|          | 7/2000 [00:11<49:51,  1.50s/it]

{'loss': 0.0, 'grad_norm': 0.259736567735672, 'learning_rate': 1.9930000000000004e-05, 'completion_length': 47.375, 'rewards/reward_len': -150.0, 'reward': -150.0, 'reward_std': 18.063379287719727, 'kl': 0.0011881862301379442, 'epoch': 0.0}


  0%|          | 8/2000 [00:12<49:33,  1.49s/it]

{'loss': 0.0, 'grad_norm': 0.2797452509403229, 'learning_rate': 1.9920000000000002e-05, 'completion_length': 48.0, 'rewards/reward_len': -145.375, 'reward': -145.375, 'reward_std': 20.845949172973633, 'kl': 0.0010424640495330095, 'epoch': 0.0}


  0%|          | 9/2000 [00:14<49:12,  1.48s/it]

{'loss': 0.0, 'grad_norm': 0.2508179843425751, 'learning_rate': 1.9910000000000004e-05, 'completion_length': 48.0, 'rewards/reward_len': -152.125, 'reward': -152.125, 'reward_std': 22.850053787231445, 'kl': 0.0011901361867785454, 'epoch': 0.0}


  0%|          | 10/2000 [00:15<48:54,  1.47s/it]

{'loss': 0.0001, 'grad_norm': 0.27906671166419983, 'learning_rate': 1.9900000000000003e-05, 'completion_length': 48.0, 'rewards/reward_len': -136.0, 'reward': -136.0, 'reward_std': 18.723552703857422, 'kl': 0.001254600239917636, 'epoch': 0.01}


  1%|          | 11/2000 [00:17<48:41,  1.47s/it]

{'loss': 0.0, 'grad_norm': 0.7304924726486206, 'learning_rate': 1.989e-05, 'completion_length': 43.5, 'rewards/reward_len': -124.125, 'reward': -124.125, 'reward_std': 50.410987854003906, 'kl': 0.0010604489361867309, 'epoch': 0.01}


  1%|          | 12/2000 [00:18<49:05,  1.48s/it]

{'loss': 0.0, 'grad_norm': 0.28289905190467834, 'learning_rate': 1.9880000000000003e-05, 'completion_length': 45.125, 'rewards/reward_len': -130.25, 'reward': -130.25, 'reward_std': 24.898365020751953, 'kl': 0.0011891189496964216, 'epoch': 0.01}


  1%|          | 13/2000 [00:20<50:07,  1.51s/it]

{'loss': 0.0001, 'grad_norm': 0.36354485154151917, 'learning_rate': 1.987e-05, 'completion_length': 42.5, 'rewards/reward_len': -121.0, 'reward': -121.0, 'reward_std': 44.92215347290039, 'kl': 0.0016298487316817045, 'epoch': 0.01}


  1%|          | 14/2000 [00:21<50:02,  1.51s/it]

{'loss': 0.0, 'grad_norm': 0.23169583082199097, 'learning_rate': 1.9860000000000003e-05, 'completion_length': 48.0, 'rewards/reward_len': -129.875, 'reward': -129.875, 'reward_std': 20.965532302856445, 'kl': 0.0011165993055328727, 'epoch': 0.01}


  1%|          | 15/2000 [00:23<50:06,  1.51s/it]

{'loss': 0.0001, 'grad_norm': 0.34557268023490906, 'learning_rate': 1.985e-05, 'completion_length': 43.875, 'rewards/reward_len': -136.0, 'reward': -136.0, 'reward_std': 28.460498809814453, 'kl': 0.001399874105118215, 'epoch': 0.01}


  1%|          | 16/2000 [00:24<50:01,  1.51s/it]

{'loss': 0.0, 'grad_norm': 0.3343489468097687, 'learning_rate': 1.9840000000000003e-05, 'completion_length': 44.25, 'rewards/reward_len': -115.125, 'reward': -115.125, 'reward_std': 28.19036102294922, 'kl': 0.0010823841439560056, 'epoch': 0.01}


  1%|          | 17/2000 [00:26<49:32,  1.50s/it]

{'loss': 0.0001, 'grad_norm': 0.32576078176498413, 'learning_rate': 1.983e-05, 'completion_length': 48.0, 'rewards/reward_len': -150.375, 'reward': -150.375, 'reward_std': 16.927892684936523, 'kl': 0.0013191209873184562, 'epoch': 0.01}


  1%|          | 18/2000 [00:27<49:22,  1.49s/it]

{'loss': 0.0, 'grad_norm': 0.49370667338371277, 'learning_rate': 1.982e-05, 'completion_length': 41.25, 'rewards/reward_len': -126.875, 'reward': -126.875, 'reward_std': 61.756290435791016, 'kl': 0.001200187485665083, 'epoch': 0.01}


  1%|          | 19/2000 [00:29<49:30,  1.50s/it]

{'loss': 0.0001, 'grad_norm': 0.2953610420227051, 'learning_rate': 1.9810000000000002e-05, 'completion_length': 48.0, 'rewards/reward_len': -152.125, 'reward': -152.125, 'reward_std': 20.060176849365234, 'kl': 0.0013167248107492924, 'epoch': 0.01}


  1%|          | 20/2000 [00:30<48:59,  1.48s/it]

{'loss': 0.0, 'grad_norm': 0.2836567461490631, 'learning_rate': 1.98e-05, 'completion_length': 47.625, 'rewards/reward_len': -137.25, 'reward': -137.25, 'reward_std': 27.421316146850586, 'kl': 0.001072687329724431, 'epoch': 0.01}


  1%|          | 21/2000 [00:32<49:23,  1.50s/it]

{'loss': 0.0, 'grad_norm': 0.5088856816291809, 'learning_rate': 1.9790000000000002e-05, 'completion_length': 41.75, 'rewards/reward_len': -124.625, 'reward': -124.625, 'reward_std': 59.603302001953125, 'kl': 0.001152966869994998, 'epoch': 0.01}


  1%|          | 22/2000 [00:33<50:02,  1.52s/it]

{'loss': 0.0001, 'grad_norm': 0.2587743103504181, 'learning_rate': 1.978e-05, 'completion_length': 47.625, 'rewards/reward_len': -150.375, 'reward': -150.375, 'reward_std': 21.94758415222168, 'kl': 0.0014028369914740324, 'epoch': 0.01}


  1%|          | 23/2000 [00:35<49:52,  1.51s/it]

{'loss': 0.0001, 'grad_norm': 0.27805009484291077, 'learning_rate': 1.9770000000000002e-05, 'completion_length': 47.875, 'rewards/reward_len': -152.5, 'reward': -152.5, 'reward_std': 14.382529258728027, 'kl': 0.0012945778435096145, 'epoch': 0.01}


  1%|          | 24/2000 [00:36<50:22,  1.53s/it]

{'loss': 0.0001, 'grad_norm': 0.5915516018867493, 'learning_rate': 1.976e-05, 'completion_length': 44.125, 'rewards/reward_len': -120.125, 'reward': -120.125, 'reward_std': 48.277286529541016, 'kl': 0.0019164023688063025, 'epoch': 0.01}


  1%|▏         | 25/2000 [00:38<49:58,  1.52s/it]

{'loss': 0.0, 'grad_norm': 0.26935237646102905, 'learning_rate': 1.9750000000000002e-05, 'completion_length': 48.0, 'rewards/reward_len': -148.125, 'reward': -148.125, 'reward_std': 18.519777297973633, 'kl': 0.0012107526417821646, 'epoch': 0.01}


  1%|▏         | 26/2000 [00:40<50:26,  1.53s/it]

{'loss': 0.0, 'grad_norm': 0.2393655925989151, 'learning_rate': 1.974e-05, 'completion_length': 48.0, 'rewards/reward_len': -140.0, 'reward': -140.0, 'reward_std': 16.936857223510742, 'kl': 0.0010837222216650844, 'epoch': 0.01}


  1%|▏         | 27/2000 [00:41<50:00,  1.52s/it]

{'loss': 0.0, 'grad_norm': 0.5833212733268738, 'learning_rate': 1.9730000000000003e-05, 'completion_length': 43.625, 'rewards/reward_len': -123.375, 'reward': -123.375, 'reward_std': 52.98230743408203, 'kl': 0.0011073541827499866, 'epoch': 0.01}


  1%|▏         | 28/2000 [00:43<50:35,  1.54s/it]

{'loss': 0.0001, 'grad_norm': 0.3087054193019867, 'learning_rate': 1.972e-05, 'completion_length': 47.25, 'rewards/reward_len': -151.75, 'reward': -151.75, 'reward_std': 20.679527282714844, 'kl': 0.0015345026040449739, 'epoch': 0.01}


  1%|▏         | 29/2000 [00:44<50:14,  1.53s/it]

{'loss': 0.0, 'grad_norm': 0.2516864836215973, 'learning_rate': 1.9710000000000003e-05, 'completion_length': 48.0, 'rewards/reward_len': -133.5, 'reward': -133.5, 'reward_std': 15.946338653564453, 'kl': 0.0010452307760715485, 'epoch': 0.01}


  2%|▏         | 30/2000 [00:46<50:03,  1.52s/it]

{'loss': 0.0, 'grad_norm': 0.28931716084480286, 'learning_rate': 1.97e-05, 'completion_length': 47.5, 'rewards/reward_len': -156.125, 'reward': -156.125, 'reward_std': 19.305347442626953, 'kl': 0.0011866823770105839, 'epoch': 0.01}


  2%|▏         | 31/2000 [00:47<49:13,  1.50s/it]

{'loss': 0.0001, 'grad_norm': 0.40258273482322693, 'learning_rate': 1.9690000000000003e-05, 'completion_length': 47.0, 'rewards/reward_len': -124.125, 'reward': -124.125, 'reward_std': 25.323266983032227, 'kl': 0.0013976659392938018, 'epoch': 0.02}


  2%|▏         | 32/2000 [00:49<48:52,  1.49s/it]

{'loss': 0.0001, 'grad_norm': 0.3096458911895752, 'learning_rate': 1.968e-05, 'completion_length': 48.0, 'rewards/reward_len': -135.125, 'reward': -135.125, 'reward_std': 27.471738815307617, 'kl': 0.0013370157685130835, 'epoch': 0.02}


  2%|▏         | 33/2000 [00:50<49:39,  1.51s/it]

{'loss': 0.0001, 'grad_norm': 0.33135053515434265, 'learning_rate': 1.9670000000000003e-05, 'completion_length': 48.0, 'rewards/reward_len': -151.375, 'reward': -151.375, 'reward_std': 23.5853328704834, 'kl': 0.0012880334397777915, 'epoch': 0.02}


  2%|▏         | 34/2000 [00:52<49:58,  1.53s/it]

{'loss': 0.0, 'grad_norm': 0.33512774109840393, 'learning_rate': 1.966e-05, 'completion_length': 48.0, 'rewards/reward_len': -123.75, 'reward': -123.75, 'reward_std': 27.400468826293945, 'kl': 0.0012392684584483504, 'epoch': 0.02}


  2%|▏         | 35/2000 [00:53<50:19,  1.54s/it]

{'loss': 0.0001, 'grad_norm': 0.3297624886035919, 'learning_rate': 1.9650000000000003e-05, 'completion_length': 44.5, 'rewards/reward_len': -120.875, 'reward': -120.875, 'reward_std': 49.252811431884766, 'kl': 0.0014031091704964638, 'epoch': 0.02}


  2%|▏         | 36/2000 [00:55<50:01,  1.53s/it]

{'loss': 0.0, 'grad_norm': 0.3177897334098816, 'learning_rate': 1.9640000000000002e-05, 'completion_length': 42.75, 'rewards/reward_len': -125.625, 'reward': -125.625, 'reward_std': 42.8850212097168, 'kl': 0.0011699208989739418, 'epoch': 0.02}


  2%|▏         | 37/2000 [00:56<49:31,  1.51s/it]

{'loss': 0.0, 'grad_norm': 0.2984064519405365, 'learning_rate': 1.9630000000000003e-05, 'completion_length': 48.0, 'rewards/reward_len': -121.25, 'reward': -121.25, 'reward_std': 30.922714233398438, 'kl': 0.0011416767956689, 'epoch': 0.02}


  2%|▏         | 38/2000 [00:58<49:23,  1.51s/it]

{'loss': 0.0001, 'grad_norm': 0.26913920044898987, 'learning_rate': 1.9620000000000002e-05, 'completion_length': 48.0, 'rewards/reward_len': -144.0, 'reward': -144.0, 'reward_std': 14.040757179260254, 'kl': 0.00137932482175529, 'epoch': 0.02}


  2%|▏         | 39/2000 [00:59<49:55,  1.53s/it]

{'loss': 0.0, 'grad_norm': 0.3848605751991272, 'learning_rate': 1.9610000000000004e-05, 'completion_length': 42.875, 'rewards/reward_len': -119.75, 'reward': -119.75, 'reward_std': 40.566524505615234, 'kl': 0.0012116237776353955, 'epoch': 0.02}


  2%|▏         | 40/2000 [01:01<49:48,  1.52s/it]

{'loss': 0.0001, 'grad_norm': 0.4181653559207916, 'learning_rate': 1.9600000000000002e-05, 'completion_length': 45.375, 'rewards/reward_len': -119.625, 'reward': -119.625, 'reward_std': 31.865509033203125, 'kl': 0.001401825575158, 'epoch': 0.02}


  2%|▏         | 41/2000 [01:02<49:23,  1.51s/it]

{'loss': 0.0001, 'grad_norm': 0.2860542833805084, 'learning_rate': 1.9590000000000004e-05, 'completion_length': 48.0, 'rewards/reward_len': -141.0, 'reward': -141.0, 'reward_std': 15.874507904052734, 'kl': 0.001743866829201579, 'epoch': 0.02}


  2%|▏         | 42/2000 [01:04<49:18,  1.51s/it]

{'loss': 0.0001, 'grad_norm': 0.29055169224739075, 'learning_rate': 1.9580000000000002e-05, 'completion_length': 48.0, 'rewards/reward_len': -133.25, 'reward': -133.25, 'reward_std': 25.431982040405273, 'kl': 0.0014540841802954674, 'epoch': 0.02}


  2%|▏         | 43/2000 [01:05<48:49,  1.50s/it]

{'loss': 0.0001, 'grad_norm': 0.28042784333229065, 'learning_rate': 1.957e-05, 'completion_length': 48.0, 'rewards/reward_len': -160.625, 'reward': -160.625, 'reward_std': 29.43728256225586, 'kl': 0.0013543854001909494, 'epoch': 0.02}


  2%|▏         | 44/2000 [01:07<48:23,  1.48s/it]

{'loss': 0.0001, 'grad_norm': 0.3751433789730072, 'learning_rate': 1.9560000000000002e-05, 'completion_length': 40.375, 'rewards/reward_len': -113.875, 'reward': -113.875, 'reward_std': 48.250648498535156, 'kl': 0.0021025771275162697, 'epoch': 0.02}


  2%|▏         | 45/2000 [01:08<48:29,  1.49s/it]

{'loss': 0.0001, 'grad_norm': 0.3485558331012726, 'learning_rate': 1.955e-05, 'completion_length': 41.875, 'rewards/reward_len': -125.375, 'reward': -125.375, 'reward_std': 41.795204162597656, 'kl': 0.0018714983016252518, 'epoch': 0.02}


  2%|▏         | 46/2000 [01:10<47:55,  1.47s/it]

{'loss': 0.0001, 'grad_norm': 0.3457840085029602, 'learning_rate': 1.9540000000000003e-05, 'completion_length': 39.25, 'rewards/reward_len': -110.5, 'reward': -110.5, 'reward_std': 48.87886047363281, 'kl': 0.001786672743037343, 'epoch': 0.02}


  2%|▏         | 47/2000 [01:11<48:13,  1.48s/it]

{'loss': 0.0001, 'grad_norm': 0.3952575922012329, 'learning_rate': 1.953e-05, 'completion_length': 47.75, 'rewards/reward_len': -136.5, 'reward': -136.5, 'reward_std': 52.839378356933594, 'kl': 0.0016563218086957932, 'epoch': 0.02}


  2%|▏         | 48/2000 [01:13<47:41,  1.47s/it]

{'loss': 0.0001, 'grad_norm': 0.26238685846328735, 'learning_rate': 1.9520000000000003e-05, 'completion_length': 48.0, 'rewards/reward_len': -142.25, 'reward': -142.25, 'reward_std': 18.07721519470215, 'kl': 0.0016497017350047827, 'epoch': 0.02}


  2%|▏         | 49/2000 [01:14<50:12,  1.54s/it]

{'loss': 0.0001, 'grad_norm': 0.2996790409088135, 'learning_rate': 1.951e-05, 'completion_length': 47.375, 'rewards/reward_len': -116.75, 'reward': -116.75, 'reward_std': 23.9925594329834, 'kl': 0.001653410610742867, 'epoch': 0.02}


  2%|▎         | 50/2000 [01:16<51:26,  1.58s/it]

{'loss': 0.0001, 'grad_norm': 0.29553118348121643, 'learning_rate': 1.95e-05, 'completion_length': 46.5, 'rewards/reward_len': -126.0, 'reward': -126.0, 'reward_std': 25.900909423828125, 'kl': 0.0014916157815605402, 'epoch': 0.03}


  3%|▎         | 51/2000 [01:18<51:44,  1.59s/it]

{'loss': 0.0001, 'grad_norm': 0.30777648091316223, 'learning_rate': 1.949e-05, 'completion_length': 45.875, 'rewards/reward_len': -118.25, 'reward': -118.25, 'reward_std': 26.794189453125, 'kl': 0.0015456527471542358, 'epoch': 0.03}


  3%|▎         | 52/2000 [01:19<50:39,  1.56s/it]

{'loss': 0.0001, 'grad_norm': 0.2704229950904846, 'learning_rate': 1.948e-05, 'completion_length': 48.0, 'rewards/reward_len': -149.875, 'reward': -149.875, 'reward_std': 17.208282470703125, 'kl': 0.001584833487868309, 'epoch': 0.03}


  3%|▎         | 53/2000 [01:21<49:52,  1.54s/it]

{'loss': 0.0001, 'grad_norm': 0.338528573513031, 'learning_rate': 1.947e-05, 'completion_length': 44.25, 'rewards/reward_len': -116.375, 'reward': -116.375, 'reward_std': 52.18904495239258, 'kl': 0.0023924086708575487, 'epoch': 0.03}


  3%|▎         | 54/2000 [01:22<49:09,  1.52s/it]

{'loss': 0.0001, 'grad_norm': 0.28731563687324524, 'learning_rate': 1.946e-05, 'completion_length': 48.0, 'rewards/reward_len': -136.0, 'reward': -136.0, 'reward_std': 25.812511444091797, 'kl': 0.001823282684199512, 'epoch': 0.03}


  3%|▎         | 55/2000 [01:24<49:02,  1.51s/it]

{'loss': 0.0001, 'grad_norm': 0.3951714038848877, 'learning_rate': 1.9450000000000002e-05, 'completion_length': 46.25, 'rewards/reward_len': -151.125, 'reward': -151.125, 'reward_std': 32.717132568359375, 'kl': 0.0019224490970373154, 'epoch': 0.03}


  3%|▎         | 56/2000 [01:25<48:34,  1.50s/it]

{'loss': 0.0001, 'grad_norm': 0.40422236919403076, 'learning_rate': 1.944e-05, 'completion_length': 40.25, 'rewards/reward_len': -111.0, 'reward': -111.0, 'reward_std': 43.7525520324707, 'kl': 0.0026957858353853226, 'epoch': 0.03}


  3%|▎         | 57/2000 [01:26<48:25,  1.50s/it]

{'loss': 0.0001, 'grad_norm': 0.31703513860702515, 'learning_rate': 1.9430000000000002e-05, 'completion_length': 47.0, 'rewards/reward_len': -131.375, 'reward': -131.375, 'reward_std': 21.070884704589844, 'kl': 0.0016972852172330022, 'epoch': 0.03}


  3%|▎         | 58/2000 [01:28<47:48,  1.48s/it]

{'loss': 0.0001, 'grad_norm': 0.38841748237609863, 'learning_rate': 1.942e-05, 'completion_length': 44.25, 'rewards/reward_len': -136.125, 'reward': -136.125, 'reward_std': 45.90498733520508, 'kl': 0.0019454529974609613, 'epoch': 0.03}


  3%|▎         | 59/2000 [01:29<47:29,  1.47s/it]

{'loss': 0.0001, 'grad_norm': 0.32227474451065063, 'learning_rate': 1.9410000000000002e-05, 'completion_length': 43.75, 'rewards/reward_len': -129.5, 'reward': -129.5, 'reward_std': 30.46778106689453, 'kl': 0.0025641212705522776, 'epoch': 0.03}


  3%|▎         | 60/2000 [01:31<47:23,  1.47s/it]

{'loss': 0.0001, 'grad_norm': 0.37290868163108826, 'learning_rate': 1.94e-05, 'completion_length': 48.0, 'rewards/reward_len': -130.875, 'reward': -130.875, 'reward_std': 17.299360275268555, 'kl': 0.0022357283160090446, 'epoch': 0.03}


  3%|▎         | 61/2000 [01:32<47:37,  1.47s/it]

{'loss': 0.0001, 'grad_norm': 0.36565685272216797, 'learning_rate': 1.9390000000000002e-05, 'completion_length': 45.375, 'rewards/reward_len': -99.75, 'reward': -99.75, 'reward_std': 31.869152069091797, 'kl': 0.0019096725154668093, 'epoch': 0.03}


  3%|▎         | 62/2000 [01:34<47:57,  1.48s/it]

{'loss': 0.0001, 'grad_norm': 0.5324435234069824, 'learning_rate': 1.938e-05, 'completion_length': 44.25, 'rewards/reward_len': -129.375, 'reward': -129.375, 'reward_std': 47.9581298828125, 'kl': 0.002472899155691266, 'epoch': 0.03}


  3%|▎         | 63/2000 [01:35<47:53,  1.48s/it]

{'loss': 0.0001, 'grad_norm': 0.4082697331905365, 'learning_rate': 1.9370000000000003e-05, 'completion_length': 48.0, 'rewards/reward_len': -123.75, 'reward': -123.75, 'reward_std': 49.23921203613281, 'kl': 0.0019633937627077103, 'epoch': 0.03}


  3%|▎         | 64/2000 [01:37<47:49,  1.48s/it]

{'loss': 0.0001, 'grad_norm': 0.32216012477874756, 'learning_rate': 1.936e-05, 'completion_length': 47.125, 'rewards/reward_len': -129.5, 'reward': -129.5, 'reward_std': 17.221250534057617, 'kl': 0.0017591440118849277, 'epoch': 0.03}


  3%|▎         | 65/2000 [01:38<47:15,  1.47s/it]

{'loss': 0.0001, 'grad_norm': 0.47261181473731995, 'learning_rate': 1.9350000000000003e-05, 'completion_length': 40.5, 'rewards/reward_len': -79.625, 'reward': -79.625, 'reward_std': 49.64714813232422, 'kl': 0.002635023556649685, 'epoch': 0.03}


  3%|▎         | 66/2000 [01:40<47:26,  1.47s/it]

{'loss': 0.0001, 'grad_norm': 0.367867112159729, 'learning_rate': 1.934e-05, 'completion_length': 45.25, 'rewards/reward_len': -121.75, 'reward': -121.75, 'reward_std': 43.676570892333984, 'kl': 0.0034080822952091694, 'epoch': 0.03}


KeyboardInterrupt: 

## Save the model

In [10]:
merged_model = trainer.model.merge_and_unload()
merged_model.push_to_hub(
    "SmolGRPO-135M", private=False, tags=["GRPO", "Reasoning-Course"]
)

model.safetensors: 100%|██████████| 269M/269M [01:13<00:00, 3.64MB/s]


CommitInfo(commit_url='https://huggingface.co/Shumatsurontek/SmolGRPO-135M/commit/a8eb7466fd6fe16ea9e20346c375a9f42473be2d', commit_message='Upload LlamaForCausalLM', commit_description='', oid='a8eb7466fd6fe16ea9e20346c375a9f42473be2d', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Shumatsurontek/SmolGRPO-135M', endpoint='https://huggingface.co', repo_type='model', repo_id='Shumatsurontek/SmolGRPO-135M'), pr_revision=None, pr_num=None)