In [1]:
from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate
from langchain_experimental.tabular_synthetic_data.base import (
    SyntheticDataGenerator,
)
from langchain_experimental.tabular_synthetic_data.prompts import (
    SYNTHETIC_FEW_SHOT_PREFIX,
    SYNTHETIC_FEW_SHOT_SUFFIX,
)
from langchain_ollama import OllamaLLM, ChatOllama

import time

from pydantic import BaseModel
from typing import Dict
import json
import re

In [2]:
class Answer(BaseModel):
    function: str
    arguments: Dict[str, str]


class func_calls(BaseModel):
    query: str
    answer: Answer

In [3]:
llm = OllamaLLM(model="llama3.1", temperature=0.7)

In [50]:
examples = [
    {"example": """{{"query": "What is your name", "answer": {{"function": "no_such_function", "arguments": {{}}}}}}"""},
    {"example": """{{"query": "Where are you from", "answer": {{"function": "no_such_function", "arguments": {{}}}}}}"""},
    {"example": """{{"query": "Change the channel to 5", "answer": {{"function": "no_such_function", "arguments": {{}}}}}}"""},
    {"example": """{{"query": "What is 5 divided by 7", "answer": {{"function": "no_such_function", "arguments": {{}}}}}}"""},
    {"example": """{{"query": "What phase is the moon going to be on 10/29/24", "answer": {{"function": "no_such_function", "arguments": {{}}}}}}"""},
    {"example": """{{"query": "When is the next solar eclipse", "answer": {{"function": "no_such_function", "arguments": {{}}}}}}"""},
    {"example": """{{"query": "How many calories should I eat to loose 1/2lbs every two weeks as a man who weights 170lbs", "answer": {{"function": "no_such_function", "arguments": {{}}}}}}"""}
]

In [53]:
OPENAI_TEMPLATE = PromptTemplate.from_template(template="{example}")

prompt_template = FewShotPromptTemplate(
    prefix=SYNTHETIC_FEW_SHOT_PREFIX,
    examples=examples,
    suffix=SYNTHETIC_FEW_SHOT_SUFFIX, 
    input_variables=["subject", "extra"],
    example_prompt=OPENAI_TEMPLATE,
)

synthetic_data_generator = SyntheticDataGenerator(template=prompt_template, llm=llm, output_schema=func_calls)

def generate_synthetic_data(runs=1):
    """
    Generates synthetic data by invoking the synthetic_data_generator with specified parameters.
    Args:
        runs (int, optional): The number of times to run the data generation process. Defaults to 1 as anything more currently breaks the pipeline.
    Returns:
        list: A list of synthetic results generated by the synthetic_data_generator. Len(synthetic_results) == runs.
    Example:
        synthetic_results = generate_synthetic_data(runs=1)
    """
    start_time = time.time()
    synthetic_results = synthetic_data_generator.generate(
        subject = "natural language queries for functions that don't exist.",
        extra = "make the queries very unique and interesting. the arguments must be chosen at random. Don't report the arguments. Don't make chit-chat and don't have an introduction. Generate 3 examples",
        runs=runs,
    )
    end_time = time.time()

    print(f"It took {end_time - start_time:.2f}seconds to generate data")
    return synthetic_results

data = generate_synthetic_data(runs=1)
print(data)

It took 3.20seconds to generate data
['{"query": "What is the average airspeed velocity of a snail on Mars", "answer": {"function": "no_such_function", "arguments": {}}}\n\n{"query": "Can you generate a poem about a cat that\'s 427 years old and has seen every historical event since the pyramids were built", "answer": {"function": "no_such_function", "arguments": {}}}\n\n{"query": "What is the probability of a coin landing on its edge on a surface in outer space while being held by an astronaut wearing a spacesuit made of a material that\'s 75% cotton and 25% silk", "answer": {"function": "no_such_function", "arguments": {}}}']


In [9]:
def fix_json_quotes(json_str):
    json_str = re.sub(r'\"([a-zA-Z]+)\"', r'\\\"\1\\\"', json_str)
    json_str = re.sub(r"'([a-zA-Z])'", r'"\1"', json_str)

    return json_str

def validate_and_fix_json(json_str):
    try:
        parsed_data = json.loads(json_str)
        return parsed_data
    except json.JSONDecodeError:
        fixed_json_str = fix_json_quotes(json_str)
        try:
            parsed_data = json.loads(fixed_json_str)
            return parsed_data
        except json.JSONDecodeError as e:
            print(f"Failed to fix JSON: {e}")
            return None

In [11]:
json_strings = data[0].split('\n\n')

parsed_json_objects = []
for json_str in json_strings:
    fixed_json = validate_and_fix_json(json_str)
    if fixed_json:
        parsed_json_objects.append(fixed_json)
    else:
        print(f"Skipping invalid JSON string: {json_str}")

combined_json = json.dumps({"queries": parsed_json_objects}, indent=4)

In [12]:
with open('dump_json/failed.json', 'w') as json_file:
    json_file.write(combined_json)