<a href="https://colab.research.google.com/github/Ifeanyi55/Gemma-Agents/blob/main/Synthetic_Data_Agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# install required libraries
!pip install -q transformers torch
!pip install sdv

In [None]:
# login into Hugging Face Hub with your HF token
from huggingface_hub import login
from google.colab import userdata
hf_token = userdata.get("HF_TOKEN")
login(token=hf_token)

In [None]:
# download function gemma model
from transformers import AutoProcessor, AutoModelForCausalLM

GEMMA_MODEL_ID = "google/functiongemma-270m-it"

processor = AutoProcessor.from_pretrained(GEMMA_MODEL_ID, device_map="auto")
model = AutoModelForCausalLM.from_pretrained(GEMMA_MODEL_ID, dtype="auto", device_map="auto")

## You can generate synthetic data of the following kinds: companies, persons, addresses, places, users, credit cards (spell as creditCards in your prompt), books, and text. These are the categories that the [Faker API](https://fakerapi.it/) in the custom function permits.

In [None]:
# build agent for generating synthetic data
from sdv.single_table import GaussianCopulaSynthesizer
from sdv.metadata import SingleTableMetadata
import pandas as pd
import requests
import json
import re

# define function schema
synthetic_data_function_schema = {
    "type": "function",
    "function": {
        "name": "synthetic_data_generator",
        "description": "Generates synthetic data points.",
        "parameters": {
            "type": "object",
            "properties": {
                "category": {
                    "type": "string",
                    "description": "The category of data to generate, e.g. persons",
                },
                "quantity": {
                    "type": "integer",
                    "description": "The quantity of data points to generate, e.g. 1000",
                }
            },
            "required": ["category","quantity"],
        },
    }
}

# build the function
def synthetic_data_generator(category:str,quantity:int) -> pd.DataFrame:
  """
  A function for generating synthetic data points
  Args:
      category: The category of data to generate
      quantity: The number of data points to generate
  Returns:
          A dataframe of synthetic data points.
  """
  # call Faker API
  url = f"https://fakerapi.it/api/v2/{category}?_quantity=200"

  resp = requests.get(url)

  df = pd.DataFrame(resp.json()["data"])

  # drop 'address' or 'addresses' column from the dataframe
  if "address" in df.columns:
    df = df.drop(columns=["address"])
  if "addresses" in df.columns and "contact" in df.columns:
    df = df.drop(columns=["addresses","contact"])

  metadata = SingleTableMetadata()

  # extract metatdata from the table
  metadata.detect_from_dataframe(df)

  synthesizer = GaussianCopulaSynthesizer(metadata)

  # train ML model for generating synthetic data
  synthesizer.fit(data=df)

  # generate synthetic data
  synthetic_data = synthesizer.sample(num_rows=quantity)

  return synthetic_data

# add function to a registry of tools
tools = {
    "synthetic_data_generator":synthetic_data_generator
}

# model's turn
message = [
    {
        "role": "developer",
        # it is important to include this system prompt to enable the model to call tools.
        "content": "You are a model that can do function calling with the following functions"
    },
    {
        "role": "user",
        "content": "Can you generate 8000 synthetic data points of different addresses?"
    }
]

inputs = processor.apply_chat_template(message, tools = [synthetic_data_function_schema],add_generation_prompt=True, return_dict=True, return_tensors="pt")

out = model.generate(**inputs.to(model.device), pad_token_id=processor.eos_token_id, max_new_tokens=128)
output = processor.decode(out[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)

# developer's turn
def extract_tool_calls(text):
    def cast(v):
        try: return int(v)
        except:
            try: return float(v)
            except: return {'true': True, 'false': False}.get(v.lower(), v.strip("'\""))

    return [{
        "name": name,
        "arguments": {
            k: cast((v1 or v2).strip())
            for k, v1, v2 in re.findall(r"(\w+):(?:<escape>(.*?)<escape>|([^,}]*))", args)
        }
    } for name, args in re.findall(r"<start_function_call>call:(\w+)\{(.*?)\}<end_function_call>", text, re.DOTALL)]

# extract the tool call output
calls = extract_tool_calls(output)

if calls:
    message.append({
        "role": "assistant",
        "tool_calls": [{"type": "function", "function": call} for call in calls]
    })
    print(message[-1])

    # call the function and get the result
    results = [
        {"name": c['name'], "response": tools[c['name']](**c['arguments'])}
        for c in calls
    ]

print(results[0]["response"])


# **Launch Web UI**

In [None]:
from sdv.single_table import GaussianCopulaSynthesizer
from sdv.metadata import SingleTableMetadata
import pandas as pd
import gradio as gr
import requests
import json
import re

def dataAgent(prompt):
    try:
      # define function schema
      synthetic_data_function_schema = {
          "type": "function",
          "function": {
              "name": "synthetic_data_generator",
              "description": "Generates synthetic data points.",
              "parameters": {
                  "type": "object",
                  "properties": {
                      "category": {
                          "type": "string",
                          "description": "The category of data to generate, e.g. persons",
                      },
                      "quantity": {
                          "type": "integer",
                          "description": "The quantity of data points to generate, e.g. 1000",
                      }
                  },
                  "required": ["category","quantity"],
              },
          }
      }


      # build the function
      def synthetic_data_generator(category:str,quantity:int) -> pd.DataFrame:
        """
        A function for generating synthetic data points
        Args:
            category: The category of data to generate
            quantity: The number of data points to generate
        Returns:
                A dataframe of synthetic data points.
        """
        url = f"https://fakerapi.it/api/v2/{category}?_quantity=200"

        resp = requests.get(url)

        df = pd.DataFrame(resp.json()["data"])

        # drop 'address' or 'addresses' column from the dataframe
        if "address" in df.columns:
          df = df.drop(columns=["address"])
        if "addresses" in df.columns and "contact" in df.columns:
          df = df.drop(columns=["addresses","contact"])

        metadata = SingleTableMetadata()

        # extract metatdata from the table
        metadata.detect_from_dataframe(df)

        synthesizer = GaussianCopulaSynthesizer(metadata)

        synthesizer.fit(data=df)

        synthetic_data = synthesizer.sample(num_rows=quantity)

        return synthetic_data


      # add function to a registry of tools
      tools = {
          "synthetic_data_generator":synthetic_data_generator
      }

      # model's turn
      message = [
          {
              "role": "developer",
              "content": "You are a model that can do function calling with the following functions"
          },
          {
              "role": "user",
              "content": prompt
          }
      ]

      inputs = processor.apply_chat_template(message, tools = [synthetic_data_function_schema],add_generation_prompt=True, return_dict=True, return_tensors="pt")

      out = model.generate(**inputs.to(model.device), pad_token_id=processor.eos_token_id, max_new_tokens=128)
      output = processor.decode(out[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)

      # developer's turn
      def extract_tool_calls(text):
          def cast(v):
              try: return int(v)
              except:
                  try: return float(v)
                  except: return {'true': True, 'false': False}.get(v.lower(), v.strip("'\""))

          return [{
              "name": name,
              "arguments": {
                  k: cast((v1 or v2).strip())
                  for k, v1, v2 in re.findall(r"(\w+):(?:<escape>(.*?)<escape>|([^,}]*))", args)
              }
          } for name, args in re.findall(r"<start_function_call>call:(\w+)\{(.*?)\}<end_function_call>", text, re.DOTALL)]

      # extract the tool call output
      calls = extract_tool_calls(output)

      if calls:
          message.append({
              "role": "assistant",
              "tool_calls": [{"type": "function", "function": call} for call in calls]
          })
          print(message[-1])

          # call the function and get the result
          results = [
              {"name": c['name'], "response": tools[c['name']](**c['arguments'])}
              for c in calls
          ]

          if results is not None:
              csv_path = "synthetic_data.csv"
              results[0]["response"].to_csv(csv_path, index=False)
              return results[0]["response"], csv_path

      return results[0]["response"], None

    except Exception as e:
      return gr.Info("Please make sure the data category you specified in your prompt is supported by the Faker API!")

#UI
# custom CSS
css = """
#generate_btn {
  background-color: #0078d0;
  border: 0;
  border-radius: 56px;
  color: black;
  cursor: pointer;
  display: inline-block;
  font-family: system-ui,-apple-system,system-ui,"Segoe UI",Roboto,Ubuntu,"Helvetica Neue",sans-serif;
  font-size: 18px;
  font-weight: 600;
  outline: 0;
  padding: 16px 21px;
  position: relative;
  text-align: center;
  text-decoration: none;
  transition: all .3s;
  user-select: none;
  -webkit-user-select: none;
  touch-action: manipulation;
}

#generate_btn:before {
  background-color: initial;
  background-image: linear-gradient(#fff 0, rgba(255, 255, 255, 0) 100%);
  border-radius: 125px;
  content: "";
  height: 50%;
  left: 4%;
  opacity: .5;
  position: absolute;
  top: 0;
  transition: all .3s;
  width: 92%;
}

#generate_btn:hover {
  box-shadow: rgba(255, 255, 255, .2) 0 3px 15px inset, rgba(0, 0, 0, .1) 0 3px 5px, rgba(0, 0, 0, .1) 0 10px 13px;
  transform: scale(1.05);
}

@media (min-width: 768px) {
  #generate_btn {
    padding: 16px 48px;
  }
}

#clear_btn {
  background-color: red;
  border: 0;
  border-radius: 56px;
  color: black;
  cursor: pointer;
  display: inline-block;
  font-family: system-ui,-apple-system,system-ui,"Segoe UI",Roboto,Ubuntu,"Helvetica Neue",sans-serif;
  font-size: 18px;
  font-weight: 600;
  outline: 0;
  padding: 16px 21px;
  position: relative;
  text-align: center;
  text-decoration: none;
  transition: all .3s;
  user-select: none;
  -webkit-user-select: none;
  touch-action: manipulation;
}

#clear_btn:before {
  background-color: initial;
  background-image: linear-gradient(#fff 0, rgba(255, 255, 255, 0) 100%);
  border-radius: 125px;
  content: "";
  height: 50%;
  left: 4%;
  opacity: .5;
  position: absolute;
  top: 0;
  transition: all .3s;
  width: 92%;
}

#clear_btn:hover {
  box-shadow: rgba(255, 255, 255, .2) 0 3px 15px inset, rgba(0, 0, 0, .1) 0 3px 5px, rgba(0, 0, 0, .1) 0 10px 13px;
  transform: scale(1.05);
}

@media (min-width: 768px) {
  #clear_btn {
    padding: 16px 48px;
  }
}

#download_btn {
  background-color: yellow;
  border: 0;
  border-radius: 56px;
  color: black;
  cursor: pointer;
  display: inline-block;
  font-family: system-ui,-apple-system,system-ui,"Segoe UI",Roboto,Ubuntu,"Helvetica Neue",sans-serif;
  font-size: 18px;
  font-weight: 600;
  outline: 0;
  padding: 16px 21px;
  position: relative;
  text-align: center;
  text-decoration: none;
  transition: all .3s;
  user-select: none;
  -webkit-user-select: none;
  touch-action: manipulation;
}

#download_btn:before {
  background-color: initial;
  background-image: linear-gradient(#fff 0, rgba(255, 255, 255, 0) 100%);
  border-radius: 125px;
  content: "";
  height: 50%;
  left: 4%;
  opacity: .5;
  position: absolute;
  top: 0;
  transition: all .3s;
  width: 92%;
}

#download_btn:hover {
  box-shadow: rgba(255, 255, 255, .2) 0 3px 15px inset, rgba(0, 0, 0, .1) 0 3px 5px, rgba(0, 0, 0, .1) 0 10px 13px;
  transform: scale(1.05);
}

@media (min-width: 768px) {
  #download_btn {
    padding: 16px 48px;
  }
}

#data {
  border: 5px solid #0078d0;
  border-radius:10px;
}

#text_input {
  border: 5px solid green;
  border-radius:10px;
}

#frame,
#frame .MuiDataGrid-root,
#frame .dataframe,
#frame .gr-dataframe,
#frame table {
  background-color: #ffffff !important;
  color: #000000 !important;
  border: 2px solid #0078d0 !important;
  border-radius: 8px !important;
  font-size: 14px;
}

/* Headers for both MUI and standard tables */
#frame .MuiDataGrid-columnHeaders,
#frame thead,
#frame th {
  background-color: #f2f8ff !important;
  color: #000000 !important;
  border-bottom: 2px solid #0078d0 !important;
  font-weight: 600 !important;
}

#frame .MuiDataGrid-columnHeader {
  border-right: 1px solid #0078d0 !important;
}

/* Cells for both MUI and standard tables */
#frame .MuiDataGrid-cell,
#frame td {
  background-color: #ffffff !important;
  color: #000000 !important;
  border-bottom: 1px solid #0078d0 !important;
  border-right: 1px solid #0078d0 !important;
  padding: 8px 12px !important;
}

#frame .MuiDataGrid-cell:last-of-type {
  border-right: none !important;
}

/* Rows for both MUI and standard tables */
#frame .MuiDataGrid-row,
#frame tr {
  background-color: #ffffff !important;
}

#frame .MuiDataGrid-row:nth-of-type(even),
#frame tr:nth-of-type(even) {
  background-color: #f5faff !important;
}

#frame .MuiDataGrid-row:hover,
#frame tr:hover {
  background-color: #e6f2ff !important;
}

#frame .MuiDataGrid-row:nth-of-type(even):hover,
#frame tr:nth-of-type(even):hover {
  background-color: #e6f2ff !important;
}

#frame .MuiDataGrid-footerContainer {
  background-color: #f2f8ff !important;
  color: #000000 !important;
  border-top: 2px solid #0078d0 !important;
}

/* Gradio-specific dataframe styling */
#frame .gr-dataframe,
#frame .gr-dataframe table,
#frame .gr-dataframe tbody,
#frame .gr-dataframe thead {
  background-color: #ffffff !important;
  color: #000000 !important;
}

#frame .gr-dataframe th {
  background-color: #f2f8ff !important;
  color: #000000 !important;
  border: 1px solid #0078d0 !important;
}

#frame .gr-dataframe td {
  background-color: #ffffff !important;
  color: #000000 !important;
  border: 1px solid #0078d0 !important;
}

#frame .gr-dataframe tr:nth-child(even) td {
  background-color: #f5faff !important;
}

#frame .gr-dataframe tr:hover td {
  background-color: #e6f2ff !important;
}

/* Force override any default styling */
#frame * {
  box-sizing: border-box;
}

#frame .gr-dataframe .gr-button {
  background-color: #e6f2ff !important;
  color: #000000 !important;
  border: 1px solid #0078d0 !important;
}

#frame .gr-dataframe .gr-button:hover {
  background-color: #d9ecff !important;
}

/* Scrollbar styling */
#frame .MuiDataGrid-virtualScroller::-webkit-scrollbar {
  width: 8px;
  height: 8px;
}

#frame .MuiDataGrid-virtualScroller::-webkit-scrollbar-track {
  background: #f2f8ff;
}

#frame .MuiDataGrid-virtualScroller::-webkit-scrollbar-thumb {
  background: #0078d0;
  border-radius: 4px;
}

#frame .MuiDataGrid-virtualScroller::-webkit-scrollbar-thumb:hover {
  background: #005da3;
}

/* Selection styling */
#frame .MuiDataGrid-row.Mui-selected {
  background-color: #cfe9ff !important;
}

#frame .MuiDataGrid-row.Mui-selected:hover {
  background-color: #bfe1ff !important;
}

#frame .MuiDataGrid-cell.Mui-selected {
  background-color: #cfe9ff !important;
}

/* Sorting and filtering icons */
#frame .MuiDataGrid-iconSeparator {
  color: #0078d0 !important;
}

#frame .MuiDataGrid-sortIcon {
  color: #0078d0 !important;
}

#frame .MuiDataGrid-filterIcon {
  color: #0078d0 !important;
}

"""

with gr.Blocks(title="Data Agent",css = css) as app:
    gr.HTML("""
      <h1 style='text-align:center;font-size:40px;color:green;''><strong>Data Agent</strong></h1>
      <br>
      <p style='text-align:center;font-size:17px;'>
        <strong>
          You can generate synthetic data of the following kinds:
          companies, persons, addresses, places, users, credit cards (spell as creditCards in your prompt), books, and texts.
        </strong>
      </p>
      """)

    with gr.Row():
      with gr.Row(elem_id="text_input"):
        text_input = gr.Text(
            label="Instruction",
            placeholder="Enter your instruction here",
            scale=1
        )
      generate_btn = gr.Button(
          "Generate",
          elem_id="generate_btn",
          variant="primary",
          scale=1
      )

    with gr.Row(elem_id="data"):
      data_output = gr.DataFrame(
        elem_id = "frame",
        label="Synthetic Data Generated",
        show_search=True,
        interactive=True,
        buttons=["fullscreen","copy"]
    )

    with gr.Row():
        download_btn = gr.DownloadButton(
            "Download CSV",
            elem_id="download_btn",
            variant="primary",
            scale=1
        )
        clear_btn = gr.Button(
            "Clear",
            elem_id="clear_btn",
            variant="stop",
            scale=1
        )

    clear_btn.click(
        fn=lambda: (None, None),
        inputs=None,
        outputs=[text_input, data_output]
    )

    generate_btn.click(
        fn=dataAgent,
        inputs=text_input,
        outputs=[data_output,download_btn]
    )

if __name__ == "__main__":
  app.launch(debug = True)