In [1]:
import os 
import json 
import json_repair
from dotenv import load_dotenv
from pydantic import BaseModel , Field
from typing import Literal , List


In [2]:
load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")
db_schema_path = os.path.join(os.getcwd(),"data/db_schema.json")

In [3]:
# Entities definition
RouteType = Literal["rag", "sql", "hybrid"]

class OutputField(BaseModel):
    question: str = Field(..., description="User question")
    sql: str = Field(..., description="SQL generated to answer the question")
    route: RouteType = Field(..., description="Routing decision: rag, sql, or hybrid")
    context: str = Field(..., description="Retrieved content if rag or hybrid route is used")

class OutputFields(BaseModel):
    final_output: List[OutputField]


In [4]:
# Load the db schema 
with open(db_schema_path, 'r') as f:
    db_schema_dict = json.load(f)
    
db_schema_string = json.dumps(db_schema_dict, indent=2)


In [5]:
docs = """"
# Northwind Marketing Calendar (1997) 

## Summer Beverages 1997 
- Dates: 1997-06-01 to 1997-06-30 
- Notes: Focus on Beverages and Condiments. 

## Winter Classics 1997 
- Dates: 1997-12-01 to 1997-12-31 
- Notes: Push Dairy Products and Confections for holiday gifting. 

# KPI Definitions 

## Average Order Value (AOV) 
- AOV = SUM(UnitPrice * Quantity * (1 - Discount)) / COUNT(DISTINCT OrderID) 
## Gross Margin 
- GM = SUM((UnitPrice - CostOfGoods) * Quantity * (1 - Discount)) 
- If cost is missing, approximate with category-level average. 

# Catalog Snapshot

- Categories include Beverages, Condiments, Confections, Dairy Products, 
Grains/Cereals, Meat/Poultry, Produce, Seafood. 
- Products map to categories as in the Northwind DB.

# Returns & Policy 
- Perishables (Produce, Seafood, Dairy): 3–7 days. 
- Beverages unopened: 14 days; opened: no returns. 
- Non-perishables: 30 days. 
"""

In [6]:
prompt_messages = [
    {
        "role": "system",
        "content": (
            "You are a Data Engineer. "
            "You generate exactly THREE high-quality training samples  for a MiPro v2 DSPy router that cover the three cases.\n\n"
            "Follow these rules:\n"
            "- Use ONLY information from the DB schema and documents.\n"
            "- Each item must contain: question, sql, route, and context.\n"
            "- Use 'rag' for questions answerable ONLY from documents.\n"
            "- Use 'sql' for questions answerable ONLY from the database.\n"
            "- Use 'hybrid' for questions requiring BOTH.\n"
            "- If route = 'sql', context MUST be an empty string.\n"
            "- If route = 'rag', sql MUST be an empty string.\n"
            "- Output ONLY a valid JSON following the OutputFields schema.\n"
            "NOTE : Make the questions simple that require simple SQL code to answer"
            "Make sure to cover three route cases (rag - hybrid - sql)"
        )
    },
    {
        "role": "user",
        "content": (
            "## Documents for rag, hybrid:\n"
            f"{docs.strip()}\n\n"
            "## Database Schema:\n"
            f"{db_schema_string.strip()}\n\n"
            "## Pydantic Schema:\n"
            f"{json.dumps(OutputFields.model_json_schema(), ensure_ascii=False)}\n\n"
            "## Generated Output:\n"
        )
    }
]


In [7]:
# Parsing json
def parse_json(text):
    try:
        return json_repair.loads(text)
    except:
        return None

# Training data generation

In [8]:
from openai import OpenAI
from tqdm.auto import tqdm


openai_client = OpenAI(
    api_key=api_key,
)
openai_model = "gpt-4o"

price_per_1m_input_tokens = 1.25
price_per_1m_output_tokens = 10

prompt_tokens = 0
completion_tokens = 0

data_dir = os.path.join(os.getcwd(),"data")
save_dir = os.path.join(data_dir,'training_data.jsonl')

iterations = [i for i in range(10)]
i = 0 

for i in tqdm(iterations):
    response = openai_client.chat.completions.create(
                        messages=prompt_messages,
                        model=openai_model,
                    )

    if response.choices[0].finish_reason != "stop":
        prompt_tokens += response.usage.prompt_tokens
        continue

    llm_response = response.choices[0].message.content
    llm_resp_dict = parse_json(llm_response)

    if not llm_resp_dict:
        continue

    with open(save_dir, "a", encoding="utf8") as f:
        f.write(json.dumps({
            "id": i,
            "response": llm_resp_dict,
        }, ensure_ascii=False, default=str)  + "\n" )

    i+= 1
    prompt_tokens += response.usage.prompt_tokens
    completion_tokens += response.usage.completion_tokens


cost_input = (prompt_tokens / 1_000_000) * price_per_1m_input_tokens
cost_output = (completion_tokens / 1_000_000) * price_per_1m_output_tokens
total_cost = cost_input + cost_output

print(f"Total Cost = ${total_cost:.4f} ")        



  from .autonotebook import tqdm as notebook_tqdm
100%|██████████| 10/10 [00:25<00:00,  2.50s/it]

Total Cost = $0.0375 



