<a href="https://colab.research.google.com/github/Vinooj/llm-fine_tuning-experiments/blob/main/fine_tuning_for_tool_support_with_grpo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Using GRPO (Group Relative Policy Optimization) for tool-calling

Here is a complete, self-contained Python script to fine-tune the Phi-3-base model for general-purpose tool calling using GRPO.

This code includes a minimal dataset, the dynamic prompt structure, and a robust, general-purpose reward function that validates tool calls against a provided schema.

### Cell 1: Setup and Installation

In [None]:
%%capture
# Automatically select the appropriate PyTorch index at runtime by inspecting the installed CUDA driver version via --torch-backend=auto
# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
!pip install vllm torch torchvision torchaudio --torch-backend=auto

# Install core packages without dependencies (to avoid version conflicts)
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl

# Install specific triton version without dependencies
!pip install triton==2.1.0 --no-deps

# Install unsloth-related packages
!pip install --no-deps cut_cross_entropy unsloth_zoo
!pip install --no-deps unsloth

# Install remaining packages with dependencies (these are generally stable)
!pip install sentencepiece protobuf datasets huggingface_hub hf_transfer

In [None]:
import torch
import triton
import unsloth
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Triton: {triton.__version__}")
print("All packages installed successfully!")

### Load base model

In [None]:
# Import FastLanguageModel instead of AutoModelForCausalLM.from_pretrained form
# Huggingface which leverages Optimized Kernels, Efficient Memory Management,
# Smart Data Type Handling ( Precision) to improve training speed
import torch
import json
import re
from datasets import Dataset
from transformers import AutoTokenizer
from trl import GRPOTrainer, GRPOConfig
from trl import SFTTrainer, SFTConfig
from google.colab import userdata
from unsloth import FastLanguageModel

### Cell 2: Imports and Model Loading with Unsloth
This is the key change. We use FastLanguageModel from Unsloth, which automatically handles the backend optimizations for significantly faster training and lower memory usage.

In [None]:
# More info about parameters: https://huggingface.co/docs/peft/v0.11.0/en/package_reference/lora#peft.LoraConfig
target_modules =  ["q_proj", "k_proj", "v_proj", "o_proj",
                   "gate_proj", "up_proj", "down_proj"]

# When adding special tokens
train_embeddings = False

if train_embeddings:
  target_modules = target_modules + ["lm_head"]

# Sets the maximum number of tokens that this specific instance of the model and
# its tokenizer will be configured to handle during our finetuning and subsequent
# inference.
max_seq_length = 2048


# we are telling Unsloth to automatically determine the most suitable data type
#(precision) for the model based on the available hardware (like your GPU).
# Unsloth is designed to leverage faster and more memory-efficient data types,
#such as bfloat16 or float16, if your hardware supports them.
dtype = None


model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/phi-3-mini-4k-instruct", # Use the Unsloth version for optimizations
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit = True,
    token=userdata.get('HF_TOKEN')
)

# tokenizer.clean_up_tokenization_spaces = False 'Came from the old ascii_art code'

# PEFT stands for "Parameter-Efficient Finetuning," and Unsloth integrates with PEFT methods like LoRA, QLoRA
model = FastLanguageModel.get_peft_model(
    model,
    r = 16,              # A rank of 16 is a common value that balances expressiveness with
                         # parameter efficiency. A higher rank means more parameters in the
                         # LoRA adapters, allowing for more complex changes but also increasing the risk of overfitting slightly
    target_modules = target_modules,  # On which modules of the llm the lora weights are used
    lora_alpha = 16,     # scales the weights of the adapters (more influence on base model), 16 was recommended on reddit
                         # Having a value same as r, lora_alpha/r = 1 is the normal.
    lora_dropout = 0,    # Default on 0.05 in tutorial but unsloth says 0 is better, This is a regularization technique
    bias = "none",       # "none" is optimized. Contributes to VRAM (GPU memory) and improving training efficiency.
    use_gradient_checkpointing = "unsloth", #"unsloth" for very long context, decreases vram. Contributes to VRAM (GPU memory) and improving training efficiency.
    random_state = 3407,
    use_rslora = False,  # scales lora_alpha with 1/sqrt(r), huggingface says this works better.
                         # Now, let's look at use_rslora = False. This parameter controls whether
                         # "Rank-Stabilized LoRA" is used. Rank-Stabilized LoRA is a variation
                         # where the LoRA adapter's output is scaled by lora_alpha / sqrt(r) instead of lora_alpha / r
    loftq_config = None, # And LoftQ
)

# The Phi-3 Instruct model already has a chat template, so we don't need to set it manually.

### Cell 3: The Training Dataset
This section defines our small, diverse dataset for teaching tool use. This remains unchanged from the original script.

In [None]:
# @title 3. The Training Dataset
# A minimal, diverse dataset to teach tool use and when NOT to use tools.
training_data = [
    {
        "prompt": "What is the weather like in New York City?",
        "is_negative": False,
        "tools": [
            {
                "name": "get_weather",
                "description": "Fetches the current weather for a given city.",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "city": {"type": "string", "description": "The city name."}
                    },
                    "required": ["city"],
                },
            }
        ],
    },
    {
        "prompt": "Please send an email to john.doe@example.com with the subject 'Hello' and body 'How are you?'",
        "is_negative": False,
        "tools": [
            {
                "name": "send_email",
                "description": "Sends an email.",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "recipient": {"type": "string"},
                        "subject": {"type": "string"},
                        "body": {"type": "string"},
                    },
                    "required": ["recipient", "subject", "body"],
                },
            }
        ],
    },
    {
        "prompt": "Hello, how are you today?",
        "is_negative": True, # This is a negative sample. No tool should be called.
        "tools": [
            {"name": "get_weather", "description": "Gets the weather."},
            {"name": "send_email", "description": "Sends an email."},
        ],
    },
    {
        "prompt": "What's the current price of the AAPL stock?",
        "is_negative": False,
        "tools": [
            {"name": "search_news", "description": "Searches for news articles."},
            {
                "name": "get_stock_price",
                "description": "Gets the current price of a stock symbol.",
                "parameters": {
                    "type": "object",
                    "properties": {"symbol": {"type": "string"}},
                    "required": ["symbol"],
                },
            },
        ],
    },
]

# Convert to Hugging Face Dataset object
dataset = Dataset.from_list(training_data)

### Cell 4: Dynamic Prompt Formatting
This function dynamically creates the system prompt with the available tools for each specific training example.

In [None]:
# @title 4. Dynamic Prompt Formatting
# This function creates the full prompt the model will see.
def create_prompt(sample):
    system_message = (
        "You are a helpful assistant. You have access to the following tools. "
        "When a user's request can be fulfilled by a tool, respond with a tool call in the format: "
        "<tool_call>{\"name\": \"tool_name\", \"arguments\": {\"arg1\": \"value1\"}}</tool_call>. "
        "If no tool is appropriate, answer conversationally."
    )
    # Add the JSON representation of the tools to the system message
    system_message += "\nAvailable Tools:\n" + json.dumps(sample["tools"], indent=2)

    messages = [
        {"role": "system", "content": system_message},
        {"role": "user", "content": sample["prompt"]},
    ]

    # We only need the formatted prompt for the trainer
    prompt_str = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    sample["formatted_prompt"] = prompt_str
    return sample

# Apply the formatting to the entire dataset
dataset = dataset.map(create_prompt)

### Cell 5: The General-Purpose Reward Function
This is the core of the GRPO logic. The function scores the model's generated tool calls based on their validity and adherence to the provided schema.

In [None]:
# @title 5. The General-Purpose Reward Function
def reward_function(generated_responses: list[str], sample: dict) -> torch.Tensor:
    rewards = []
    tool_schemas = {tool['name']: tool for tool in sample['tools']}
    is_negative_sample = sample.get('is_negative', False)

    for response in generated_responses:
        score = 0.0
        # Regex to find the tool call within the response
        match = re.search(r"<tool_call>(.*?)</tool_call>", response, re.DOTALL)

        if match:
            # If a tool call is found
            if is_negative_sample:
                # Penalty for calling a tool on a negative sample
                score -= 5.0
            else:
                try:
                    tool_call_str = match.group(1)
                    tool_call_obj = json.loads(tool_call_str)
                    tool_name = tool_call_obj.get("name")
                    arguments = tool_call_obj.get("arguments", {})

                    if tool_name in tool_schemas:
                        score += 2.0  # Reward for selecting a valid tool
                        schema = tool_schemas[tool_name].get("parameters", {})
                        required_args = schema.get("required", [])

                        # Reward for providing all required arguments
                        for req_arg in required_args:
                            if req_arg in arguments:
                                score += 1.0
                            else:
                                score -= 2.0 # Heavy penalty for missing required arg

                        # Penalize extraneous arguments
                        for arg in arguments:
                            if arg not in schema.get("properties", {}):
                                score -= 1.0
                    else:
                        score -= 3.0 # Penalty for calling a non-existent tool
                except json.JSONDecodeError:
                    score -= 5.0 # Heavy penalty for malformed JSON
        else:
            # If no tool call is found
            if is_negative_sample:
                score += 5.0 # High reward for correctly ignoring a negative sample
            else:
                # Penalty for not calling a tool when one was likely needed
                score -= 2.0

        rewards.append(score)

    return torch.tensor(rewards, dtype=torch.float32)

### Cell 6: GRPO Trainer Setup and Execution
Here, we configure and launch the GRPOTrainer. Thanks to Unsloth, this process will be 2-3 times faster than the standard implementation.

In [None]:
# @title 6. GRPO Trainer Setup and Execution
# GRPO Configuration
grpo_config = GRPOConfig(
    output_dir="./grpo_phi3_tool_caller",
    num_train_epochs=5, # Increased epochs for better learning on a small dataset
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    logging_steps=1,
    max_prompt_length=2048, # Corresponds to max_seq_length
    max_completion_length=512,
    beta=0.1,
    bf16=True,
    optim="adamw_8bit", # Use 8-bit AdamW optimizer
)

# Initialize the GRPO Trainer
trainer = GRPOTrainer(
    model=model,
    args=grpo_config,
    tokenizer=tokenizer,
    train_dataset=dataset,
    reward_function=reward_function,
    prompt_col="formatted_prompt",
)

# Start training!
print("Starting GRPO training with Unsloth...")
trainer.train()
print("Training complete!")

# Save the trained LoRA adapter
trainer.save_model("./grpo_phi3_tool_caller/final_adapter")
print("Model adapter saved!")

### Cell 7: Inference
Finally, let's test our newly trained model with a complex prompt that requires using multiple tools.

In [None]:
# @title 7. Inference
# Use Unsloth's fast inference
from unsloth.chat_templates import get_chat_template

tokenizer = get_chat_template(
    tokenizer,
    chat_template = "phi-3", # Use the Phi-3 chat template
)

# Create a test sample
test_sample = {
    "prompt": "I need to know the weather in London and also send a confirmation email to supervisor@company.com with the subject 'Task Complete'",
    "tools": [
        {
            "name": "get_weather",
            "description": "Fetches the current weather.",
            "parameters": {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
        },
        {
            "name": "send_email",
            "description": "Sends an email.",
            "parameters": {"type": "object", "properties": {"recipient": {"type": "string"}, "subject": {"type": "string"}}, "required": ["recipient", "subject"]},
        },
    ]
}

# Format the prompt using the same function as in training
formatted_prompt = create_prompt(test_sample)["formatted_prompt"]
inputs = tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=False).to("cuda")

outputs = model.generate(**inputs, max_new_tokens=200, use_cache=True)
response_text = tokenizer.batch_decode(outputs)[0]

print("\nUser Prompt:\n", test_sample["prompt"])
print("\nModel Response:\n", response_text.split("<|assistant|>")[1])