##### This is just a sample code so that you can understand how to use function calling it not a starter code 

In [7]:
import openai
import os
import json 
import requests
from pprint import pprint
import json
import openai
import requests
from tenacity import retry, wait_random_exponential, stop_after_attempt
from termcolor import colored

GPT_MODEL = "gpt-3.5-turbo-0613"

##### In the example below LLM acts as an intermediary that translates a user's natural language request into a structured SQL query, which is then executed by a database tool to fetch and return the relevant data.


In [8]:
from dotenv import load_dotenv, find_dotenv
from openai import OpenAI
_ = load_dotenv(find_dotenv())
openai.api_key  = os.getenv('OPENAI_API_KEY')

In [9]:
# import gcp 
_SCOPE = 'https://www.googleapis.com/auth/cloud-platform'
from google.cloud import bigquery
client = bigquery.Client.from_service_account_json('../../data_warehousing/data_warehousing/include/gcp/service_account.json')

In [10]:
@retry(wait=wait_random_exponential(multiplier=1, max=40), stop=stop_after_attempt(3))
def chat_completion_request(messages, tools=None, tool_choice=None, model=GPT_MODEL):
    headers = {
        "Content-Type": "application/json",
        "Authorization": "Bearer " + openai.api_key,
    }
    json_data = {"model": model, "messages": messages}
    if tools is not None:
        json_data.update({"tools": tools})
    if tool_choice is not None:
        json_data.update({"tool_choice": tool_choice})
    try:
        response = requests.post(
            "https://api.openai.com/v1/chat/completions",
            headers=headers,
            json=json_data,
        )
        
        return response
    except Exception as e:
        print("Unable to generate ChatCompletion response")
        print(f"Exception: {e}")
        return e

In [11]:
def pretty_print_conversation(messages):
    role_to_color = {
        "system": "red",
        "user": "green",
        "assistant": "blue",
        "tool": "magenta",
    }
    
    for message in messages:
        if message["role"] == "system":
            print(colored(f"system: {message['content']}\n", role_to_color[message["role"]]))
        elif message["role"] == "user":
            print(colored(f"user: {message['content']}\n", role_to_color[message["role"]]))
        elif message["role"] == "assistant" and message.get("function_call"):
            print(colored(f"assistant: {message['function_call']}\n", role_to_color[message["role"]]))
        elif message["role"] == "assistant" and not message.get("function_call"):
            print(colored(f"assistant: {message['content']}\n", role_to_color[message["role"]]))
        elif message["role"] == "tool":
            print(colored(f"function ({message['name']}): {message['content']}\n", role_to_color[message["role"]]))


In [12]:
def get_table_names(client, dataset_id):
    """Return a list of table names in the specified dataset."""
    table_names = []
    dataset_ref = client.dataset(dataset_id)
    # List tables in the dataset
    tables = client.list_tables(dataset_ref)
    for table in tables:
        table_names.append(table.table_id)
    return table_names


def get_column_names(client, dataset_id, table_name):
    """Return a list of column names for the specified table."""
    column_names = []
    table_ref = client.dataset(dataset_id).table(table_name)
    table = client.get_table(table_ref)
    for field in table.schema:
        column_names.append(field.name)
    return column_names

def get_database_info(client, dataset_id):
    """Return a list of dicts containing the table name and columns for each table in the dataset."""
    table_dicts = []
    for table_name in get_table_names(client, dataset_id):
        column_names = get_column_names(client, dataset_id, table_name)
        table_dicts.append({"table_name": table_name, "column_names": column_names})
    return table_dicts


In [15]:
database_schema_dict = get_database_info(client,'youtube')
database_schema_string = "\n".join(
    [
        f"Table: {table['table_name']}\nColumns: {', '.join(table['column_names'])}"
        for table in database_schema_dict
    ]
)

In [16]:
print(database_schema_string)

Table: dim_dail_view_7day
Columns: Date, RollingAverageViews
Table: dim_top10
Columns: Date, TotalViews
Table: raw_cities
Columns: Cities, City_name, Geography, Geography_3, Views, Watch_time__hours_, Average_view_duration
Table: raw_gender
Columns: Date, Views, Watch_time__hours_, Average_view_duration
Table: raw_total
Columns: Date, Views


In [17]:
tools = [
    {
        "type": "function",
        "function": {
            "name": "ask_database",
            "description": "Use this function to answer user questions about music. Input should be a fully formed SQL query.",
            "parameters": {
                "type": "object",
                "properties": {
                    "query": {
                        "type": "string",
                        "description": f"""
                                SQL query extracting info to answer the user's question.
                                SQL should be written using this database schema:
                                {database_schema_string}
                                The query should be returned in plain text, not in JSON.
                                """,
                    }
                },
                "required": ["query"],
            },
        }
    }
]

In [23]:
def ask_database(client, query):
    """Function to query BigQuery dataset with a provided SQL query."""
    try:
        query_job = client.query(query)
        results = query_job.result()
        rows = [row.values() for row in results]
    except Exception as e:
        results = f"query failed with error: {e}"
        rows = []
    return rows

def execute_function_call(message):
    if message["tool_calls"][0]["function"]["name"] == "ask_database":
        query = json.loads(message["tool_calls"][0]["function"]["arguments"])["query"]
        results = ask_database(client, query)
    else:
        results = f"Error: function {message['tool_calls'][0]['function']['name']} does not exist"
    return results

In [24]:
messages = []
messages.append({"role": "system", "content": "Answer user questions by generating SQL queries against the youtube analytic Database."})
messages.append({"role": "user", "content": "Hi, show me all cities?"})

# Assuming chat_completion_request returns a response with tool calls
chat_response = chat_completion_request(messages, tools)
print("===================", chat_response.json())

assistant_message = chat_response.json()["choices"][0]["message"]
assistant_message['content'] = str(assistant_message["tool_calls"][0]["function"])
print("===================", assistant_message['content'])

messages.append(assistant_message)

if assistant_message.get("tool_calls"):
    results = execute_function_call(assistant_message)
    messages.append({
        "role": "tool",
        "tool_call_id": assistant_message["tool_calls"][0]['id'],
        "name": assistant_message["tool_calls"][0]["function"]["name"],
        "content": results
    })
pretty_print_conversation(messages)

[31msystem: Answer user questions by generating SQL queries against the youtube analytic Database.
[0m
[32muser: Hi, show me all cities?
[0m
[34massistant: {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT DISTINCT City_name FROM raw_cities"\n}'}
[0m
[35mfunction (ask_database): []
[0m


In [26]:
messages.append({"role": "user", "content": "What is the name of the cities that start ET?"})
chat_response = chat_completion_request(messages, tools)
assistant_message = chat_response.json()["choices"][0]["message"]
assistant_message['content'] = str(assistant_message["tool_calls"][0]["function"])
messages.append(assistant_message)
if assistant_message.get("tool_calls"):
    results = execute_function_call(assistant_message)
    messages.append({"role": "tool", "tool_call_id": assistant_message["tool_calls"][0]['id'], "name": assistant_message["tool_calls"][0]["function"]["name"], "content": results})
pretty_print_conversation(messages)


[31msystem: Answer user questions by generating SQL queries against the youtube analytic Database.
[0m
[32muser: Hi, show me all cities?
[0m
[34massistant: {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT DISTINCT City_name FROM raw_cities"\n}'}
[0m
[35mfunction (ask_database): []
[0m
[32muser: What is the name of the album with the most tracks?
[0m
[34massistant: {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT Album_name, COUNT(Track_name) as Track_count FROM music_album GROUP BY Album_name ORDER BY Track_count DESC LIMIT 1"\n}'}
[0m
[35mfunction (ask_database): []
[0m
[32muser: What is the name of the cities that start ET?
[0m
[34massistant: {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT City_name FROM raw_cities WHERE City_name LIKE \'ET%\'"\n}'}
[0m
[35mfunction (ask_database): []
[0m


In [27]:
messages.append({"role": "user", "content": "What are the avrage view per day?"})
chat_response = chat_completion_request(messages, tools)
assistant_message = chat_response.json()["choices"][0]["message"]
assistant_message['content'] = str(assistant_message["tool_calls"][0]["function"])
messages.append(assistant_message)
if assistant_message.get("tool_calls"):
    results = execute_function_call(assistant_message)
    messages.append({"role": "tool", "tool_call_id": assistant_message["tool_calls"][0]['id'], "name": assistant_message["tool_calls"][0]["function"]["name"], "content": results})
pretty_print_conversation(messages)


[31msystem: Answer user questions by generating SQL queries against the youtube analytic Database.
[0m
[32muser: Hi, show me all cities?
[0m
[34massistant: {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT DISTINCT City_name FROM raw_cities"\n}'}
[0m
[35mfunction (ask_database): []
[0m
[32muser: What is the name of the album with the most tracks?
[0m
[34massistant: {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT Album_name, COUNT(Track_name) as Track_count FROM music_album GROUP BY Album_name ORDER BY Track_count DESC LIMIT 1"\n}'}
[0m
[35mfunction (ask_database): []
[0m
[32muser: What is the name of the cities that start ET?
[0m
[34massistant: {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT City_name FROM raw_cities WHERE City_name LIKE \'ET%\'"\n}'}
[0m
[35mfunction (ask_database): []
[0m
[32muser: What are the avrage view per day?
[0m
[34massistant: {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT AVG(RollingAve