In [1]:
!git clone https://github.com/ArgoJ/functional-assistent
%cd functional-assistent

Cloning into 'functional-assistent'...
remote: Enumerating objects: 34, done.[K
remote: Counting objects: 100% (34/34), done.[K
remote: Compressing objects: 100% (22/22), done.[K
remote: Total 34 (delta 10), reused 34 (delta 10), pack-reused 0 (from 0)[K
Receiving objects: 100% (34/34), 15.68 KiB | 2.24 MiB/s, done.
Resolving deltas: 100% (10/10), done.
/content/functional-assistent


In [2]:
!pip install trl

Collecting trl
  Downloading trl-0.26.2-py3-none-any.whl.metadata (11 kB)
Downloading trl-0.26.2-py3-none-any.whl (518 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m518.9/518.9 kB[0m [31m16.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: trl
Successfully installed trl-0.26.2


In [None]:
import json
import glob
import os
import torch

from huggingface_hub import login
from trl import SFTConfig, SFTTrainer
from datasets import Dataset, ClassLabel
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.utils import get_json_schema

import src.functional_gemma.tools as tools
from src.functional_gemma.checker import check_success_rate

In [None]:
base_model = "google/functiongemma-270m-it"
learning_rate = 5e-5

TOOLS = [get_json_schema(tool) for _, tool in tools.__dict__.items() if callable(tool)]
DEFAULT_SYSTEM_MSG = "Du bist ein hilfreicher Assistent, der Funktionsaufrufe mit den folgenden Funktionen durchführen kann. Antworte immer auf Deutsch."

In [None]:
def create_conversation(sample, tool_names=None):
  tool_name = sample["tool_name"]
  if tool_names and isinstance(tool_name, int):
      tool_name = tool_names[tool_name]
  return {
      "messages": [
          {"role": "developer", "content": DEFAULT_SYSTEM_MSG},
          {"role": "user", "content": sample["user_content"]},
          {"role": "assistant", "tool_calls": [{"type": "function", "function": {"name": tool_name, "arguments": sample["tool_arguments"]}}]},
      ],
      "tools": TOOLS
  }

loaded_json = []
for file_path in glob.glob(os.path.abspath("data/*.json")):
    with open(file_path, "r") as f:
        loaded_json.extend(json.load(f))

dataset = Dataset.from_list(loaded_json)

tool_names = sorted(list(set(item["tool_name"] for item in loaded_json)))
dataset = dataset.cast_column("tool_name", ClassLabel(names=tool_names))

original_columns = dataset.column_names
dataset = dataset.map(create_conversation, fn_kwargs={"tool_names": tool_names}, batched=False)

dataset = dataset.train_test_split(test_size=0.2, shuffle=True, stratify_by_column="tool_name")

dataset["train"] = dataset["train"].remove_columns(original_columns)
dataset["test"] = dataset["test"].remove_columns(original_columns)

In [None]:
from google.colab import userdata
hf_token = userdata.get('HF_TOKEN')
login(token=hf_token)

model = AutoModelForCausalLM.from_pretrained(
    base_model,
    dtype="auto",
    device_map="auto",
    attn_implementation="eager"
)
tokenizer = AutoTokenizer.from_pretrained(base_model)

print(f"Device: {model.device}")
print(f"DType: {model.dtype}")

In [None]:
print("\n--- Initial check before training ---")
check_success_rate(dataset["test"], model, tokenizer, TOOLS)

In [None]:
torch_dtype = model.dtype
args = SFTConfig(
    output_dir="./functiongemma-tool-calling-sft",              # directory to save and repository id
    max_length=512,                         # max sequence length for model and packing of the dataset
    packing=False,                          # Groups multiple samples in the dataset into a single sequence
    num_train_epochs=8,                     # number of training epochs
    per_device_train_batch_size=1,          # batch size per device during training
    gradient_checkpointing=False,           # Caching is incompatible with gradient checkpointing
    optim="adamw_torch_fused",              # use fused adamw optimizer
    logging_steps=1,                        # log every step
    #save_strategy="epoch",                  # save checkpoint every epoch
    eval_strategy="epoch",                  # evaluate checkpoint every epoch
    learning_rate=learning_rate,            # learning rate
    fp16=True if torch_dtype == torch.float16 else False,   # use float16 precision
    bf16=True if torch_dtype == torch.bfloat16 else False,  # use bfloat16 precision
    lr_scheduler_type="constant",            # use constant learning rate scheduler
    push_to_hub=False,                        # push model to hub
    # report_to="tensorboard",                 # report metrics to tensorboard
)

In [None]:
trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['test'],
    processing_class=tokenizer,
)

trainer.train()

In [None]:
print("\n--- Check after training ---")
check_success_rate(dataset["test"], model, tokenizer, TOOLS)