<a href="https://colab.research.google.com/github/Ifeanyi55/Gemma-Agents/blob/main/Synthetic_Data_Agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# install required libraries
!pip install -q transformers torch
!pip install sdv

In [None]:
# login into Hugging Face Hub with your HF token
from huggingface_hub import login
from google.colab import userdata
hf_token = userdata.get("HF_TOKEN")
login(token=hf_token)

In [None]:
# download function gemma model
from transformers import AutoProcessor, AutoModelForCausalLM

GEMMA_MODEL_ID = "google/functiongemma-270m-it"

processor = AutoProcessor.from_pretrained(GEMMA_MODEL_ID, device_map="auto")
model = AutoModelForCausalLM.from_pretrained(GEMMA_MODEL_ID, dtype="auto", device_map="auto")

## You can generate synthetic data of the following kinds: companies, persons, addresses, places, users, credit cards, books, and text. These are the categories that the [Faker API](https://fakerapi.it/) in the custom function permits.

In [None]:
# build agent for generating synthetic data
from sdv.single_table import GaussianCopulaSynthesizer
from sdv.metadata import SingleTableMetadata
import pandas as pd
import requests
import json
import re

# define function schema
synthetic_data_function_schema = {
    "type": "function",
    "function": {
        "name": "synthetic_data_generator",
        "description": "Generates synthetic data points.",
        "parameters": {
            "type": "object",
            "properties": {
                "category": {
                    "type": "string",
                    "description": "The category of data to generate, e.g. persons",
                },
                "quantity": {
                    "type": "integer",
                    "description": "The quantity of data points to generate, e.g. 1000",
                }
            },
            "required": ["category","quantity"],
        },
    }
}

# build the function
def synthetic_data_generator(category:str,quantity:int) -> pd.DataFrame:
  """
  A function for generating synthetic data points
  Args:
      category: The category of data to generate
      quantity: The number of data points to generate
  Returns:
          A dataframe of synthetic data points.
  """
  # call Faker API
  url = f"https://fakerapi.it/api/v2/{category}?_quantity=200"

  resp = requests.get(url)

  df = pd.DataFrame(resp.json()["data"])

  # drop 'address' or 'addresses' column from the dataframe
  if "address" in df.columns:
    df = df.drop(columns=["address"])
  elif "addresses" in df.columns:
    df = df.drop(columns=["addresses"])
  else:
    metadata = SingleTableMetadata()

    # extract metatdata from the table
    metadata.detect_from_dataframe(df)

    synthesizer = GaussianCopulaSynthesizer(metadata)

    # train ML model for generating synthetic data
    synthesizer.fit(data=df)

    # generate synthetic data
    synthetic_data = synthesizer.sample(num_rows=quantity)

    return synthetic_data

# add function to a registry of tools
tools = {
    "synthetic_data_generator":synthetic_data_generator
}

# model's turn
message = [
    {
        "role": "developer",
        # it is important to include this system prompt to enable the model to call tools.
        "content": "You are a model that can do function calling with the following functions"
    },
    {
        "role": "user",
        "content": "Can you generate 8000 synthetic data points of different addresses?"
    }
]

inputs = processor.apply_chat_template(message, tools = [synthetic_data_function_schema],add_generation_prompt=True, return_dict=True, return_tensors="pt")

out = model.generate(**inputs.to(model.device), pad_token_id=processor.eos_token_id, max_new_tokens=128)
output = processor.decode(out[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)

# developer's turn
def extract_tool_calls(text):
    def cast(v):
        try: return int(v)
        except:
            try: return float(v)
            except: return {'true': True, 'false': False}.get(v.lower(), v.strip("'\""))

    return [{
        "name": name,
        "arguments": {
            k: cast((v1 or v2).strip())
            for k, v1, v2 in re.findall(r"(\w+):(?:<escape>(.*?)<escape>|([^,}]*))", args)
        }
    } for name, args in re.findall(r"<start_function_call>call:(\w+)\{(.*?)\}<end_function_call>", text, re.DOTALL)]

# extract the tool call output
calls = extract_tool_calls(output)

if calls:
    message.append({
        "role": "assistant",
        "tool_calls": [{"type": "function", "function": call} for call in calls]
    })
    print(message[-1])

    # call the function and get the result
    results = [
        {"name": c['name'], "response": tools[c['name']](**c['arguments'])}
        for c in calls
    ]

print(results[0]["response"])
