<a href="https://colab.research.google.com/github/abhinavsb3/Function-Callingwith-RFT-With-Predibase/blob/main/Function_Calling_with_RFT(with_predibase).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip uninstall numpy -y --quiet
!pip install predibase --quiet
!pip install datasets --quiet

In [None]:
import pandas as pd
from datasets import load_dataset
from predibase import Predibase
from google.colab import userdata
from predibase import GRPOConfig, RewardFunctionsConfig, SamplingParamsConfig

# Load the dataset
dataset = load_dataset("predibase/glaive_function_calling")
train_df = pd.DataFrame(dataset["train"])
train_df.head()

In [None]:
print("Prompt Examples")
print(train_df['prompt'][2])
print("\n\nSAMPLE TOOL CALLING")
print(train_df['tool_call'][2])

In [None]:
train_df.to_json('./glaive_function_calling.jsonl', lines=True, orient='records')

In [None]:
#Tool use Reawrd function
def global_tool_use_reward_func(prompt: str, completion: str, example: dict):
    """Check that the correct tools are being called, max score is 1.0."""
    import re
    import ast
    import json

    rscale = 1.0
    reward = 0.0

    true_tool_call = ast.literal_eval(example['tool_call'])

    try:
        pred_tool_call = re.search(r'<tool>(.*?)</tool>', completion, re.DOTALL).group(1).strip()
        pred_tool_call = ast.literal_eval(pred_tool_call)
        if pred_tool_call != '':
            assert isinstance(pred_tool_call, dict)
            assert 'name' in pred_tool_call
            assert 'arguments' in pred_tool_call
    except Exception as e:
        print(f'Error parsing tool call: {e} for {completion}')
        return reward * rscale

    if pred_tool_call == '':
        pred_tool_call = None

    # If both don't use a tool then we give full credit
    if true_tool_call is None and pred_tool_call is None:
        reward += 1.0
        print(f'(CORRECT - NO TOOL). global tool use reward: {reward}')
        return reward * rscale

    # If both use a tool
    if true_tool_call is not None and pred_tool_call is not None:
        # For using a tool when true_tool_call is not None
        reward += 0.1

        # Why is this necessary and only happens for true_tool_call? Some issue with data upload and reload vs in completion, ast literal eval is enough.
        if isinstance(true_tool_call['arguments'], str):
            true_tool_call['arguments'] = json.loads(true_tool_call['arguments'])

        # Name match
        if true_tool_call['name'] != pred_tool_call['name']:
            print(f'(PARTIAL - NAME) {true_tool_call=}, {pred_tool_call=} did not match. global tool use reward: {reward}')
            return reward * rscale
        reward += 0.4

        # Arguments match
        if true_tool_call['arguments'] != pred_tool_call['arguments']:
            print(f'(PARTIAL - ARGUMENTS) {true_tool_call=}, {pred_tool_call=} did not match. global tool use reward: {reward}')
            return reward * rscale

        reward += 0.5
        print(f'(CORRECT) {true_tool_call=}, {pred_tool_call=} global tool use reward: {reward}')
        return reward * rscale
    else:
        print(f'(INCORRECT - TYPE) {true_tool_call=}, {pred_tool_call=} did not match. global tool use reward: {reward}')
        return reward * rscale

In [None]:
#Format Reward function
def global_format_reward_func(prompt: str, completion: str, example: dict):
    """Check that the generated text is in the requested format, max score is 1.0."""
    import re
    import ast

    reward = 0.0
    rscale = 0.5

    completion = f'<think>{completion}'

    # Find <think> and </think> tags
    think_start = completion.find('<think>')
    think_end = completion.find('</think>')

    # Find <functioncall> or <no_functioncall> tags
    tool_start = completion.find('<tool>')
    tool_end = completion.find('</tool>')

    if think_start == -1 or think_end == -1:
        print(f'(PARTIAL - FORMAT) missing think or tool tags. format reward: {reward}')
        return reward * rscale

    reward += 0.1

    if not (think_start < think_end < tool_start < tool_end):
        print(f'(PARTIAL - FORMAT) tags present but not in the correct order. format reward: {reward}')
        return reward * rscale

    reward += 0.1

    # Check if there are any stray tags
    think_tags = re.findall(r'</?think>', completion)
    tool_tags = re.findall(r'</?tool>', completion)

    if len(think_tags) != 2 or len(tool_tags) != 2:
        print(f'(PARTIAL - FORMAT) found stray think or tool tags. format reward: {reward}')
        return reward * rscale

    reward += 0.2

    # Check if tool call syntax is valid.
    try:
        pred_tool_call = re.search(r'<tool>(.*?)</tool>', completion, re.DOTALL).group(1).strip()
        pred_tool_call = ast.literal_eval(pred_tool_call)
        if pred_tool_call != '':
            assert isinstance(pred_tool_call, dict)
            assert 'name' in pred_tool_call
            assert 'arguments' in pred_tool_call
    except Exception as e:
        print(f'(PARTIAL - FORMAT) could not parse result. Exception: {e} for {pred_tool_call=}. format reward: {reward}')
        return reward * rscale

    reward += 0.6

    return reward * rscale

In [None]:
#Length Reward function
def global_length_reward_func(prompt: str, completion: str, example: dict):
    """Set a hard limit on the completion length, max score is 1.0"""
    import math
    norm_length_char = 2000 # 500 tokens * 4 characters per token
    num_chars = len(completion)
    return 1.0 if num_chars <= norm_length_char else 0.0

In [None]:
#Create the Predibase Client and Upload Your Data
pb = Predibase(api_token=userdata.get("PREDIBASE_API_TOKEN"))
try:
  dataset = pb.datasets.from_file("./glaive_function_calling.jsonl", name="glaive_function_calling")
except:
  dataset = pb.datasets.get("glaive_function_calling")

In [None]:
#Train with RFT!
repo = pb.repos.create(name="function-calling-rft", description="Train a function calling model with RFT!", exists_ok=True)

In [None]:
adapter = pb.finetuning.jobs.create(
    config=GRPOConfig(
        base_model="qwen2-5-7b-instruct",
        num_generations=16,
        sampling_params=SamplingParamsConfig(
            max_tokens=1024
        ),
        reward_fns=RewardFunctionsConfig(
            functions={
                "correctness": global_tool_use_reward_func,
                "format": global_format_reward_func,
                "length": global_length_reward_func
            },
        ),
    ),
    dataset=dataset,
    repo=repo,
    description="Function calling model",
)

In [None]:
#Update the Reward Functions
def global_length_reward_func_v2(prompt: str, completion: str, example: dict):
    """Minimize completion length, max score is 1.0"""
    import math
    rscale = 0.5
    norm_length_char = 2000 # 500 tokens * 4 characters per token
    num_chars = len(completion)

    def tanh(x):
        return (math.exp(x) - math.exp(-x)) / (math.exp(x) + math.exp(-x))

    reward =  1 - tanh(num_chars / norm_length_char)

    return reward * rscale

In [None]:
adapter_path = "function-calling-rft/1"

# Get the original config and update it with the new reward function
cfg = pb.adapters.get_config(adapter_path)
cfg.reward_fns["length"] = global_length_reward_func_v2
pb.adapters.update_config(adapter_path, cfg)