In [25]:
import json
import openai
import requests
from tenacity import retry, wait_random_exponential, stop_after_attempt
from termcolor import colored

GPT_MODEL = "gpt-3.5-turbo-0613"
openai.api_key="sk-xx"


Utilities

First let's define a few utilities for making calls to the Chat Completions API and for maintaining and keeping track of the conversation state.


In [26]:
@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))
def chat_completion_request(messages, functions=None, function_call=None, model=GPT_MODEL):
    headers = {
        "Content-Type": "application/json",
        "Authorization": "Bearer " + openai.api_key,
    }
    json_data = {"model": model, "messages": messages}
    if functions is not None:
        json_data.update({"functions": functions})
    if function_call is not None:
        json_data.update({"function_call": function_call})
    try:
        response = requests.post(
            "https://api.openai.com/v1/chat/completions",
            headers=headers,
            json=json_data,
        )
        return response
    except Exception as e:
        print("Unable to generate ChatCompletion response")
        print(f"Exception: {e}")
        return e

In [23]:
def pretty_print_conversation(messages):
    role_to_color = {
        "system": "red",
        "user": "green",
        "assistant": "blue",
        "function": "magenta",
    }
    formatted_messages = []
    for message in messages:
        if message["role"] == "system":
            formatted_messages.append(f"system: {message['content']}\n")
        elif message["role"] == "user":
            formatted_messages.append(f"user: {message['content']}\n")
        elif message["role"] == "assistant" and message.get("function_call"):
            formatted_messages.append(f"assistant: {message['function_call']}\n")
        elif message["role"] == "assistant" and not message.get("function_call"):
            formatted_messages.append(f"assistant: {message['content']}\n")
        elif message["role"] == "function":
            formatted_messages.append(f"function ({message['name']}): {message['content']}\n")
    for formatted_message in formatted_messages:
        print(
            colored(
                formatted_message,
                role_to_color[messages[formatted_messages.index(formatted_message)]["role"]],
            )
        )


Basic concepts

Let's create some function specifications to interface with a hypothetical weather API. We'll pass these function specification to the Chat Completions API in order to generate function arguments that adhere to the specification.


In [4]:
functions = [
    {
        "name": "get_current_weather",
        "description": "Get the current weather",
        "parameters": {
            "type": "object",
            "properties": {
                "location": {
                    "type": "string",
                    "description": "The city and state, e.g. San Francisco, CA",
                },
                "format": {
                    "type": "string",
                    "enum": ["celsius", "fahrenheit"],
                    "description": "The temperature unit to use. Infer this from the users location.",
                },
            },
            "required": ["location", "format"],
        },
    },
    {
        "name": "get_n_day_weather_forecast",
        "description": "Get an N-day weather forecast",
        "parameters": {
            "type": "object",
            "properties": {
                "location": {
                    "type": "string",
                    "description": "The city and state, e.g. San Francisco, CA",
                },
                "format": {
                    "type": "string",
                    "enum": ["celsius", "fahrenheit"],
                    "description": "The temperature unit to use. Infer this from the users location.",
                },
                "num_days": {
                    "type": "integer",
                    "description": "The number of days to forecast",
                }
            },
            "required": ["location", "format", "num_days"]
        },
    },
]

In [8]:
messages = []
messages.append({"role": "system", "content": "Don't make assumptions about what values \
                 to plug into functions. Ask for clarification if a user request is ambiguous."})
messages.append({"role": "user", "content": "What's the weather like today"})
chat_response = chat_completion_request(
    messages, functions=functions
)
assistant_message = chat_response.json()["choices"][0]["message"]
messages.append(assistant_message)
assistant_message

{'role': 'assistant',
 'content': 'Sure! Can you please provide me with the name of the city or state?'}

In [9]:
messages.append({"role": "user", "content": "I'm in Glasgow, Scotland."})
chat_response = chat_completion_request(
    messages, functions=functions
)
assistant_message = chat_response.json()["choices"][0]["message"]
messages.append(assistant_message)
assistant_message

{'role': 'assistant',
 'content': None,
 'function_call': {'name': 'get_current_weather',
  'arguments': '{\n  "location": "Glasgow, Scotland",\n  "format": "celsius"\n}'}}

In [10]:


messages = []
messages.append({"role": "system", "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."})
messages.append({"role": "user", "content": "what is the weather going to be like in Glasgow, Scotland over the next x days"})
chat_response = chat_completion_request(
    messages, functions=functions
)
assistant_message = chat_response.json()["choices"][0]["message"]
messages.append(assistant_message)
assistant_message



{'role': 'assistant',
 'content': 'Sure, I can help you with that. Please provide the number of days you would like to forecast for.'}

In [11]:
messages.append({"role": "user", "content": "5 days"})
chat_response = chat_completion_request(
    messages, functions=functions
)
chat_response.json()["choices"][0]

{'index': 0,
 'message': {'role': 'assistant',
  'content': None,
  'function_call': {'name': 'get_n_day_weather_forecast',
   'arguments': '{\n  "location": "Glasgow, Scotland",\n  "format": "celsius",\n  "num_days": 5\n}'}},
 'finish_reason': 'function_call'}

In [12]:
chat_response.json()

{'id': 'chatcmpl-7TG1umIDus6cyqStxEEZLCyPurxzU',
 'object': 'chat.completion',
 'created': 1687207486,
 'model': 'gpt-3.5-turbo-0613',
 'choices': [{'index': 0,
   'message': {'role': 'assistant',
    'content': None,
    'function_call': {'name': 'get_n_day_weather_forecast',
     'arguments': '{\n  "location": "Glasgow, Scotland",\n  "format": "celsius",\n  "num_days": 5\n}'}},
   'finish_reason': 'function_call'}],
 'usage': {'prompt_tokens': 234, 'completion_tokens': 39, 'total_tokens': 273}}


Forcing the use of specific functions or no function

We can force the model to use a specific function, for example get_n_day_weather_forecast by using the function_call argument. By doing so, we force the model to make assumptions about how to use it.


In [13]:
# in this cell we force the model to use get_n_day_weather_forecast
messages = []
messages.append({"role": "system", "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."})
messages.append({"role": "user", "content": "Give me a weather report for Toronto, Canada."})
chat_response = chat_completion_request(
    messages, functions=functions, function_call={"name": "get_n_day_weather_forecast"}
)
chat_response.json()["choices"][0]["message"]


{'role': 'assistant',
 'content': None,
 'function_call': {'name': 'get_n_day_weather_forecast',
  'arguments': '{\n  "location": "Toronto, Canada",\n  "format": "celsius",\n  "num_days": 1\n}'}}

In [14]:
# if we don't force the model to use get_n_day_weather_forecast it may not
messages = []
messages.append({"role": "system", "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."})
messages.append({"role": "user", "content": "Give me a weather report for Toronto, Canada."})
chat_response = chat_completion_request(
    messages, functions=functions
)
chat_response.json()["choices"][0]["message"]

{'role': 'assistant',
 'content': None,
 'function_call': {'name': 'get_current_weather',
  'arguments': '{\n  "location": "Toronto, Canada",\n  "format": "celsius"\n}'}}

In [15]:


messages = []
messages.append({"role": "system", "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."})
messages.append({"role": "user", "content": "Give me the current weather (use Celcius) for Toronto, Canada."})
chat_response = chat_completion_request(
    messages, functions=functions, function_call="none"
)
chat_response.json()["choices"][0]["message"]



{'role': 'assistant',
 'content': '{\n  "location": "Toronto, Canada",\n  "format": "celsius"\n}'}


How to call functions with model generated arguments

In our next example, we'll demonstrate how to execute functions whose inputs are model-generated, and use this to implement an agent that can answer questions for us about a database. For simplicity we'll use the Chinook sample database.

Note: SQL generation can be high-risk in a production environment since models are not perfectly reliable at generating correct SQL.
Specifying a function to execute SQL queries

First let's define some helpful utility functions to extract data from a duckdb database.


In [2]:
import duckdb
conn = duckdb.connect(database='itineraries.duckdb')
print("connected to db successfully")

connected to db successfully


In [16]:
def get_table_names(conn):
    """Return a list of column names."""    
    table_names= []
    tables = conn.execute("SELECT table_name FROM information_schema.tables")
    for table in tables.fetchall():
        table_names.append(table[0])
    return table_names

def get_column_names(conn, table_name):
    """Return a list of column names."""    
    column_names = []
    columns = conn.execute("SELECT column_name from information_schema.columns")
    for column in columns.fetchall():
        column_names.append(column[0])
    return column_names

def get_database_info(conn):
    """Return a list of dicts containing the table name and columns for each table in the database."""
    table_dicts = []
    for table_name in get_table_names(conn):
        columns_names = get_column_names(conn, table_name)
        table_dicts.append({"table_name": table_name, "column_names": columns_names})
    return table_dicts

In [18]:
#Now can use these utility functions to extract a representation of the database schema.

database_schema_dict = get_database_info(conn)
database_schema_string = "\n".join(
    [
        f"Table: {table['table_name']}\nColumns: {', '.join(table['column_names'])}"
        for table in database_schema_dict
    ]
)




As before, we'll define a function specification for the function we'd like the API to generate arguments for. Notice that we are inserting the database schema into the function specification. This will be important for the model to know about.


In [20]:
functions = [
    {
        "name": "ask_database",
        "description": "Use this function to answer user questions about flight itineraries. Output should be a fully formed SQL query.",
        "parameters": {
            "type": "object",
            "properties": {
                "query": {
                    "type": "string",
                    "description": f"""
                            SQL query extracting info to answer the user's question.
                            SQL should be written using this database schema:
                            {database_schema_string}
                            The query should be returned in plain text, not in JSON.
                            """,
                }
            },
            "required": ["query"],
        },
    }
]


Executing SQL queries

Now let's implement the function that will actually excute queries against the database.


In [21]:
def ask_database(conn, query):
    """Function to query duckdb database with a provided SQL query."""
    try:
        results = str(conn.execute(query).fetchall())
    except Exception as e:
        results = f"query failed with error: {e}"
    return results

def execute_function_call(message):
    if message["function_call"]["name"] == "ask_database":
        query = json.loads(message["function_call"]["arguments"])["query"]
        results = ask_database(conn, query)
    else:
        results = f"Error: function {message['function_call']['name']} does not exist"
    return results


In [27]:
messages = []
messages.append({"role": "system", "content": "Answer user questions by generating SQL queries against the itineraries database."})
messages.append({"role": "user", "content": "Hi, whats the total count of flights between ATL and BOS"})
chat_response = chat_completion_request(messages, functions)
assistant_message = chat_response.json()["choices"][0]["message"]
messages.append(assistant_message)
if assistant_message.get("function_call"):
    results = execute_function_call(assistant_message)
    messages.append({"role": "function", "name": assistant_message["function_call"]["name"], "content": results})
pretty_print_conversation(messages)

[31msystem: Answer user questions by generating SQL queries against the itineraries database.
[0m
[32muser: Hi, whats the total count of flights between ATL and BOS
[0m
[34massistant: {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT COUNT(*) FROM itineraries WHERE startingAirport = \'ATL\' AND destinationAirport = \'BOS\'"\n}'}
[0m
[35mfunction (ask_database): [(114860,)]
[0m


In [29]:
messages.append({"role": "user", "content": "List the top 5 flights between ATL and BOS sorted by higest fare"})
chat_response = chat_completion_request(messages, functions)
assistant_message = chat_response.json()["choices"][0]["message"]
messages.append(assistant_message)
if assistant_message.get("function_call"):
    results = execute_function_call(assistant_message)
    messages.append({"role": "function", "content": results, "name": assistant_message["function_call"]["name"]})
pretty_print_conversation(messages)

[31msystem: Answer user questions by generating SQL queries against the itineraries database.
[0m
[32muser: Hi, whats the total count of flights between ATL and BOS
[0m
[34massistant: {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT COUNT(*) FROM itineraries WHERE startingAirport = \'ATL\' AND destinationAirport = \'BOS\'"\n}'}
[0m
[35mfunction (ask_database): [(114860,)]
[0m
[32muser: List the top 5 flights between ATL and BOS
[0m
[34massistant: {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT * FROM itineraries WHERE startingAirport = \'ATL\' AND destinationAirport = \'BOS\' LIMIT 5"\n}'}
[0m
[35mfunction (ask_database): [('2022-04-17', 'ATL', 'BOS', 'PT2H29M', False, False, True, 217.67, 248.6, 9, 'Delta', '2022-04-17T15:26:00.000-04:00', '2022-04-17T12:57:00.000-04:00', 'coach'), ('2022-04-17', 'ATL', 'BOS', 'PT2H30M', False, False, True, 217.67, 248.6, 4, 'Delta', '2022-04-17T09:00:00.000-04:00', '2022-04-17T06:30:00.000-04:00', 'coach'), ('202