<a href="https://colab.research.google.com/github/anthonysauer/colab-playground/blob/main/nb/Llama3.1_(8B)-mcts-ttt-GRPO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

To run this, press "*Runtime*" and press "*Run all*" on a **free** Tesla T4 Google Colab instance!
<div class="align-center">
<a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
<a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
<a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a></a> Join Discord if you need help + ⭐ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐
</div>

To install Unsloth on your own computer, follow the installation instructions on our Github page [here](https://docs.unsloth.ai/get-started/installing-+-updating).

You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & [how to save it](#Save)


### Installation

In [1]:
%%capture
import os
!pip install open_spiel
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    !pip install --no-deps unsloth vllm

In [2]:
#@title Colab Extra Install { display-mode: "form" }
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    !pip install --no-deps unsloth vllm
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    # Skip restarting message in Colab
    import sys, re, requests; modules = list(sys.modules.keys())
    for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft "trl==0.15.2" triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1" huggingface_hub hf_transfer

    # vLLM requirements - vLLM breaks Colab due to reinstalling numpy
    f = requests.get("https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt").content
    with open("vllm_requirements.txt", "wb") as file:
        file.write(re.sub(rb"(transformers|numpy|xformers)[^\n]{1,}\n", b"", f))
    !pip install -r vllm_requirements.txt

### Unsloth

Load up `Llama 3.1 8B Instruct`, and set parameters

In [3]:
from unsloth import FastLanguageModel
import torch
max_seq_length = 2048 # Can increase for longer reasoning traces
lora_rank = 32 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Meta-Llama-3.1-8B",
    max_seq_length = max_seq_length,
    dtype = None, # None for auto detection.
    load_in_4bit = True, # False for LoRA 16bit
    max_lora_rank = lora_rank,
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

# model, tokenizer = FastLanguageModel.from_pretrained(
#     model_name = "meta-llama/meta-Llama-3.1-8B-Instruct",
#     max_seq_length = max_seq_length,
#     load_in_4bit = True, # False for LoRA 16bit
#     fast_inference = True, # Enable vLLM fast inference
#     max_lora_rank = lora_rank,
#     gpu_memory_utilization = 0.6, # Reduce if out of memory
# )

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = lora_rank,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

# model = FastLanguageModel.get_peft_model(
#     model,
#     r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
#     target_modules = [
#         "q_proj", "k_proj", "v_proj", "o_proj",
#         "gate_proj", "up_proj", "down_proj",
#     ], # Remove QKVO if out of memory
#     lora_alpha = lora_rank,
#     use_gradient_checkpointing = "unsloth", # Enable long context finetuning
#     random_state = 3407,
# )

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 05-27 19:23:06 [importing.py:53] Triton module has been replaced with a placeholder.
INFO 05-27 19:23:06 [__init__.py:239] Automatically detected platform cuda.
==((====))==  Unsloth 2025.5.7: Fast Llama patching. Transformers: 4.52.2. vLLM: 0.8.5.post1.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/5.96G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/235 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.6k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/459 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

Unsloth 2025.5.7 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


### Data Prep
<a name="Data"></a>

Helper data prep functions and prompt formats to use in both SFT and GRPO stages

In [4]:
import re
from datasets import load_dataset, Dataset

# Load and prep dataset
SYSTEM_PROMPT = """
You are a powerful gaming agent who can make proper decisions to beat the user in gaming tasks.
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{}
</reasoning>
<answer>
{}
</answer>
"""

def extract_xml_reasoning(text: str) -> str:
    reasoning = text.split("<reasoning>")[-1]
    reasoning = reasoning.split("</reasoning>")[0]
    return reasoning.strip()

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()


### SFT Data Prep
<a name="SFT Data"></a>

Prepare SFT training data

In [5]:
from unsloth.chat_templates import get_chat_template

tokenizer = get_chat_template(
    tokenizer,
    chat_template="chatml",
)

def sft_formatting_prompts_func(examples):
  messages = [
       [
          {"role": "system", "content": SYSTEM_PROMPT},
          {"role": "user", "content": prompt},
          {"role": "assistant", "content": XML_COT_FORMAT.format(reasoning, answer)}
       ]
       for prompt, reasoning, answer in zip(examples["prompt"], examples["reasoning"], examples["answer"])
  ]
  # print(messages)
  texts = [tokenizer.apply_chat_template(message, tokenize = False, add_generation_prompt = False) for message in messages]
  return { "text" : texts, }

sft_dataset = load_dataset("json", data_files="mcts_ttt_train_sft.jsonl", split = "train")
sft_dataset = sft_dataset.map(sft_formatting_prompts_func, batched = True,)
print(sft_dataset[0])

Unsloth: Will map <|im_end|> to EOS = <|end_of_text|>.


Generating train split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/857 [00:00<?, ? examples/s]

{'current_moves': ['x(2,2)', 'o(2,1)', 'x(1,0)'], 'player': 'o', 'prompt': "Tic Tac Toe is a two-player game played on a grid. Players take turns marking a space with their respective symbols. The goal is to get 3 of one's own symbols in a row, either horizontally, vertically, or diagonally, before the opponent does. If all nine squares are filled and no player has three in a row, the game is a draw. \nEach move is represented by a string consisting of two parts: the current player ('x' or 'o') and the coordinate of their move (column, row), in that order. For instance, o(0,2) means that 'o' moves at the first column and the third row of the grid. \nThe current move sequence is: x(2,2), o(2,1), x(1,0). You, player o, will move next. \nThink about your current situation, then choose the best next move by exploring the search space. \nYour output must be in the following format strictly:\n<reasoning>\nThe search space trace.\n</reasoning>\n<answer>The best move for player o (you), i.e., 

### SFT Training
<a name="SFT Train"></a>

Use `SFTTrainer` to train the model to imitate the desired MCTS format

In [None]:
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported

sft_trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = sft_dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 2,
    packing = False, # Can make training 5x faster for short sequences.
    args = TrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        num_train_epochs = 1, # Set this for 1 full training run.
        # max_steps = 60,
        learning_rate = 2e-4,
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "sft_outputs",
        report_to = "none", # Use this for WandB etc
    ),
)

Unsloth: Tokenizing ["text"] (num_proc=2):   0%|          | 0/857 [00:00<?, ? examples/s]

In [None]:
trainer_stats = sft_trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 857 | Num Epochs = 1 | Total steps = 107
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 4 x 1) = 8
 "-____-"     Trainable parameters = 83,886,080/8,000,000,000 (1.05% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss
1,1.1453
2,1.2177
3,1.0903
4,0.8192
5,0.7737
6,0.6636
7,0.5241
8,0.3941
9,0.3479
10,0.2549


<a name="Save"></a>
### Saving, loading SFT models
To save the final model as LoRA adapters, either use Huggingface's `push_to_hub` for an online save or `save_pretrained` for a local save.

**[NOTE]** This ONLY saves the LoRA adapters, and not the full model.

In [None]:
model.save_pretrained("sft_lora_model")  # Local saving
tokenizer.save_pretrained("sft_lora_model")
# model.push_to_hub("your_name/lora_model", token = "...") # Online saving
# tokenizer.push_to_hub("your_name/lora_model", token = "...") # Online saving

import shutil

shutil.make_archive("mcts_ttt_sft_lora", "zip", "sft_lora_model")

'/content/mcts_ttt_sft_lora.zip'

Now if you want to load the LoRA adapters we just saved for inference, set `False` to `True`:

In [6]:
if True: # if saved in .zip file
    import shutil

    shutil.unpack_archive("mcts_ttt_sft_lora.zip", "sft_lora_model")

if True:
    from unsloth import FastLanguageModel
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name = "sft_lora_model", # YOUR MODEL YOU USED FOR TRAINING
        max_seq_length = max_seq_length,
        dtype = None,
        load_in_4bit = True,
    )
    # FastLanguageModel.for_inference(model) # Enable native 2x faster inference

==((====))==  Unsloth 2025.5.7: Fast Llama patching. Transformers: 4.52.2. vLLM: 0.8.5.post1.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


<a name="SFT Inference"></a>
### SFT Inference
Test out the SFT trained model. Main goal is the MCTS format, even if the search trace itself is incorrect/inconsistent

In [7]:
test_dataset = load_dataset("json", data_files="mcts_ttt_test.jsonl", split="train")
test_sample = test_dataset[9]

chatml_prompt = [
    {"role": "system", "content": SYSTEM_PROMPT},
    {"role": "user", "content": test_sample["prompt"]},
    {"role": "assistant", "content": XML_COT_FORMAT.format(test_sample["reasoning"], test_sample["answer"])}
]

text = tokenizer.apply_chat_template(chatml_prompt, tokenize = False, add_generation_prompt = False)
# print(text)

inputs = tokenizer(
[
    text[:text.find("<|im_start|>assistant")]
], return_tensors = "pt").to("cuda")

outputs = model.generate(**inputs, max_new_tokens = 2048, use_cache = True)
print(tokenizer.batch_decode(outputs)[0])

Generating train split: 0 examples [00:00, ? examples/s]

<|begin_of_text|><|im_start|>system

You are a powerful gaming agent who can make proper decisions to beat the user in gaming tasks.
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
<|im_end|>
<|im_start|>user
Tic Tac Toe is a two-player game played on a grid. Players take turns marking a space with their respective symbols. The goal is to get 3 of one's own symbols in a row, either horizontally, vertically, or diagonally, before the opponent does. If all nine squares are filled and no player has three in a row, the game is a draw. 
Each move is represented by a string consisting of two parts: the current player ('x' or 'o') and the coordinate of their move (column, row), in that order. For instance, o(0,2) means that 'o' moves at the first column and the third row of the grid. 
The current move sequence is: x(1,2), o(0,0), x(1,1), o(1,0), x(2,1). You, player o, will move next. 
Think about your current situation, then choose the best next move by ex

### GRPO Data Prep
<a name="GRPO Data"></a>

Prepare GRPO training data

In [8]:
def grpo_formatting_prompts_func(examples):
  messages = [
       [
          {"role": "system", "content": SYSTEM_PROMPT},
          {"role": "user", "content": prompt},
       ]
       for prompt in examples["prompt"]
  ]
  # print(messages)
  texts = [tokenizer.apply_chat_template(message, tokenize = False, add_generation_prompt = False) for message in messages]
  return { "prompt" : texts, }

def get_mcts_ttt_grpo(stage="grpo", split="train") -> Dataset:
    path = "mcts_ttt_test.jsonl"
    if split == "train":
      path = "mcts_ttt_train_" + stage + ".jsonl"
    data = load_dataset("json", data_files=path, split = "train")
    data = data.map(grpo_formatting_prompts_func, batched = True,)
    return data # type: ignore

grpo_dataset = get_mcts_ttt_grpo()
print(grpo_dataset[0])

Generating train split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/3001 [00:00<?, ? examples/s]

{'current_moves': ['x(0,0)', 'o(0,1)', 'x(0,2)', 'o(1,0)', 'x(2,1)'], 'player': 'o', 'prompt': "<|im_start|>system\n\nYou are a powerful gaming agent who can make proper decisions to beat the user in gaming tasks.\nRespond in the following format:\n<reasoning>\n...\n</reasoning>\n<answer>\n...\n</answer>\n<|im_end|>\n<|im_start|>user\nTic Tac Toe is a two-player game played on a grid. Players take turns marking a space with their respective symbols. The goal is to get 3 of one's own symbols in a row, either horizontally, vertically, or diagonally, before the opponent does. If all nine squares are filled and no player has three in a row, the game is a draw. \nEach move is represented by a string consisting of two parts: the current player ('x' or 'o') and the coordinate of their move (column, row), in that order. For instance, o(0,2) means that 'o' moves at the first column and the third row of the grid. \nThe current move sequence is: x(0,0), o(0,1), x(0,2), o(1,0), x(2,1). You, player

Check maxmimum tokens needed for the prompt and for the entire sequence

(Need to be less than the `max_prompt_length` and `max_seq_length` when training)

In [9]:
max_prompt = 0
max_seq = 0
for example in grpo_dataset:
    curr_prompt = len(tokenizer(example["prompt"]).input_ids)
    if curr_prompt > max_prompt:
        max_prompt = curr_prompt

    curr_seq = len(tokenizer(example["prompt"] + XML_COT_FORMAT.format(example["reasoning"], example["answer"])).input_ids)
    if curr_seq > max_seq:
        max_seq = curr_seq

print("Max prompt tokens: " + str(max_prompt))
print("Max sequence tokens: " + str(max_seq))

Max prompt tokens: 325
Max sequence tokens: 1880


<a name="Reward functions"></a>
### Set up reward functions

In [10]:
import numpy as np
import pyspiel

#### Formatting reward functions

In [11]:
# Strict format of overall response
def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    matches = [re.match(pattern, c) for c in completions]
    return [0.5 if match else 0.0 for match in matches]


# Soft format of overall response
def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    matches = [re.match(pattern, c) for c in completions]
    return [0.5 if match else 0.0 for match in matches]


def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count


# Correct number of xml tags
def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    return [count_xml(c) for c in completions]


# Strict reasoning formatting -- nearly exact match with training data format
def strict_reasoning_format_reward_func(completions, **kwargs) -> list[float]:
    extracted_reasonings = [extract_xml_reasoning(c) for c in completions]
    pattern = r"^Start state:\s*\n([xo.]{3}\n){3}Evaluating state with random playout:\s*\n([xo]\([0-2],[0-2]\)\n([xo.]{3}\n){3})+(([xo] wins)|(draw)) in random playout\nUpdating move rewards:\s*\n\n(Exploring move sequence: \[('[xo]\([0-2],[0-2]\)'(, )?)+\]\nResulting state:\s*\n([xo.]{3}\n){3}((draw\n)|([xo] wins\n))?(Evaluating state with random playout:\s*\n([xo]\([0-2],[0-2]\)\n([xo.]{3}\n){3})+(([xo] wins)|(draw)) in random playout\n)?Updating move rewards:\s*\n((\[('[xo]\([0-2],[0-2]\)'(, )?)+\]: -?\d.\d \+ -?\d.\d = -?\d.\d\n)|(All possible moves starting from move sequence \[('[xo]\([0-2],[0-2]\)'(, )?)+\] have been solved\. This move sequence has a maximum reward of -?\d\.\d for x\n)|(\[('[xo]\([0-2],[0-2]\)'(, )?)+\] is a winning move sequence for [xo]\n))+\n)+(((Explored moves at best resulted in a draw for [xo].\nChoosing move that results in draw that was explored the most: [xo]\([0-2],[0-2]\)\n)|(Choosing move that is proven to win for [xo]: [xo]\([0-2],[0-2]\)\n))|(No explored moves are a winning move for [xo].\nChoosing move with highest reward that is not a losing move for [xo]: [xo]\([0-2],[0-2]\)\n))$"
    return [1.0 if re.match(pattern, r) else 0.0 for r in extracted_reasonings]


# Soft reasoning formatting
def soft_reasoning_format_reward_func(completions, **kwargs) -> list[float]:
    extracted_reasonings = [extract_xml_reasoning(c) for c in completions]
    pattern = r"^[sS]tart.*\n([xo.]{3}\n){3}.*playout.*\n([xo]\([0-2]\s*,[0-2]\)\n([xo.]{3}\n){3})+(([xo] wins)|(draw)).*\n[uU]pdating.*\n\n?([Ee]xploring.*\[('?[xo]\([0-2]\s*,[0-2]\)'?(,\s*)?)+\]\n[rR]esult.*\n([xo.]{3}\n){3}((draw\n)|([xo] wins?\n))?(.*playout.*\n([xo]\([0-2]\s*,[0-2]\)\n([xo.]{3}\n){3})+(([xo] wins)|(draw)).*\n)?[uU]pdating.*\n((\[('?[xo]\([0-2]\s*,[0-2]\)'?(,\s*)?)+\]:\s*?-?\d.\d \+ -?\d.\d = -?\d.\d\s*\n)|([aA]ll possible moves.*-?\d\.\d.*\n)|(\[('?[xo]\([0-2]\s*,[0-2]\)'?(,\s*)?)+\].*winning move sequence for [xo]\s*\n))+\n)+.*\n?[cC]hoosing.*[xo]\([0-2]\s*,[0-2]\)*.\n$"
    return [1.0 if re.match(pattern, r) else 0.0 for r in extracted_reasonings]


# Final move is in the strict format, e.g. x(0,1)
def move_format_reward_func(completions, **kwargs) -> list[float]:
    extracted_responses = [extract_xml_answer(c) for c in completions]
    pattern = r"^[xo]\([0-2],[0-2]\)$/gm"
    return [0.5 if re.match(pattern, r) else 0.0 for r in extracted_responses]


#### MCTS reward functions

In [17]:
# RegEx for parsing reasoning
state_pattern = r"(?:[xo.]{3}\n){2}(?:[xo.]{3}\n?)"
exploration_pattern = r"[Ee]xploring.*\[(?:'?[xo]\([0-2]\s*,[0-2]\)'?(?:,\s*)?)+\]\n[rR]esult.*\n(?:[xo.]{3}\n){3}(?:(?:draw\n)|(?:[xo] wins?\n))?"
playout_pattern = exploration_pattern + r".*playout.*\n(?:[xo]\([0-2]\s*,[0-2]\)\n(?:[xo.]{3}\n){3})+(?:(?:[xo] wins)|(?:draw)).*\n"
single_move_pattern = r"[xo]\([0-2]\s*,[0-2]\)"
outcome_pattern = r"(?:draw)|(?:[xo] wins?)"


def get_action(state, action_str):
    for action in state.legal_actions():
        if action_str == state.action_to_string(state.current_player(), action):
          return action
    return None


def get_start_state(current_moves):
    game = pyspiel.load_game("tic_tac_toe")
    state = game.new_initial_state()
    for action_str in current_moves:
        state.apply_action(get_action(state, action_str))
    return state


# Simulates Tic-Tac-Toe game using the given move sequences and start state,
# and provides reward counts for the number of valid moves, result states, and terminal states
def count_valid_ttt_explorations(start_state, move_sequences, result_states, outcomes):
    valid_move_sequence_count = 0.0
    valid_result_state_count = 0.0
    correct_terminal_count = 0.0

    for move_sequence, result_state_str, outcome in zip(move_sequences, result_states, outcomes):
        state = start_state.clone()
        move_sequence_valid = True
        true_outcome = None
        terminal_valid = True

        for i, move in enumerate(move_sequence):
            action = get_action(state, move)
            if action is None:
                move_sequence_valid = False
                break

            state.apply_action(action)

            if state.is_terminal():
                if i < (len(move_sequence) - 1): # If a terminal state is reached but there are still moves remaining
                    terminal_valid = False
                else:
                    true_returns = state.returns()
                    if true_returns[0] == 0.0 and true_returns[1] == 0.0:
                        true_outcome = "draw"
                    elif true_returns[0] == 1.0:
                        true_outcome = "x wins"
                    else:
                        true_outcome = "o wins"

        if move_sequence_valid:
            valid_move_sequence_count += 0.1

            if (result_state_str is not None) and (str(state) == result_state_str.strip()):
                valid_result_state_count += 0.1

            if terminal_valid and (true_outcome is None and outcome is None) or (true_outcome == outcome):
              correct_terminal_count += 0.1

    return valid_move_sequence_count, valid_result_state_count, correct_terminal_count


# Simulates Tic-Tac-Toe game using the given playout move sequences and start state,
# and provides reward counts for the number of valid moves, result states, and terminal states
def count_valid_ttt_playouts(start_state, playout_move_sequences, playout_states, outcomes):
    valid_move_sequence_count = 0.0
    valid_playout_state_count = 0.0
    correct_terminal_count = 0.0

    for move_sequence, playout, outcome in zip(playout_move_sequences, playout_states, outcomes):
        state = start_state.clone()
        move_sequence_valid = True
        playout_valid = True
        true_outcome = None
        terminal_valid = True

        for i, (move, playout_state_str) in enumerate(zip(move_sequence, playout)):
            action = get_action(state, move)
            if action is None:
                move_sequence_valid = False
                break

            state.apply_action(action)

            if str(state) != playout_state_str.strip():
                playout_valid = False

            if state.is_terminal():
                if i < (len(move_sequence) - 1): # If a terminal state is reached but there are still moves remaining
                    terminal_valid = False
                else:
                    true_returns = state.returns()
                    if true_returns[0] == 0.0 and true_returns[1] == 0.0:
                        true_outcome = "draw"
                    elif true_returns[0] == 1.0:
                        true_outcome = "x wins"
                    else:
                        true_outcome = "o wins"

        if move_sequence_valid:
            valid_move_sequence_count += 0.1

            if playout_valid:
                valid_playout_state_count += 0.1

            if terminal_valid and (true_outcome is None and outcome is None) or (true_outcome == outcome):
              correct_terminal_count += 0.1

    return valid_move_sequence_count, valid_playout_state_count, correct_terminal_count


# Correct start state based on previous moves
def mcts_individual_rewards(completions, current_moves):
    extracted_reasonings = [extract_xml_reasoning(c) for c in completions]
    start_states = [get_start_state(c) for c in current_moves]

    # Rewards for correct start state given previous moves
    start_state_matches = [re.search(state_pattern, r) for r in extracted_reasonings]
    start_state_rewards = [2.0 if (m is not None and m.group().strip() == str(s)) else 0.0 for m, s in zip(start_state_matches, start_states)]

    # Rewards for exploring valid moves
    explores_valid_moves_rewards = []

    # Rewards for arriving at correct state from explored moves
    explored_states_rewards = []

    # Rewards for correctly identifying terminal vs non-terminal states
    explored_terminal_rewards = []

    exploration_lists = [re.findall(exploration_pattern, r) for r in extracted_reasonings]
    exploration_moves = [] # List (completions) of lists (explorations) of lists (move sequences)
    exploration_result_states = [] # List (completions) of lists (explorations) of lists (result states)
    exploration_outcomes = [] # List (completions) of lists (terminal result of exploration or None if non-terminal)

    for explorations in exploration_lists:
        moves = []
        result_state_matches = []
        outcome_matches = []
        for e in explorations:
            moves.append(re.findall(single_move_pattern, e))
            result_state_matches.append(re.search(state_pattern, e))
            outcome_matches.append(re.search(outcome_pattern, e))

        exploration_moves.append(moves)
        exploration_result_states.append([m.group() if m is not None else None for m in result_state_matches])
        exploration_outcomes.append([m.group() if m is not None else None for m in outcome_matches])

    valid_exploration_counts = [
        count_valid_ttt_explorations(s, m, r, o)
        for s, m, r, o in zip(start_states, exploration_moves, exploration_result_states, exploration_outcomes)
    ]

    for count in valid_exploration_counts:
        explores_valid_moves_rewards.append(count[0])
        explored_states_rewards.append(count[1])
        explored_terminal_rewards.append(count[2])

    # Rewards for adding valid moves to the playouts
    playout_valid_moves_rewards = []

    # Rewards for arriving at correct state for each playout step
    playout_states_rewards = []

    # Rewards for correctly identifying terminal vs non-terminal states in playouts
    playout_terminal_rewards = []

    playout_lists = [re.findall(playout_pattern, r) for r in extracted_reasonings]
    playout_moves = [] # List (completions) of lists (playouts) of lists (moves)
    playout_states = [] # List (completions) of lists (playouts) of lists (states)
    playout_outcomes = [] # List (completions) of lists (terminal result of playout or None if non-terminal)

    for playouts in playout_lists:
        moves = []
        states = []
        outcome_matches = []
        for p in playouts:
            moves.append(re.findall(single_move_pattern, p))
            states.append(re.findall(state_pattern, p))
            outcome_matches.append(re.search(outcome_pattern, p))

        playout_moves.append(moves)
        playout_states.append(states)
        playout_outcomes.append([m.group() if m is not None else None for m in outcome_matches])

    valid_playout_counts = [
        count_valid_ttt_playouts(s, m, p, o)
        for s, m, p, o in zip(start_states, playout_moves, playout_states, playout_outcomes)
    ]

    for count in valid_playout_counts:
        playout_valid_moves_rewards.append(count[0])
        playout_states_rewards.append(count[1])
        playout_terminal_rewards.append(count[2])

    return start_state_rewards, explores_valid_moves_rewards, explored_states_rewards, explored_terminal_rewards, playout_valid_moves_rewards, playout_states_rewards, playout_terminal_rewards


# Sums together individual MCTS reward functions, to get final rewards
def mcts_reward_func(completions, current_moves, **kwargs) -> list[float]:
    return list(np.sum(np.array(mcts_individual_rewards(completions, current_moves)), axis=0))


#### Final move reward function

In [13]:
# Final move is in list of optimal moves
def optimality_reward_func(completions, optimal_moves, **kwargs) -> list[float]:
    extracted_responses = [extract_xml_answer(c) for c in completions]
    return [4.0 if r in m else 0.0 for r, m in zip(extracted_responses, optimal_moves)]


<a name="Test reward functions"></a>
### Test reward functions

Non-comprehensive tests to check whether reward functions are working properly

In [18]:
import math

# =============================================================================
# START STATE
# =============================================================================

# Test correct start state
correct_start_state_rewards = mcts_individual_rewards(
    ["<reasoning>Start state: \n.xo\nxo.\n.ox\n</reasoning>"],
    [["x(0,1)", "o(0,2)", "x(1,0)", "o(1,1)", "x(2,2)", "o(2,1)"]],
)[0]
assert correct_start_state_rewards == [2.0]

# Test incorrect start state
incorrect_start_state_rewards = mcts_individual_rewards(
    ["<reasoning>Start state: \nx.o\nxo.\n.ox\n</reasoning>"],
    [["x(0,1)", "o(0,2)", "x(1,0)", "o(1,1)", "x(2,2)", "o(2,1)"]],
)[0]
assert incorrect_start_state_rewards == [0.0]


# =============================================================================
# VALID MOVES
# =============================================================================

# Test that valid moves are correctly rewarded in exploration
explores_valid_moves = mcts_individual_rewards(
    ["Exploring move sequence: ['x(2,0)']\nResulting state: \n.xo\nxo.\nxox\nEvaluating state with random playout: \no(0,0)\noxo\nxo.\nxox\nx(1,2)\noxo\nxox\nxox\ndraw in random playout\nUpdating move rewards: \n['x(2,0)']: 0.0 + 0.0 = 0.0\n\nExploring move sequence: ['x(0,0)']\nResulting state: \nxxo\nxo.\n.ox\nEvaluating state with random playout: \no(2,0)\nxxo\nxo.\noox\no wins in random playout\nUpdating move rewards: \n['x(0,0)']: 0.0 + -1.0 = -1.0\n\nExploring move sequence: ['x(1,2)', 'o(2,0)']\nResulting state: \n.xo\nxox\noox\no wins\nUpdating move rewards: \n['x(1,2)', 'o(2,0)']: 0.0 + -1.0 = -1.0\n['x(1,2)']: 0.0 + -1.0 = -1.0\n['x(1,2)'] is a winning move sequence for o\n\n"],
    [["x(0,1)", "o(0,2)", "x(1,0)", "o(1,1)", "x(2,2)", "o(2,1)"]]
)[1]
assert math.isclose(explores_valid_moves[0], 0.3)

# Test that invalid moves are not rewarded in exploration
explores_invalid_moves = mcts_individual_rewards(
    ["Exploring move sequence: ['o(2,0)']\nResulting state: \n.xo\nxo.\nxox\nEvaluating state with random playout: \no(0,0)\noxo\nxo.\nxox\nx(1,2)\noxo\nxox\nxox\ndraw in random playout\nUpdating move rewards: \n['x(2,0)']: 0.0 + 0.0 = 0.0\n\nExploring move sequence: ['x(0,0)']\nResulting state: \nxxo\nxo.\n.ox\nEvaluating state with random playout: \no(2,0)\nxxo\nxo.\noox\no wins in random playout\nUpdating move rewards: \n['x(0,0)']: 0.0 + -1.0 = -1.0\n\nExploring move sequence: ['x(0,1)', 'o(2,0)']\nResulting state: \n.xo\nxox\noox\no wins\nUpdating move rewards: \n['x(1,2)', 'o(2,0)']: 0.0 + -1.0 = -1.0\n['x(1,2)']: 0.0 + -1.0 = -1.0\n['x(1,2)'] is a winning move sequence for o\n\n"],
    [["x(0,1)", "o(0,2)", "x(1,0)", "o(1,1)", "x(2,2)", "o(2,1)"]]
)[1]
assert math.isclose(explores_invalid_moves[0], 0.1)

# Test that valid moves are correctly rewarded in playouts
playout_valid_moves = mcts_individual_rewards(
    ["Exploring move sequence: ['x(2,0)']\nResulting state: \n.xo\nxo.\nxox\nEvaluating state with random playout: \no(0,0)\noxo\nxo.\nxox\nx(1,2)\noxo\nxox\nxox\ndraw in random playout\nUpdating move rewards: \n['x(2,0)']: 0.0 + 0.0 = 0.0\n\nExploring move sequence: ['x(0,0)']\nResulting state: \nxxo\nxo.\n.ox\nEvaluating state with random playout: \no(2,0)\nxxo\nxo.\noox\no wins in random playout\nUpdating move rewards: \n['x(0,0)']: 0.0 + -1.0 = -1.0\n\nExploring move sequence: ['x(1,2)', 'o(2,0)']\nResulting state: \n.xo\nxox\noox\no wins\nUpdating move rewards: \n['x(1,2)', 'o(2,0)']: 0.0 + -1.0 = -1.0\n['x(1,2)']: 0.0 + -1.0 = -1.0\n['x(1,2)'] is a winning move sequence for o\n\n"],
    [["x(0,1)", "o(0,2)", "x(1,0)", "o(1,1)", "x(2,2)", "o(2,1)"]]
)[4]
assert math.isclose(playout_valid_moves[0], 0.2)

# Test that invalid moves are not rewarded in playouts
playout_invalid_moves = mcts_individual_rewards(
    ["Exploring move sequence: ['x(2,0)']\nResulting state: \n.xo\nxo.\nxox\nEvaluating state with random playout: \nx(0,0)\noxo\nxo.\nxox\nx(1,2)\noxo\nxox\nxox\ndraw in random playout\nUpdating move rewards: \n['x(2,0)']: 0.0 + 0.0 = 0.0\n\nExploring move sequence: ['x(0,0)']\nResulting state: \nxxo\nxo.\n.ox\nEvaluating state with random playout: \nx(1,0)\nxxo\nxo.\noox\no wins in random playout\nUpdating move rewards: \n['x(0,0)']: 0.0 + -1.0 = -1.0\n\nExploring move sequence: ['x(1,2)', 'o(2,0)']\nResulting state: \n.xo\nxox\noox\no wins\nUpdating move rewards: \n['x(1,2)', 'o(2,0)']: 0.0 + -1.0 = -1.0\n['x(1,2)']: 0.0 + -1.0 = -1.0\n['x(1,2)'] is a winning move sequence for o\n\n"],
    [["x(0,1)", "o(0,2)", "x(1,0)", "o(1,1)", "x(2,2)", "o(2,1)"]]
)[4]
assert math.isclose(playout_invalid_moves[0], 0.0)


# =============================================================================
# RESULTING STATES
# =============================================================================

# Test that correct states are rewarded in exploration
explores_correct_states = mcts_individual_rewards(
    ["Exploring move sequence: ['x(2,0)']\nResulting state: \n.xo\nxo.\nxox\nEvaluating state with random playout: \no(0,0)\noxo\nxo.\nxox\nx(1,2)\noxo\nxox\nxox\ndraw in random playout\nUpdating move rewards: \n['x(2,0)']: 0.0 + 0.0 = 0.0\n\nExploring move sequence: ['x(0,0)']\nResulting state: \nxxo\nxo.\n.ox\nEvaluating state with random playout: \no(2,0)\nxxo\nxo.\noox\no wins in random playout\nUpdating move rewards: \n['x(0,0)']: 0.0 + -1.0 = -1.0\n\nExploring move sequence: ['x(1,2)', 'o(2,0)']\nResulting state: \n.xo\nxox\noox\no wins\nUpdating move rewards: \n['x(1,2)', 'o(2,0)']: 0.0 + -1.0 = -1.0\n['x(1,2)']: 0.0 + -1.0 = -1.0\n['x(1,2)'] is a winning move sequence for o\n\n"],
    [["x(0,1)", "o(0,2)", "x(1,0)", "o(1,1)", "x(2,2)", "o(2,1)"]]
)[2]
assert math.isclose(explores_correct_states[0], 0.3)

# Test that incorrect states are not rewarded in exploration
explores_correct_states = mcts_individual_rewards(
    ["Exploring move sequence: ['x(2,0)']\nResulting state: \nxxo\nxo.\nxox\nEvaluating state with random playout: \no(0,0)\noxo\nxo.\nxox\nx(1,2)\noxo\nxox\nxox\ndraw in random playout\nUpdating move rewards: \n['x(2,0)']: 0.0 + 0.0 = 0.0\n\nExploring move sequence: ['x(0,0)']\nResulting state: \nxxo\nxo.\n.ox\nEvaluating state with random playout: \no(2,0)\nxxo\nxo.\noox\no wins in random playout\nUpdating move rewards: \n['x(0,0)']: 0.0 + -1.0 = -1.0\n\nExploring move sequence: ['x(1,2)', 'o(2,0)']\nResulting state: \n.xo\nxox\n.ox\no wins\nUpdating move rewards: \n['x(1,2)', 'o(2,0)']: 0.0 + -1.0 = -1.0\n['x(1,2)']: 0.0 + -1.0 = -1.0\n['x(1,2)'] is a winning move sequence for o\n\n"],
    [["x(0,1)", "o(0,2)", "x(1,0)", "o(1,1)", "x(2,2)", "o(2,1)"]]
)[2]
assert math.isclose(explores_correct_states[0], 0.1)

# Test that correct states are rewarded in playouts
playout_correct_states = mcts_individual_rewards(
    ["Exploring move sequence: ['x(2,0)']\nResulting state: \n.xo\nxo.\nxox\nEvaluating state with random playout: \no(0,0)\noxo\nxo.\nxox\nx(1,2)\noxo\nxox\nxox\ndraw in random playout\nUpdating move rewards: \n['x(2,0)']: 0.0 + 0.0 = 0.0\n\nExploring move sequence: ['x(0,0)']\nResulting state: \nxxo\nxo.\n.ox\nEvaluating state with random playout: \no(2,0)\nxxo\nxo.\noox\no wins in random playout\nUpdating move rewards: \n['x(0,0)']: 0.0 + -1.0 = -1.0\n\nExploring move sequence: ['x(1,2)', 'o(2,0)']\nResulting state: \n.xo\nxox\noox\no wins\nUpdating move rewards: \n['x(1,2)', 'o(2,0)']: 0.0 + -1.0 = -1.0\n['x(1,2)']: 0.0 + -1.0 = -1.0\n['x(1,2)'] is a winning move sequence for o\n\n"],
    [["x(0,1)", "o(0,2)", "x(1,0)", "o(1,1)", "x(2,2)", "o(2,1)"]]
)[5]
assert math.isclose(playout_correct_states[0], 0.2)

# Test that correct states are rewarded in playouts
playout_correct_states = mcts_individual_rewards(
    ["Exploring move sequence: ['x(2,0)']\nResulting state: \n.xo\nxo.\nxox\nEvaluating state with random playout: \no(0,0)\n.xo\nxo.\nxox\nx(1,2)\noxo\nxox\nxox\ndraw in random playout\nUpdating move rewards: \n['x(2,0)']: 0.0 + 0.0 = 0.0\n\nExploring move sequence: ['x(0,0)']\nResulting state: \nxxo\nxo.\n.ox\nEvaluating state with random playout: \no(2,0)\nxxo\nxox\noox\no wins in random playout\nUpdating move rewards: \n['x(0,0)']: 0.0 + -1.0 = -1.0\n\nExploring move sequence: ['x(1,2)', 'o(2,0)']\nResulting state: \n.xo\nxox\noox\no wins\nUpdating move rewards: \n['x(1,2)', 'o(2,0)']: 0.0 + -1.0 = -1.0\n['x(1,2)']: 0.0 + -1.0 = -1.0\n['x(1,2)'] is a winning move sequence for o\n\n"],
    [["x(0,1)", "o(0,2)", "x(1,0)", "o(1,1)", "x(2,2)", "o(2,1)"]]
)[5]
assert math.isclose(playout_correct_states[0], 0.0)


# =============================================================================
# RECOGNIZES TERMINAL STATES
# =============================================================================

# Test when terminal states are correctly recognized in exploration
explores_correct_terminal = mcts_individual_rewards(
    ["Exploring move sequence: ['x(1,2)', 'o(2,0)']\nResulting state: \n.xo\nxox\noox\no wins\nUpdating move rewards: \n['x(1,2)', 'o(2,0)']: 0.0 + -1.0 = -1.0\n['x(1,2)']: 0.0 + -1.0 = -1.0\n['x(1,2)'] is a winning move sequence for o\n\n"],
    [["x(0,1)", "o(0,2)", "x(1,0)", "o(1,1)", "x(2,2)", "o(2,1)"]]
)[3]
assert math.isclose(explores_correct_states[0], 0.1)

# Test when terminal states are incorrectly recognized in exploration
explores_incorrect_terminal = mcts_individual_rewards(
    ["Exploring move sequence: ['x(2,0)']\nResulting state: \n.xo\nxo.\nxox\no wins\nUpdating move rewards: \n['x(2,0)']: 0.0 + 0.0 = 0.0\n\\nExploring move sequence: ['x(1,2)', 'o(2,0)']\nResulting state: \n.xo\nxox\noox\nx wins\nUpdating move rewards: \n['x(1,2)', 'o(2,0)']: 0.0 + -1.0 = -1.0\n['x(1,2)']: 0.0 + -1.0 = -1.0\n['x(1,2)'] is a winning move sequence for o\n\nExploring move sequence: ['x(1,2)', 'o(2,0)']\nResulting state: \n.xo\nxox\noox\nEvaluating state with random playout: \no(0,0)\noxo\nxo.\nxox\nx(1,2)\noxo\nxox\nxox\ndraw in random playout\nUpdating move rewards: \n['x(1,2)', 'o(2,0)']: 0.0 + -1.0 = -1.0\n['x(1,2)']: 0.0 + -1.0 = -1.0\n['x(1,2)'] is a winning move sequence for o\n\n"],
    [["x(0,1)", "o(0,2)", "x(1,0)", "o(1,1)", "x(2,2)", "o(2,1)"]]
)[3]
assert math.isclose(explores_incorrect_terminal[0], 0.0)

# Test when terminal states are correctly recognized in playouts
playout_correct_terminal = mcts_individual_rewards(
    ["Exploring move sequence: ['x(2,0)']\nResulting state: \n.xo\nxo.\nxox\nEvaluating state with random playout: \no(0,0)\n.xo\nxo.\nxox\nx(1,2)\noxo\nxox\nxox\ndraw in random playout\nUpdating move rewards: \n['x(2,0)']: 0.0 + 0.0 = 0.0\n\nExploring move sequence: ['x(0,0)']\nResulting state: \nxxo\nxo.\n.ox\nEvaluating state with random playout: \no(2,0)\nxxo\nxox\noox\no wins in random playout\nUpdating move rewards: \n['x(0,0)']: 0.0 + -1.0 = -1.0\n\nExploring move sequence: ['x(1,2)', 'o(2,0)']\nResulting state: \n.xo\nxox\noox\no wins\nUpdating move rewards: \n['x(1,2)', 'o(2,0)']: 0.0 + -1.0 = -1.0\n['x(1,2)']: 0.0 + -1.0 = -1.0\n['x(1,2)'] is a winning move sequence for o\n\n"],
    [["x(0,1)", "o(0,2)", "x(1,0)", "o(1,1)", "x(2,2)", "o(2,1)"]]
)[6]
assert math.isclose(playout_correct_terminal[0], 0.2)

# Test when terminal states are incorrectly recognized in playouts
playout_incorrect_terminal = mcts_individual_rewards(
    ["Exploring move sequence: ['x(2,0)']\nResulting state: \n.xo\nxo.\nxox\nEvaluating state with random playout: \no(0,0)\n.xo\nxo.\nxox\ndraw in random playout\nUpdating move rewards: \n['x(2,0)']: 0.0 + 0.0 = 0.0\n\nExploring move sequence: ['x(0,0)']\nResulting state: \nxxo\nxo.\n.ox\nEvaluating state with random playout: \no(2,0)\nxxo\nxox\noox\ndraw in random playout\nUpdating move rewards: \n['x(0,0)']: 0.0 + -1.0 = -1.0\n\nExploring move sequence: ['x(1,2)', 'o(2,0)']\nResulting state: \n.xo\nxox\noox\no wins\nUpdating move rewards: \n['x(1,2)', 'o(2,0)']: 0.0 + -1.0 = -1.0\n['x(1,2)']: 0.0 + -1.0 = -1.0\n['x(1,2)'] is a winning move sequence for o\n\n"],
    [["x(0,1)", "o(0,2)", "x(1,0)", "o(1,1)", "x(2,2)", "o(2,1)"]]
)[6]
assert math.isclose(playout_incorrect_terminal[0], 0.0)


print("All passed")


All passed


<a name="Train"></a>
### Train the model

Now set up GRPO Trainer and all configurations

In [None]:
max_prompt_length = 512

from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    learning_rate = 1e-6,
    # lr_scheduler_type = "cosine",
    optim = "paged_adamw_8bit",
    logging_steps = 10,
    per_device_train_batch_size = 6,
    gradient_accumulation_steps = 4, # Increase to 4 for smoother training
    num_generations = 6, # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    max_completion_length = max_seq_length - max_prompt_length,
    num_train_epochs = 1,
    # max_steps = 250,
    save_steps = 250,
    # max_grad_norm = 0.1,
    report_to = "none", # Can use Weights & Biases
    output_dir = "grpo_outputs",
    # scale_rewards = False, # Disables std scaling, recommended by Dr. GRPO paper
)

And let's run the trainer! If you scroll up, you'll see a table of rewards. The goal is to see the `reward` column increase!

You might have to wait 150 to 200 steps for any action. You'll probably get 0 reward for the first 100 steps. Please be patient!

Example:

| Step | Training Loss | reward    | reward_std | completion_length | kl       |
|------|---------------|-----------|------------|-------------------|----------|
| 1    | 0.000000      | 0.125000  | 0.000000   | 200.000000        | 0.000000 |
| 2    | 0.000000      | 0.072375  | 0.248112   | 200.000000        | 0.000000 |
| 3    | 0.000000      | -0.079000 | 0.163776   | 182.500000        | 0.000005 |


In [None]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        strict_format_reward_func,
        soft_format_reward_func,
        xmlcount_reward_func,
        strict_reasoning_format_reward_func,
        soft_reasoning_format_reward_func,
        move_format_reward_func,
        start_state_reward_func,
        explores_valid_moves_reward_func,
        playout_valid_moves_reward_func,
        explored_states_reward_func,
        playout_states_reward_func,
        explored_terminal_reward_func,
        playout_terminal_reward_func,
        optimality_reward_func,
    ],
    args = training_args,
    train_dataset = grpo_dataset,
)
trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 3,001 | Num Epochs = 3 | Total steps = 9,003
O^O/ \_/ \    Batch size per device = 6 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (6 x 1 x 1) = 6
 "-____-"     Trainable parameters = 83,886,080/8,000,000,000 (1.05% trained)


Step,Training Loss,reward,reward_std,completion_length,kl,rewards / strict_format_reward_func,rewards / soft_format_reward_func,rewards / xmlcount_reward_func,rewards / strict_reasoning_format_reward_func,rewards / soft_reasoning_format_reward_func,rewards / move_format_reward_func,rewards / start_state_reward_func,rewards / explores_valid_moves_reward_func,rewards / playout_valid_moves_reward_func,rewards / explored_states_reward_func,rewards / playout_states_reward_func,rewards / explored_terminal_reward_func,rewards / playout_terminal_reward_func,rewards / optimality_reward_func


Unsloth: Will smartly offload gradients to save VRAM!


<a name="Inference"></a>
### Inference
Now let's try the model we just trained! First, let's first try the model without any GRPO trained:

In [None]:
text = tokenizer.apply_chat_template([
    {"role" : "user", "content" : "Calculate pi."},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)
output = model.fast_generate(
    [text],
    sampling_params = sampling_params,
    lora_request = None,
)[0].outputs[0].text

output

Processed prompts: 100%|██████████| 1/1 [00:51<00:00, 51.81s/it, est. speed input: 0.75 toks/s, output: 17.01 toks/s]


'**Calculating Pi using Python**\n\nPi (π) is a mathematical constant representing the ratio of a circle\'s circumference to its diameter. Here\'s a simple and efficient way to calculate an approximation of pi using Python.\n\n### Using the Monte Carlo Method\n\nThe Monte Carlo method is a computational algorithm that uses random sampling to approximate a value. In this case, we can use it to estimate pi by generating random points within a square and checking if they fall inside a quarter-circle inscribed within it.\n\n```python\nimport random\nimport math\n\ndef estimate_pi(num_samples):\n    """\n    Estimate the value of pi using the Monte Carlo method.\n\n    Args:\n    num_samples (int): The number of random points to generate.\n\n    Returns:\n    float: An approximation of pi.\n    """\n    points_inside_circle = 0\n\n    for _ in range(num_samples):\n        x, y = random.random(), random.random()\n        distance = x**2 + y**2\n        if distance <= 1:\n            points_i

And now with the LoRA we just trained with GRPO - we first save the LoRA first!

In [None]:
model.save_lora("grpo_saved_lora")

Now we load the LoRA and test:

In [None]:
text = tokenizer.apply_chat_template([
    {"role" : "system", "content" : SYSTEM_PROMPT},
    {"role" : "user", "content" : "Calculate pi."},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)
output = model.fast_generate(
    text,
    sampling_params = sampling_params,
    lora_request = model.load_lora("grpo_saved_lora"),
)[0].outputs[0].text

output

Processed prompts: 100%|██████████| 1/1 [00:23<00:00, 23.17s/it, est. speed input: 2.63 toks/s, output: 15.80 toks/s]


'Calculating pi to a high degree of accuracy is a complex task that requires a large amount of computational power. However, I can provide you with an approximate value of pi or show you a simple method to calculate it.\n\nOne of the simplest methods to calculate pi is the Leibniz formula, which is an infinite series:\n\npi/4 = 1 - 1/3 + 1/5 - 1/7 + 1/9 - ...\n\nThis series can be used to calculate an approximation of pi.\n\n<reasoning>\npi = 4 * (1 - 1/3 + 1/5 - 1/7 + 1/9 - ...)\n</reasoning>\n\nThis is a simple, yet effective method to calculate pi. However, the more terms you use, the more accurate the result will be.\n\nTo calculate pi to a high degree of accuracy, you would need to use a computer program to perform the calculation.\n\n<answer>\n3.141592653589793 (approximately)\n</answer>\n\nFor a more accurate result, I can provide you with a Python code snippet to calculate pi:\n\n```python\nimport math\n\ndef calculate_pi(n):\n    pi = 0.0\n    for i in range(n):\n        pi +=

Our reasoning model is much better - it's not always correct, since we only trained it for an hour or so - it'll be better if we extend the sequence length and train for longer!

<a name="Save"></a>
### Saving to float16 for VLLM

We also support saving to `float16` directly. Select `merged_16bit` for float16 or `merged_4bit` for int4. We also allow `lora` adapters as a fallback. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens.

In [None]:
# Merge to 16bit
if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_16bit",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_16bit", token = "")

# Merge to 4bit
if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_4bit",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_4bit", token = "")

# Just LoRA adapters
if False: model.save_pretrained_merged("model", tokenizer, save_method = "lora",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "lora", token = "")

### GGUF / llama.cpp Conversion
To save to `GGUF` / `llama.cpp`, we support it natively now! We clone `llama.cpp` and we default save it to `q8_0`. We allow all methods like `q4_k_m`. Use `save_pretrained_gguf` for local saving and `push_to_hub_gguf` for uploading to HF.

Some supported quant methods (full list on our [Wiki page](https://github.com/unslothai/unsloth/wiki#gguf-quantization-options)):
* `q8_0` - Fast conversion. High resource use, but generally acceptable.
* `q4_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K.
* `q5_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K.

[**NEW**] To finetune and auto export to Ollama, try our [Ollama notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)

In [None]:
# Save to 8bit Q8_0
if False: model.save_pretrained_gguf("model", tokenizer,)
# Remember to go to https://huggingface.co/settings/tokens for a token!
# And change hf to your username!
if False: model.push_to_hub_gguf("hf/model", tokenizer, token = "")

# Save to 16bit GGUF
if False: model.save_pretrained_gguf("model", tokenizer, quantization_method = "f16")
if False: model.push_to_hub_gguf("hf/model", tokenizer, quantization_method = "f16", token = "")

# Save to q4_k_m GGUF
if False: model.save_pretrained_gguf("model", tokenizer, quantization_method = "q4_k_m")
if False: model.push_to_hub_gguf("hf/model", tokenizer, quantization_method = "q4_k_m", token = "")

# Save to multiple GGUF options - much faster if you want multiple!
if False:
    model.push_to_hub_gguf(
        "hf/model", # Change hf to your username!
        tokenizer,
        quantization_method = ["q4_k_m", "q8_0", "q5_k_m",],
        token = "",
    )

Now, use the `model-unsloth.gguf` file or `model-unsloth-Q4_K_M.gguf` file in llama.cpp or a UI based system like Jan or Open WebUI. You can install Jan [here](https://github.com/janhq/jan) and Open WebUI [here](https://github.com/open-webui/open-webui)

And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!

Some other links:
1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)
2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)
3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)
6. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://docs.unsloth.ai/get-started/unsloth-notebooks)!

<div class="align-center">
  <a href="https://unsloth.ai"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
  <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord.png" width="145"></a>
  <a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a>

  Join Discord if you need help + ⭐️ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐️
</div>
