# GSPO Training for Tool-Calling

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ProfSynapse/Toolset-Training/blob/main/Trainers/notebooks/gspo_tool_calling_notebook.ipynb)

## What is GSPO?

**GSPO (Group Sequence Policy Optimization)** is a reinforcement learning technique developed by the Qwen team. Unlike SFT (which learns from examples) or KTO (which learns preferences), GSPO:

- **Generates completions** during training (online RL)
- **Scores outputs** using custom reward functions
- **Optimizes at sequence level** (not token level like GRPO)

## Why GSPO for Tool-Calling?

| Method | Best For | Tool-Calling Use Case |
|--------|----------|----------------------|
| **SFT** | Teaching new skills | Initial tool syntax learning |
| **KTO** | Refining preferences | Good vs bad tool choices |
| **GSPO** | Reward-guided optimization | Complex tool selection & argument quality |

GSPO shines when you have **complex evaluation criteria** that can be expressed as reward functions:
- Did the model select the RIGHT tool?
- Are the arguments structurally correct?
- Is the context object complete?
- Do argument values make semantic sense?

## Hardware Requirements

- **7B models**: T4 (15GB) - Free Colab works!
- **13B+ models**: A100 recommended
- **Training time**: ~60-90 minutes for 7B

## 1. Installation

Install Unsloth and dependencies (~2 minutes).

In [1]:
# Install Unsloth for faster training
%%capture
!pip install unsloth
!pip uninstall unsloth -y && pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"

In [2]:
# Install training dependencies
%%capture
!pip install -U "transformers>=4.45.0"
!pip install "datasets==4.3.0"
!pip install -U accelerate bitsandbytes
!pip install -U trl peft xformers triton

## 2. Mount Google Drive (Optional)

Save checkpoints to persist across sessions.

In [3]:
from google.colab import drive
import os

drive.mount('/content/drive')

DRIVE_OUTPUT_DIR = "/content/drive/MyDrive/GSPO_Training"
os.makedirs(DRIVE_OUTPUT_DIR, exist_ok=True)

print(f"Checkpoints will be saved to: {DRIVE_OUTPUT_DIR}")

Mounted at /content/drive
Checkpoints will be saved to: /content/drive/MyDrive/GSPO_Training


## 3. HuggingFace Credentials

Add your HF token to Colab secrets (key icon in sidebar).

In [4]:
import os
from google.colab import userdata
from huggingface_hub import HfApi

HF_TOKEN = userdata.get('HF_TOKEN')
os.environ['HF_TOKEN'] = HF_TOKEN

api = HfApi()
hf_user = api.whoami(token=HF_TOKEN)["name"]

print(f"HuggingFace username: {hf_user}")

HuggingFace username: professorsynapse


## 4. Configuration

In [5]:
# @title Model & Dataset Configuration

# @markdown ### Base Model
MODEL_NAME = "unsloth/mistral-7b-instruct-v0.3-bnb-4bit" # @param ["unsloth/Llama-3.2-1B-Instruct-bnb-4bit", "unsloth/Qwen2.5-3B-Instruct-bnb-4bit", "unsloth/mistral-7b-instruct-v0.3-bnb-4bit", "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"]

# @markdown ### Sequence Length
MAX_SEQ_LENGTH = 2048 # @param [1024, 2048, 4096] {type:"raw"}

# @markdown ### Dataset
DATASET_NAME = "professorsynapse/claudesidian-behaviors-merged" # @param {type:"string"}
DATASET_FILE = "behavior_gspo_v1.3.jsonl" # @param {type:"string"}

# @markdown ### Output Model Name
OUTPUT_MODEL_NAME = "nexus-tools-sft16-gspo-1" # @param {type:"string"}

print(f"Model: {MODEL_NAME}")
print(f"Dataset: {DATASET_NAME}/{DATASET_FILE}")
print(f"Output: {OUTPUT_MODEL_NAME}")

Model: unsloth/mistral-7b-instruct-v0.3-bnb-4bit
Dataset: professorsynapse/claudesidian-behaviors-merged/behavior_gspo_v1.3.jsonl
Output: nexus-tools-sft16-gspo-1


## 5. Load Model

In [6]:
from unsloth import FastLanguageModel
import torch

print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
print()

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=MODEL_NAME,
    max_seq_length=MAX_SEQ_LENGTH,
    dtype=None,
    load_in_4bit=True,
    token=HF_TOKEN,
)

print(f"Model loaded: {sum(p.numel() for p in model.parameters()):,} parameters")

ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!
GPU: NVIDIA L4
VRAM: 22.2 GB

==((====))==  Unsloth 2025.11.4: Fast Mistral patching. Transformers: 4.57.3.
   \\   /|    NVIDIA L4. Num GPUs = 1. Max memory: 22.161 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu126. CUDA: 8.9. CUDA Toolkit: 12.6. Triton: 3.5.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.33.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/4.14G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/157 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/587k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/446 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

Model loaded: 3,758,362,624 parameters


## 6. Apply LoRA Adapters

In [7]:
# @title LoRA Configuration

r = 32 # @param [8, 16, 32, 64] {type:"raw"}
lora_alpha = r * 2
lora_dropout = 0.05 # @param {type:"number"}

model = FastLanguageModel.get_peft_model(
    model,
    r=r,
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    bias="none",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj"],
    use_gradient_checkpointing="unsloth",
    random_state=3407,
)

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"LoRA applied: {trainable:,} trainable parameters")

Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = 0.05.
Unsloth will patch all other layers, except LoRA matrices, causing a performance hit.
Unsloth 2025.11.4 patched 32 layers with 0 QKV layers, 0 O layers and 0 MLP layers.


LoRA applied: 83,886,080 trainable parameters


## 7. Load GSPO Dataset

The GSPO dataset format:
```json
{
  "prompt": [{"role": "system", ...}, {"role": "user", ...}],
  "ground_truth_tool": "toolName",
  "ground_truth_args": {...}
}
```

In [8]:
from datasets import load_dataset

dataset = load_dataset(
    DATASET_NAME,
    data_files=DATASET_FILE,
    split="train"
)

print(f"Loaded {len(dataset)} examples")
print(f"\nSample prompt:")
print(dataset[0]["prompt"])
print(f"\nGround truth tool: {dataset[0]['ground_truth_tool']}")

README.md: 0.00B [00:00, ?B/s]

behavior_gspo_v1.3.jsonl: 0.00B [00:00, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Loaded 1324 examples

Sample prompt:
[{'role': 'user', 'content': 'Check which files I modified today'}]

Ground truth tool: vaultLibrarian_searchDirectory


## 8. Define Reward Functions

These functions evaluate the model's generated completions and return rewards.

**Reward breakdown:**
- `tool_selection_reward`: +1.0 if correct tool selected
- `json_structure_reward`: +0.3 if valid JSON in tool call
- `context_completeness_reward`: +0.5 if all 7 context fields present
- `format_reward`: +0.2 if output follows expected format

In [9]:
import re
import json

# Required context fields
CONTEXT_FIELDS = [
    "sessionId", "workspaceId", "sessionDescription",
    "sessionMemory", "toolContext", "primaryGoal", "subgoal"
]


def extract_tool_call(text: str) -> tuple:
    """
    Extract tool name and arguments from model output.

    Handles multiple formats:
    - OpenAI function calling format
    - Text-based tool_call: format
    - JSON tool_calls array
    """
    tool_name = None
    tool_args = None

    # Try to find function call pattern
    # Pattern 1: tool_call: toolName\narguments: {...}
    tc_match = re.search(r'tool_call:\s*(\w+)\s*\narguments:\s*({[^}]+})', text, re.DOTALL)
    if tc_match:
        tool_name = tc_match.group(1)
        try:
            tool_args = json.loads(tc_match.group(2))
        except:
            tool_args = {}
        return tool_name, tool_args

    # Pattern 2: {"name": "toolName", "arguments": ...}
    fn_match = re.search(r'"name"\s*:\s*"([^"]+)".*?"arguments"\s*:\s*"?({[^}]+})"?', text, re.DOTALL)
    if fn_match:
        tool_name = fn_match.group(1)
        try:
            args_str = fn_match.group(2)
            # Handle escaped JSON
            if '\\"' in args_str:
                args_str = args_str.replace('\\"', '"')
            tool_args = json.loads(args_str)
        except:
            tool_args = {}
        return tool_name, tool_args

    # Pattern 3: Just look for any tool name pattern
    name_match = re.search(r'(\w+Manager_\w+|\w+Librarian_\w+)', text)
    if name_match:
        tool_name = name_match.group(1)
        # Try to find JSON after it
        json_match = re.search(r'{[^{}]*"context"[^{}]*}', text)
        if json_match:
            try:
                tool_args = json.loads(json_match.group(0))
            except:
                tool_args = {}

    return tool_name, tool_args


def tool_selection_reward(completions, ground_truth_tool, **kwargs):
    """
    Reward for selecting the correct tool.
    +1.0 if exact match, +0.3 if same agent family, 0.0 otherwise.
    """
    rewards = []
    for completion in completions:
        text = completion[0]["content"] if isinstance(completion, list) else completion
        tool_name, _ = extract_tool_call(text)

        if tool_name == ground_truth_tool:
            rewards.append(1.0)
        elif tool_name and ground_truth_tool:
            # Partial credit for same agent family (e.g., vaultManager_*)
            pred_agent = tool_name.split('_')[0] if '_' in tool_name else ''
            true_agent = ground_truth_tool.split('_')[0] if '_' in ground_truth_tool else ''
            if pred_agent == true_agent:
                rewards.append(0.3)
            else:
                rewards.append(0.0)
        else:
            rewards.append(0.0)

    return rewards


def json_structure_reward(completions, **kwargs):
    """
    Reward for valid JSON structure in tool call.
    +0.3 if valid JSON with arguments, 0.0 otherwise.
    """
    rewards = []
    for completion in completions:
        text = completion[0]["content"] if isinstance(completion, list) else completion
        _, tool_args = extract_tool_call(text)

        if tool_args and isinstance(tool_args, dict) and len(tool_args) > 0:
            rewards.append(0.3)
        else:
            rewards.append(0.0)

    return rewards


def context_completeness_reward(completions, **kwargs):
    """
    Reward for complete context object.
    Up to +0.5 based on how many of the 7 required fields are present.
    """
    rewards = []
    for completion in completions:
        text = completion[0]["content"] if isinstance(completion, list) else completion
        _, tool_args = extract_tool_call(text)

        if not tool_args or "context" not in tool_args:
            rewards.append(0.0)
            continue

        context = tool_args.get("context", {})
        if not isinstance(context, dict):
            rewards.append(0.0)
            continue

        # Count present fields
        present = sum(1 for field in CONTEXT_FIELDS if field in context and context[field])
        reward = (present / len(CONTEXT_FIELDS)) * 0.5
        rewards.append(reward)

    return rewards


def format_reward(completions, **kwargs):
    """
    Reward for proper output format.
    +0.2 if output contains tool call structure, 0.0 otherwise.
    """
    rewards = []
    for completion in completions:
        text = completion[0]["content"] if isinstance(completion, list) else completion

        # Check for tool call indicators
        has_tool_call = (
            'tool_call' in text.lower() or
            'function' in text.lower() or
            'arguments' in text.lower() or
            re.search(r'\w+Manager_\w+|\w+Librarian_\w+', text)
        )

        rewards.append(0.2 if has_tool_call else 0.0)

    return rewards


# Combine all reward functions
REWARD_FUNCTIONS = [
    tool_selection_reward,
    json_structure_reward,
    context_completeness_reward,
    format_reward,
]

print(f"Defined {len(REWARD_FUNCTIONS)} reward functions")
print("  - tool_selection_reward (max +1.0)")
print("  - json_structure_reward (max +0.3)")
print("  - context_completeness_reward (max +0.5)")
print("  - format_reward (max +0.2)")
print(f"  Total max reward: 2.0")

Defined 4 reward functions
  - tool_selection_reward (max +1.0)
  - json_structure_reward (max +0.3)
  - context_completeness_reward (max +0.5)
  - format_reward (max +0.2)
  Total max reward: 2.0


## 9. Configure GSPO Training

GSPO uses `GRPOTrainer` with `importance_sampling_level="sequence"`.

In [11]:
from trl import GRPOConfig, GRPOTrainer
from unsloth import is_bfloat16_supported
from datetime import datetime

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = f"{DRIVE_OUTPUT_DIR}/{timestamp}"

# @title GSPO Hyperparameters

# @markdown ### Performance
per_device_train_batch_size = 4 # @param [1, 2, 4] {type:"raw"}
gradient_accumulation_steps = 4 # @param [4, 8, 16] {type:"raw"}

# @markdown ### Generation
num_generations = 4 # @param [2, 4, 8] {type:"raw"}
max_new_tokens = 512 # @param [256, 512, 1024] {type:"raw"}

# @markdown ### Learning
learning_rate_exponent = 6 # @param [5, 6, 7] {type:"raw"}
learning_rate_multiplier = 5 # @param [1, 2, 3, 4, 5] {type:"raw"}
learning_rate = learning_rate_multiplier * (10 ** -learning_rate_exponent)

num_train_epochs = 1 # @param {type:"integer"}

# @markdown ### GSPO-Specific
# @markdown epsilon controls the clipping range for policy updates
epsilon = 0.0003 # @param {type:"number"}
epsilon_high = 0.0004 # @param {type:"number"}

training_args = GRPOConfig(
    output_dir=output_dir,

    # Batch settings
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,

    # Generation settings
    num_generations=num_generations,
    max_completion_length=max_new_tokens, # Changed from max_new_tokens

    # GSPO: sequence-level importance sampling
    importance_sampling_level="sequence",

    # Learning settings
    learning_rate=learning_rate,
    num_train_epochs=num_train_epochs,

    # Precision
    fp16=not is_bfloat16_supported(),
    bf16=is_bfloat16_supported(),

    # Optimizer
    optim="adamw_8bit",

    # Logging
    logging_steps=5,
    save_steps=50,
    save_total_limit=3,

    # Misc
    seed=42,
    report_to="none",
)

print(f"GSPO Configuration:")
print(f"  Batch: {per_device_train_batch_size} x {gradient_accumulation_steps} = {per_device_train_batch_size * gradient_accumulation_steps}")
print(f"  Generations per prompt: {num_generations}")
print(f"  Learning rate: {learning_rate}")
print(f"  Importance sampling: sequence (GSPO)")
print(f"  Output: {output_dir}")

GSPO Configuration:
  Batch: 4 x 4 = 16
  Generations per prompt: 4
  Learning rate: 4.9999999999999996e-06
  Importance sampling: sequence (GSPO)
  Output: /content/drive/MyDrive/GSPO_Training/20251126_185254


## 10. Prepare Dataset for GRPO

Format the dataset so GRPOTrainer can use the ground truth for rewards.

In [13]:
# Mistral chat template
MISTRAL_CHAT_TEMPLATE = """{{ bos_token }}{% for message in messages %}{% if message['role'] == 'system' %}{% if loop.index == 1 %}{{ message['content'] + ' ' }}{% endif %}{% elif message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token }}{% endif %}{% endfor %}"""

if tokenizer.chat_template is None:
    if 'mistral' in MODEL_NAME.lower():
        tokenizer.chat_template = MISTRAL_CHAT_TEMPLATE
        print("Applied Mistral chat template")


def format_for_grpo(example):
    """
    Format example for GRPOTrainer.

    GRPOTrainer expects:
    - prompt: The formatted input text
    - Additional fields for reward computation
    """
    # Format prompt using chat template
    prompt_text = tokenizer.apply_chat_template(
        example["prompt"],
        tokenize=False,
        add_generation_prompt=True  # Add prompt for model to continue
    )

    return {
        "prompt": prompt_text,
        "ground_truth_tool": example["ground_truth_tool"],
        # Added default=str to handle datetime objects
        "ground_truth_args": json.dumps(example["ground_truth_args"], default=str) if example["ground_truth_args"] else "{}"
    }


# Apply formatting
formatted_dataset = dataset.map(
    format_for_grpo,
    remove_columns=["prompt"],  # Remove original, keep ground_truth fields
    desc="Formatting for GRPO"
)

print(f"Dataset formatted: {len(formatted_dataset)} examples")
print(f"\nSample formatted prompt (truncated):")
print(formatted_dataset[0]["prompt"][:300])

Formatting for GRPO:   0%|          | 0/1324 [00:00<?, ? examples/s]

Dataset formatted: 1324 examples

Sample formatted prompt (truncated):
<s>[INST] Check which files I modified today[/INST]


## 11. Create Custom Reward Function

Combine all reward functions into one that GRPOTrainer can use.

In [14]:
def combined_reward_function(completions, prompts, ground_truth_tool, **kwargs):
    """
    Combined reward function for tool-calling evaluation.

    Returns total reward (max 2.0) for each completion.
    """
    batch_size = len(completions)
    total_rewards = [0.0] * batch_size

    # Get rewards from each function
    tool_rewards = tool_selection_reward(completions, ground_truth_tool)
    json_rewards = json_structure_reward(completions)
    context_rewards = context_completeness_reward(completions)
    format_rewards = format_reward(completions)

    # Combine rewards
    for i in range(batch_size):
        total_rewards[i] = (
            tool_rewards[i] +
            json_rewards[i] +
            context_rewards[i] +
            format_rewards[i]
        )

    return total_rewards


# Test the reward function
test_completion = [[{"content": 'tool_call: vaultManager_listDirectory\narguments: {"context": {"sessionId": "test", "workspaceId": "test", "sessionDescription": "test", "sessionMemory": "test", "toolContext": "test", "primaryGoal": "test", "subgoal": "test"}, "path": "/"}'}]]
test_reward = combined_reward_function(
    test_completion,
    prompts=["test"],
    ground_truth_tool="vaultManager_listDirectory"
)
print(f"Test reward (perfect completion): {test_reward[0]:.2f} / 2.0")

Test reward (perfect completion): 1.20 / 2.0


## 12. Initialize GRPO Trainer

In [15]:
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=combined_reward_function,
    args=training_args,
    train_dataset=formatted_dataset,
)

print("GSPO Trainer initialized!")
print(f"  Using sequence-level importance sampling (GSPO)")
print(f"  {num_generations} generations per prompt")
print(f"  Reward function: combined_reward_function (max 2.0)")

GSPO Trainer initialized!
  Using sequence-level importance sampling (GSPO)
  4 generations per prompt
  Reward function: combined_reward_function (max 2.0)


## 13. Train!

GSPO training will:
1. Generate multiple completions for each prompt
2. Score them using the reward functions
3. Update the model to favor higher-reward completions

**Expected metrics:**
- **reward/mean**: Average reward across generations (should increase)
- **reward/std**: Reward variance (should decrease as model converges)
- **loss**: Policy loss (can fluctuate, watch reward instead)

In [None]:
import torch

# Memory check
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
print(f"GPU: {gpu_stats.name}")
print(f"Memory reserved: {start_gpu_memory} GB / {gpu_stats.total_memory / 1024**3:.1f} GB")
print()

# Check for checkpoints
import glob
checkpoint_dirs = sorted(glob.glob(f"{output_dir}/checkpoint-*"))
resume_from = max(checkpoint_dirs, key=lambda x: int(x.split("-")[-1])) if checkpoint_dirs else None

if resume_from:
    print(f"Resuming from: {resume_from}")
else:
    print("Starting fresh training")

print("\n" + "=" * 60)
print("STARTING GSPO TRAINING")
print("=" * 60 + "\n")

trainer_stats = trainer.train(resume_from_checkpoint=resume_from)

print("\n" + "=" * 60)
print("TRAINING COMPLETED")
print("=" * 60)

The model is already on multiple devices. Skipping the move to device specified in `args`.


GPU: NVIDIA L4
Memory reserved: 6.766 GB / 22.2 GB

Starting fresh training

STARTING GSPO TRAINING



==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,324 | Num Epochs = 1 | Total steps = 331
O^O/ \_/ \    Batch size per device = 4 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (4 x 4 x 1) = 16
 "-____-"     Trainable parameters = 83,886,080 of 7,331,909,632 (1.14% trained)


## 14. Save & Upload Model

In [None]:
# Save locally
model.save_pretrained(f"{output_dir}/final_model")
tokenizer.save_pretrained(f"{output_dir}/final_model")
print(f"Saved to: {output_dir}/final_model")

In [None]:
# Upload LoRA adapters to HuggingFace
model.push_to_hub(
    f"{hf_user}/{OUTPUT_MODEL_NAME}",
    token=HF_TOKEN,
    private=False
)
tokenizer.push_to_hub(
    f"{hf_user}/{OUTPUT_MODEL_NAME}",
    token=HF_TOKEN,
    private=False
)

print(f"Uploaded to: https://huggingface.co/{hf_user}/{OUTPUT_MODEL_NAME}")

In [None]:
# Upload merged model + GGUF quantizations
print("Creating merged model and GGUF quantizations...")
print("This takes ~10 minutes")

model.push_to_hub_merged(
    f"{hf_user}/{OUTPUT_MODEL_NAME}-merged",
    tokenizer,
    save_method="merged_16bit",
    token=HF_TOKEN,
    private=False
)

model.push_to_hub_gguf(
    f"{hf_user}/{OUTPUT_MODEL_NAME}",
    tokenizer,
    quantization_method=["q4_k_m", "q5_k_m", "q8_0"],
    token=HF_TOKEN,
)

print(f"\nAll models uploaded!")
print(f"  LoRA: https://huggingface.co/{hf_user}/{OUTPUT_MODEL_NAME}")
print(f"  Merged: https://huggingface.co/{hf_user}/{OUTPUT_MODEL_NAME}-merged")
print(f"  GGUF: https://huggingface.co/{hf_user}/{OUTPUT_MODEL_NAME} (Files tab)")

## Done!

Your GSPO-trained model is ready!

### What You Accomplished

- Trained a model using **GSPO reinforcement learning**
- Used **custom reward functions** for tool-calling evaluation
- Created **multiple model formats** (LoRA, merged, GGUF)

### Next Steps

1. **Test** with LM Studio or Ollama
2. **Compare** GSPO vs SFT performance on tool selection accuracy
3. **Iterate** on reward functions to improve specific behaviors