<a href="https://colab.research.google.com/github/JSJeong-me/GPT-Agent/blob/main/GRPO/RFT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Function Calling with RFT

This notebook will show you how to train an adapter with RFT on a function calling use case.

Function calling (or tool calling) is a type of prompting/inference in which the model has is given the ability to use a set of tools to generate a response. For example, you might provide the model with a function that multiplies two numbers, and provide it with parameters to pass to the multiply function. When prompted, the model will use the multiply function to generate the product of two numbers.

Here, we train a model to select the correct tools and parameters to pass those tools given a request and a set of available tools.

In [None]:
# !pip uninstall numpy -y --quiet
# !pip install predibase --quiet
# # !pip install datasets --quiet
# ! pip install --upgrade datasets huggingface_hub

In [None]:
# !pip uninstall numpy -y --quiet
# !pip install numpy==1.24.0 --quiet # Install a known compatible version of numpy
# !pip install predibase --quiet
# ! pip install --upgrade datasets huggingface_hub

In [3]:
!pip uninstall numpy -y --quiet
!pip install numpy==1.24.0 --quiet # Install a known compatible version of numpy
# Uninstall and reinstall predibase and datasets to ensure compatibility with the installed numpy version
!pip uninstall predibase -y --quiet
!pip uninstall datasets -y --quiet
!pip install predibase --quiet
! pip install --upgrade datasets huggingface_hub

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow 2.18.0 requires numpy<2.1.0,>=1.26.0, but you have numpy 1.24.0 which is incompatible.
albumentations 2.0.7 requires numpy>=1.24.4, but you have numpy 1.24.0 which is incompatible.
imbalanced-learn 0.13.0 requires numpy<3,>=1.24.3, but you have numpy 1.24.0 which is incompatible.
albucore 0.0.24 requires numpy>=1.24.4, but you have numpy 1.24.0 which is incompatible.
treescope 0.1.9 requires numpy>=1.25.2, but you have numpy 1.24.0 which is incompatible.
blosc2 3.3.3 requires numpy>=1.26, but you have numpy 1.24.0 which is incompatible.
pymc 5.22.0 requires numpy>=1.25.0, but you have numpy 1.24.0 which is incompatible.
thinc 8.3.6 requires numpy<3.0.0,>=2.0.0, but you have numpy 1.24.0 which is incompatible.
seaborn 0.13.2 requires numpy!=1.24.0,>=1.20, but you have numpy 1.24.0 which is incompati

# Prepare Your Data

For RFT, you need, at minimum, a text dataset with a `prompt` field. Your dataset can optionally have other columns as well. When provided, these columns are accessible within your defined reward functions if you need to access these fields and their values.

For this function calling use case, you can get the dataset from Predibase on HuggingFace Hub.

[How to prepare your data for GRPO](https://docs.predibase.com/user-guide/fine-tuning/grpo#prepare-your-dataset)

In [None]:
import pandas as pd
from datasets import load_dataset

# Load the dataset
dataset = load_dataset("predibase/glaive_function_calling")
train_df = pd.DataFrame(dataset["train"])
train_df.head()

In [None]:
print("SAMPLE PROMPT")
print(train_df['prompt'][2])
print("\n\nSAMPLE TOOL CALL")
print(train_df['tool_call'][2])

In [None]:
train_df.to_json('./glaive_function_calling.jsonl', lines=True, orient='records')

# Set up Your Reward Functions

Reward functions will be used to score your model's generations during training.

All reward functions must follow this function signature:

```
def reward_fn(prompt: str, completion: str, example: dict[str, str]) -> float
```

The prompt is a prompt from your dataset. The completion is one of N of the model's generated outputs for the given prompt. The example is the original data sample from your dataset represented as a dictionary. If you define more than one reward function, you should make sure you give each reward function a unique function name.

[Reward Functions Documentation](https://docs.predibase.com/user-guide/fine-tuning/grpo#defining-reward-functions)

In [None]:
def global_tool_use_reward_func(prompt: str, completion: str, example: dict):
    """Check that the correct tools are being called, max score is 1.0."""
    import re
    import ast
    import json

    rscale = 1.0
    reward = 0.0

    true_tool_call = ast.literal_eval(example['tool_call'])

    try:
        pred_tool_call = re.search(r'<tool>(.*?)</tool>', completion, re.DOTALL).group(1).strip()
        pred_tool_call = ast.literal_eval(pred_tool_call)
        if pred_tool_call != '':
            assert isinstance(pred_tool_call, dict)
            assert 'name' in pred_tool_call
            assert 'arguments' in pred_tool_call
    except Exception as e:
        print(f'Error parsing tool call: {e} for {completion}')
        return reward * rscale

    if pred_tool_call == '':
        pred_tool_call = None

    # If both don't use a tool then we give full credit
    if true_tool_call is None and pred_tool_call is None:
        reward += 1.0
        print(f'(CORRECT - NO TOOL). global tool use reward: {reward}')
        return reward * rscale

    # If both use a tool
    if true_tool_call is not None and pred_tool_call is not None:
        # For using a tool when true_tool_call is not None
        reward += 0.1

        # Why is this necessary and only happens for true_tool_call? Some issue with data upload and reload vs in completion, ast literal eval is enough.
        if isinstance(true_tool_call['arguments'], str):
            true_tool_call['arguments'] = json.loads(true_tool_call['arguments'])

        # Name match
        if true_tool_call['name'] != pred_tool_call['name']:
            print(f'(PARTIAL - NAME) {true_tool_call=}, {pred_tool_call=} did not match. global tool use reward: {reward}')
            return reward * rscale
        reward += 0.4

        # Arguments match
        if true_tool_call['arguments'] != pred_tool_call['arguments']:
            print(f'(PARTIAL - ARGUMENTS) {true_tool_call=}, {pred_tool_call=} did not match. global tool use reward: {reward}')
            return reward * rscale

        reward += 0.5
        print(f'(CORRECT) {true_tool_call=}, {pred_tool_call=} global tool use reward: {reward}')
        return reward * rscale
    else:
        print(f'(INCORRECT - TYPE) {true_tool_call=}, {pred_tool_call=} did not match. global tool use reward: {reward}')
        return reward * rscale

In [None]:
def global_format_reward_func(prompt: str, completion: str, example: dict):
    """Check that the generated text is in the requested format, max score is 1.0."""
    import re
    import ast

    reward = 0.0
    rscale = 0.5

    completion = f'<think>{completion}'

    # Find <think> and </think> tags
    think_start = completion.find('<think>')
    think_end = completion.find('</think>')

    # Find <functioncall> or <no_functioncall> tags
    tool_start = completion.find('<tool>')
    tool_end = completion.find('</tool>')

    if think_start == -1 or think_end == -1:
        print(f'(PARTIAL - FORMAT) missing think or tool tags. format reward: {reward}')
        return reward * rscale

    reward += 0.1

    if not (think_start < think_end < tool_start < tool_end):
        print(f'(PARTIAL - FORMAT) tags present but not in the correct order. format reward: {reward}')
        return reward * rscale

    reward += 0.1

    # Check if there are any stray tags
    think_tags = re.findall(r'</?think>', completion)
    tool_tags = re.findall(r'</?tool>', completion)

    if len(think_tags) != 2 or len(tool_tags) != 2:
        print(f'(PARTIAL - FORMAT) found stray think or tool tags. format reward: {reward}')
        return reward * rscale

    reward += 0.2

    # Check if tool call syntax is valid.
    try:
        pred_tool_call = re.search(r'<tool>(.*?)</tool>', completion, re.DOTALL).group(1).strip()
        pred_tool_call = ast.literal_eval(pred_tool_call)
        if pred_tool_call != '':
            assert isinstance(pred_tool_call, dict)
            assert 'name' in pred_tool_call
            assert 'arguments' in pred_tool_call
    except Exception as e:
        print(f'(PARTIAL - FORMAT) could not parse result. Exception: {e} for {pred_tool_call=}. format reward: {reward}')
        return reward * rscale

    reward += 0.6

    return reward * rscale

In [None]:
def global_length_reward_func(prompt: str, completion: str, example: dict):
    """Set a hard limit on the completion length, max score is 1.0"""
    import math
    norm_length_char = 2000 # 500 tokens * 4 characters per token
    num_chars = len(completion)
    return 1.0 if num_chars <= norm_length_char else 0.0

# Create the Predibase Client and Upload Your Data

In [None]:
from predibase import Predibase
from google.colab import userdata

pb = Predibase(api_token=userdata.get("PREDIBASE_API_TOKEN"))

In [None]:
try:
  dataset = pb.datasets.from_file("./glaive_function_calling.jsonl", name="glaive_function_calling")
except:
  dataset = pb.datasets.get("glaive_function_calling")

# Train with RFT!

Set up an adapter repo and start training your models! To train with RFT, pass a `GRPOConfig` with the name of the base model you would like to fine-tune and your rewards functions. You can also optionally adjust the parameters that control sampling your model--like temperature and the max tokens to generate.

[Supported Base Models List](https://docs.predibase.com/user-guide/fine-tuning/finetuning-models)

[`GRPOConfig` Documentation](https://docs.predibase.com/sdk-guide/SDKv2/ConfigClasses/FineTuningConfig#grpoconfig)

[`RewardFunctionsConfig` Documentation](https://docs.predibase.com/sdk-guide/SDKv2/ConfigClasses/RewardFunctionsConfig)

[`SamplingParamsConfig` Documentation](https://docs.predibase.com/sdk-guide/SDKv2/ConfigClasses/SamplingParamsConfig)


In [None]:
repo = pb.repos.create(name="function-calling-rft", description="Train a function calling model with RFT!", exists_ok=True)

In [None]:
from predibase import GRPOConfig, RewardFunctionsConfig, SamplingParamsConfig

adapter = pb.finetuning.jobs.create(
    config=GRPOConfig(
        base_model="qwen2-5-7b-instruct",
        num_generations=16,
        sampling_params=SamplingParamsConfig(
            max_tokens=1024
        ),
        reward_fns=RewardFunctionsConfig(
            functions={
                "correctness": global_tool_use_reward_func,
                "format": global_format_reward_func,
                "length": global_length_reward_func
            },
        ),
    ),
    dataset=dataset,
    repo=repo,
    description="Function calling model",
)

# Update the Reward Functions

During training, you may notice that the model is not performing well on a particular reward. Maybe the reward function is too restrictive, or there is a bug in its logic. If this is the case, you can update your reward functions in a training job without interrupting training. You just need to submit an updated reward function to the running job with a name matching an existing reward function.

[How to Update a Rewar Function](https://docs.predibase.com/user-guide/fine-tuning/grpo#how-do-i-update-my-reward-functions)


In [None]:
def global_length_reward_func_v2(prompt: str, completion: str, example: dict):
    """Minimize completion length, max score is 1.0"""
    import math
    rscale = 0.5
    norm_length_char = 2000 # 500 tokens * 4 characters per token
    num_chars = len(completion)

    def tanh(x):
        return (math.exp(x) - math.exp(-x)) / (math.exp(x) + math.exp(-x))

    reward =  1 - tanh(num_chars / norm_length_char)

    return reward * rscale

In [None]:
adapter_path = "function-calling-rft/1"

# Get the original config and update it with the new reward function
cfg = pb.adapters.get_config(adapter_path)
cfg.reward_fns["length"] = global_length_reward_func_v2
pb.adapters.update_config(adapter_path, cfg)