In [1]:
# Importing Necessary Modules 
import openai
import os
import json
from pprint import pprint
import requests
from tenacity import retry, wait_random_exponential, stop_after_attempt
from termcolor import colored
import re
import sys 
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy import inspect
from dotenv import load_dotenv, find_dotenv
from openai import OpenAI 

In [2]:
# Importing Chat GPT Model
GPT_MODEL = "gpt-3.5-turbo-0613"

In [3]:
load_dotenv()
openai.api_key  = os.getenv('OPENAI_API_KEY')

In [4]:
username = os.getenv("POSTGRES_USERNAME")
password = os.getenv("POSTGRES_PASSWORD")
host = os.getenv("POSTGRES_HOST")
port = os.getenv("POSTGRES_PORT")
database = os.getenv("POSTGRES_DATABASE")
database_url = f"postgresql://{username}:{password}@{host}:{port}/{database}"
engine = create_engine(database_url)
Session = sessionmaker(bind=engine)


In [5]:
conn = engine.connect()
print("Opened database successfully")

Opened database successfully


In [6]:
@retry(wait=wait_random_exponential(multiplier=1, max=40), stop=stop_after_attempt(3))
def chat_completion_request(messages, tools=None, tool_choice=None, model=GPT_MODEL):
    headers = {
        "Content-Type": "application/json",
        "Authorization": "Bearer " + openai.api_key,
    }
    json_data = {"model": model, "messages": messages}
    if tools is not None:
        json_data.update({"tools": tools})
    if tool_choice is not None:
        json_data.update({"tool_choice": tool_choice})
    try:
        response = requests.post(
            "https://api.openai.com/v1/chat/completions",
            headers=headers,
            json=json_data,
        )

        return response
    except Exception as e:
        print("Unable to generate ChatCompletion response")
        print(f"Exception: {e}")
        return e

In [7]:
def pretty_print_conversation(messages):
    role_to_color = {
        "system": "red",
        "user": "green",
        "assistant": "blue",
        "tool": "magenta",
    }

    for message in messages:
        if message["role"] == "system":
            print(colored(f"system: {message['content']}\n", role_to_color[message["role"]]))
        elif message["role"] == "user":
            print(colored(f"user: {message['content']}\n", role_to_color[message["role"]]))
        elif message["role"] == "assistant" and message.get("function_call"):
            print(colored(f"assistant: {message['function_call']}\n", role_to_color[message["role"]]))
        elif message["role"] == "assistant" and not message.get("function_call"):
            print(colored(f"assistant: {message['content']}\n", role_to_color[message["role"]]))
        elif message["role"] == "tool":
            print(colored(f"function ({message['name']}): {message['content']}\n", role_to_color[message["role"]]))

In [8]:
def get_table_names(engine):
    """Return a list of table names."""
    table_names = []
    inspector = inspect(engine)
    for table_name in inspector.get_table_names():
        table_names.append(f'"{table_name}"')  # Add double quotes around table name
    return table_names

def get_column_names(engine, table_name):
    """Return a list of column names."""
    column_names = []
    inspector = inspect(engine)
    for column in inspector.get_columns(table_name):
        column_names.append(f'"{column["name"]}"')  # Add double quotes around column name
    return column_names

def get_database_info(engine):
    """Return a list of dicts containing the table name and columns for each table in the database."""
    table_dicts = []
    inspector = inspect(engine)
    for table_name in inspector.get_table_names():
        columns_names = get_column_names(engine, table_name)
        table_dicts.append({"table_name": f'"{table_name}"', "column_names": columns_names})  # Add double quotes around table name
    return table_dicts

In [9]:
database_schema_dict = get_database_info(conn)
database_schema_string = "\n".join(
    [
        f'Table: "{table["table_name"]}"\nColumns: {", ".join([f"{col}" for col in table["column_names"]])}'
        for table in database_schema_dict
    ]
)

In [10]:
tools = [
    {
        "type": "function",
        "function": {
            "name": "ask_database",
            "description": "Use this function to answer user questions about youtube. Input should be a fully formed SQL query.",
            "parameters": {
                "type": "object",
                "properties": {
                    "query": {
                        "type": "string",
                        "description": f"""
                                SQL query extracting info to answer the user's question.
                                SQL should be written using this database schema:
                                {database_schema_string}
                                The query should be returned in plain text, not in JSON.
                                """,
                    }
                },
                "required": ["query"],
            },
        }
    }
]

In [13]:
def ask_database(engine, query):
    """Function to query PostgreSQL database with a provided SQL query."""
    try:
        print("Executing Query:", query)  # Add this line
        with engine.connect() as conn:
            result = conn.execute(query)
            results = result.fetchall()
    except Exception as e:
        results = f"Query failed with error: {e}"
    return results

# def ask_database(engine, query):
#     """Function to query PostgreSQL database with a provided SQL query."""
#     try:
#         with engine.connect() as conn:
#             result = conn.execute(query)
#             results = result.fetchall()
#     except Exception as e:
#         print("Query execution failed with error:", e)  # Add this line
#         results = f"Query failed with error: {e}"
#     return results
def execute_function_call(message, engine):
    if message["tool_calls"][0]["function"]["name"] == "ask_database":
        query = json.loads(message["tool_calls"][0]["function"]["arguments"])["query"]
        results = ask_database(engine, query)
    else:
        results = f"Error: function {message['tool_calls'][0]['function']['name']} does not exist"
    return results

In [14]:
messages = []
messages.append({"role": "system", "content": "Answer user queries by generating SQL queries from the youtube data."})
messages.append({"role": "user", "content": "Hi, who are the top 5 cities by number of viewers?"})
chat_response = chat_completion_request(messages, tools)
print("===================",chat_response.json())
assistant_message = chat_response.json()["choices"][0]["message"]
assistant_message['content'] = str(assistant_message["tool_calls"][0]["function"])
print("===================",assistant_message['content'])
messages.append(assistant_message)
if assistant_message.get("tool_calls"):
    results = execute_function_call(assistant_message, engine)
    messages.append({"role": "tool", "tool_call_id": assistant_message["tool_calls"][0]['id'], "name": assistant_message["tool_calls"][0]["function"]["name"], "content": results})
pretty_print_conversation(messages)

Executing Query: SELECT "City name", SUM("Views") AS TotalViews FROM "cities_table_data" GROUP BY "City name" ORDER BY TotalViews DESC LIMIT 5;
[31msystem: Answer user queries by generating SQL queries from the youtube data.
[0m
[32muser: Hi, who are the top 5 cities by number of viewers?
[0m
[34massistant: {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT \\"City name\\", SUM(\\"Views\\") AS TotalViews FROM \\"cities_table_data\\" GROUP BY \\"City name\\" ORDER BY TotalViews DESC LIMIT 5;"\n}'}
[0m
[35mfunction (ask_database): []
[0m


In [15]:
messages.append({"role": "user", "content": "What is the name of the city with the most views?"})
chat_response = chat_completion_request(messages, tools)
assistant_message = chat_response.json()["choices"][0]["message"]
assistant_message['content'] = str(assistant_message["tool_calls"][0]["function"])
messages.append(assistant_message)
if assistant_message.get("tool_calls"):
    results = execute_function_call(assistant_message, engine)
    messages.append({"role": "tool", "tool_call_id": assistant_message["tool_calls"][0]['id'], "name": assistant_message["tool_calls"][0]["function"]["name"], "content": results})
pretty_print_conversation(messages)

Executing Query: SELECT "City name" FROM ("cities_table_data") WHERE "Views" = (SELECT MAX("Views") FROM ("cities_table_data")) LIMIT 1;
[31msystem: Answer user queries by generating SQL queries from the youtube data.
[0m
[32muser: Hi, who are the top 5 cities by number of viewers?
[0m
[34massistant: {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT \\"City name\\", SUM(\\"Views\\") AS TotalViews FROM \\"cities_table_data\\" GROUP BY \\"City name\\" ORDER BY TotalViews DESC LIMIT 5;"\n}'}
[0m
[35mfunction (ask_database): []
[0m
[32muser: What is the name of the city with the most views?
[0m
[34massistant: {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT \\"City name\\" FROM (\\"cities_table_data\\") WHERE \\"Views\\" = (SELECT MAX(\\"Views\\") FROM (\\"cities_table_data\\")) LIMIT 1;"\n}'}
[0m
[35mfunction (ask_database): Query failed with error: (psycopg2.errors.SyntaxError) syntax error at or near ")"
LINE 1: SELECT "City name" FROM ("cities_table_

In [16]:
messages.append({"role": "user", "content": "What is the name of the city with the most views?"})
chat_response = chat_completion_request(messages, tools)
assistant_message = chat_response.json()["choices"][0]["message"]
assistant_message['content'] = str(assistant_message["tool_calls"][0]["function"])
messages.append(assistant_message)
if assistant_message.get("tool_calls"):
    results = execute_function_call(assistant_message, engine)
    messages.append({"role": "tool", "tool_call_id": assistant_message["tool_calls"][0]['id'], "name": assistant_message["tool_calls"][0]["function"]["name"], "content": results})
pretty_print_conversation(messages)

Query execution failed with error: Not an executable object: 'SELECT city, COUNT(*) AS view_count FROM viewers GROUP BY city ORDER BY view_count DESC LIMIT 1'
[31msystem: Answer user questions by generating SQL queries against the youtube data Database.
[0m
[32muser: Hi, who are the top 5 cities by number of viewers?
[0m
[34massistant: {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT city, COUNT(DISTINCT user_id) AS viewer_count FROM viewers GROUP BY city ORDER BY viewer_count DESC LIMIT 5"\n}'}
[0m
[35mfunction (ask_database): Query failed with error: Not an executable object: 'SELECT city, COUNT(DISTINCT user_id) AS viewer_count FROM viewers GROUP BY city ORDER BY viewer_count DESC LIMIT 5'
[0m
[32muser: What is the name of the city with the most views?
[0m
[34massistant: {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT city, COUNT(*) AS view_count FROM viewers GROUP BY city ORDER BY view_count DESC LIMIT 1"\n}'}
[0m
[35mfunction (ask_database): Q