In [14]:
# import the required libraries
import os
import json
from openai import OpenAI
from dotenv import load_dotenv
import gradio as gr
import sqlite3

In [2]:
# Initialize the OpenAI client
load_dotenv(override=True)

openai_api_key = os.getenv("OPENAI_API_KEY")
if not openai_api_key:
    raise ValueError("OPENAI_API_KEY not found in environment variables")
else:
    print("OpenAI API key loaded successfully")

MODEL = "gpt-4.1-mini"
openai = OpenAI()



OpenAI API key loaded successfully


In [3]:
system_message = """
You are a helpful assistant for an Airline called FlightAI.
Give short, courteous answers, no more than 1 sentence.
Always be accurate. If you don't know the answer, say so.
"""

In [22]:
def chat(message, history):
    history = [{"role":h["role"], "content":h["content"]} for h in history]
    messages = [{"role": "system", "content": system_message}] + history + [{"role": "user", "content": message}]
    response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tools)

    while response.choices[0].finish_reason=="tool_calls":
        message = response.choices[0].message
        responses = handle_tool_calls(message)
        messages.append(message)
        messages.extend(responses)
        response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tools)
    
    return response.choices[0].message.content

In [23]:
ticket_prices = {"london": "$799", "paris": "$899", "tokyo": "$1400", "berlin": "$499"}

def get_ticket_price(destination_city):
    print(f"Tool called for city: {destination_city}")
    price = ticket_prices.get(destination_city.lower(), "Unknown ticket price")
    return f"The ticket price to {destination_city} is {price}"
    

In [24]:
price_function = {
    "name": "get_ticket_price",
    "description": "Get the ticket price for a destination city",
    "parameters": {
        "type": "object",
        "properties": {
            "destination_city": {
                "type": "string",
                "description": "The destination city"
            }
        },
        "required": ["destination_city"],
        "additionalProperties": False
    }
}


In [25]:
set_price_function = {
    "name": "set_ticket_price",
    "description": "Set the ticket price for a destination city",
    "parameters": {
        "type": "object",
        "properties": {
            "destination_city": {
                "type": "string",
                "description": "The destination city"
            },
            "price": {
                "type": "integer",
                "description": "The price of the ticket"
            }
        },
        "required": ["destination_city"],
        "additionalProperties": False
    }
}

In [26]:
tools =[
    {"type": "function", "function": price_function},
    {"type": "function", "function": set_price_function}
]

In [27]:
DB = "prices.db"

with sqlite3.connect(DB) as conn:
    cursor = conn.cursor()
    cursor.execute('CREATE TABLE IF NOT EXISTS prices (city TEXT PRIMARY KEY, price REAL)')
    conn.commit()

In [40]:
def set_ticket_price(destination_city, price):
    with sqlite3.connect(DB) as conn:
        cursor = conn.cursor()
        cursor.execute('INSERT INTO prices (city, price) VALUES (?, ?) ON CONFLICT(city) DO UPDATE SET price = ?', (destination_city.lower(), price, price))
        conn.commit()
    return f"Success: The ticket price to {destination_city} has been updated to ${price}."

In [41]:
def get_ticket_price(destination_city):
    print(f"DATABASE TOOL CALLED: Getting price for {destination_city}", flush=True)
    with sqlite3.connect(DB) as conn:
        cursor = conn.cursor()
        cursor.execute('SELECT price FROM prices WHERE city = ?', (destination_city.lower(),))
        result = cursor.fetchone()
        return f"Ticket price to {destination_city} is ${result[0]}" if result else "No price data available for this city"

In [42]:
available_functions ={
    "get_ticket_price": get_ticket_price,
    "set_ticket_price": set_ticket_price,
}

In [43]:
def handle_tool_calls(message):
    responses = []
    for tool_call in message.tool_calls:
        arguments = json.loads(tool_call.function.arguments)
        match tool_call.function.name:
            case "get_ticket_price":
                content = get_ticket_price(**arguments)
            case "set_ticket_price":
                content = set_ticket_price(**arguments)
            case _:
                raise ValueError("Unknown tool call: " + tool_call.function.name)
        responses.append({
            "role": "tool",
            "tool_call_id": tool_call.id,
            "name": tool_call.function.name,
            "content": str(content)
        })
    return responses

In [44]:
gr.ChatInterface(fn=chat).launch()

* Running on local URL:  http://127.0.0.1:7878
* To create a public link, set `share=True` in `launch()`.




DATABASE TOOL CALLED: Getting price for Tokyo
