In [None]:
from dotenv import load_dotenv
import websockets
import json
import asyncio
import sqlite3
import requests

load_dotenv()

import os
oopenai_api_key = os.environ.get("OPENAI_API_KEY")
BOT_TOKEN = "5a3556bb2de44a73ab2e5643cb633a6c"
THREAD_ID = "default"
DB_PATH = '../datasets/nba_sql.db'
uri = f"ws://localhost:8000/ws/chat?thread_id={THREAD_ID}"

In [None]:
def get_gpt3_response(prompt, stop=None):
    headers = {"Authorization": f"Bearer {oopenai_api_key}", "Content-Type": "application/json"}
    data = {"prompt": prompt, "max_tokens": 500, "temperature": 0, "model": "text-davinci-002"}
    if stop:
        data["stop"] = stop
    response = requests.post("https://api.openai.com/v1/completions", headers=headers, json=data)
    return response.json()["choices"][0]["text"]

In [None]:
async def chatbot_messages():
    uri = f"ws://localhost:8000/ws/chat?thread_id={THREAD_ID}"
    messages = []
    async with websockets.connect(uri, extra_headers={"Cookie": f"Authorization=Bearer {BOT_TOKEN}"}) as ws:
        while True:
            try:
                raw_data = await asyncio.wait_for(ws.recv(), timeout=0.5)
                data = json.loads(raw_data)
                if data.get("replay", False):
                    messages.append(data)
                    continue
            except:
                break
    return messages

In [None]:
from jinja2 import Template

verbose = True
# https://zetcode.com/python/jinja/
class GPT3Prompt:
    def __init__(self, prompt_template_string, stop=None):
        self.prompt_template = Template(prompt_template_string)
        self.stop = stop

    def get_prompt(self, **kwargs):
        return self.prompt_template.render(**kwargs)

    def execute(self, **kwargs):
        prompt = self.get_prompt(**kwargs)
        if verbose:
            print("="*80)
            print(f"prompt:\n{prompt}")
        response = get_gpt3_response(prompt, self.stop)
        if verbose:
            print("-"*80)
            print(f"response:\n{response}")
            print("="*80)
        return response

    def __call__(self, **kwargs):
        return self.execute(**kwargs)

In [None]:
isSQLyn = GPT3Prompt(
"""
The following is a conversation with a data expert (datAI). 
In order to respond to the prompt, should the datAI use SQL or just say something friendly in response?
===
{% for message in messages %}
{{ message["sender"] }}: {{ message["message"] }}
{% endfor %}
===
The datAI should use SQL to answer the question (y/n):"""
)

In [None]:
sqlExecutor = GPT3Prompt(
"""
The following is a conversation with a data expert (datAI). In order to be accurate, datAI would like to use SQL. 
===
{% for message in messages %}
{{ message["sender"] }}: {{ message["message"] }}
{% endfor %}
===
|| Database Context ||
{{ database_context }}
||||||||||||||||||||||
{% if previousAttempt %}
{{ previousAttempt }}
{% endif %}
The SQL query that datAI would like to execute is:
"""
)

In [None]:
gotAValidSQLResponse = GPT3Prompt(
"""
===
{% for message in messages %}
{{ message["sender"] }}: {{ message["message"] }}
{% endfor %}
===
In order to answer a question (above), a data expert (datAI) executed the following SQL query:
{{ sql }}
And the result of the execution was:
{{ result }}
===
Is the result of the SQL query correct and answer the users question? (y/n):"""
)

In [None]:
composeDataDrivenAnswer = GPT3Prompt(
"""
The following is a conversation with a data expert (datAI). In order to be accurate, the data expert used a database (queried with SQL).
===
{% for message in messages %}
{{ message["sender"] }}: {{ message["message"] }}
{% endfor %}
===
The query and response:
{{ sql }}
{{ result }}
===
What should datAI say in response to the last message?
"""
)

In [None]:
neededHelpButStillFriendly = GPT3Prompt(
"""
The following is a conversation with a friendly chatbot (datAI, a data expert), who loves basketball.
The dataAI tried to run some SQL, but failed to get a good response to the question, 
and now wants to ask a clarifying question so that it can query the database better to help the user...
===
{% for message in messages %}
{{ message["sender"] }}: {{ message["message"] }}
{% endfor %}
datAI:""", stop="\n"
)

In [None]:
friendlyChatbot = GPT3Prompt(
"""
The following is a conversation with a friendly chatbot (datAI, a data expert), who loves basketball.
The response will never contain a data response unless it has proof in the form of executed SQL. 
datAI does not respond with guesses, so will ask questions to clarify the users intent to help it formulate a SQL query.
===
{% for message in messages %}
{{ message["sender"] }}: {{ message["message"] }}
{% endfor %}
datAI:"""
)

In [None]:
checkIfNeedsSQL = GPT3Prompt(
"""
The following is a conversation with a friendly chatbot (datAI, a data expert), who loves basketball.
Does the following statement contain any attempt at factual information that would exist in a database?
Statement: {{ statement }}
===
y/n:"""
)

In [None]:
conn = sqlite3.connect(DB_PATH) 
c = conn.cursor()
# execute a sql query to get all the tables and the columns in the tables
c.execute("SELECT name FROM sqlite_master WHERE type='table';")
outputschema = ""
for table, in c.fetchall():
    outputschema += table + "||["
    c.execute(f"PRAGMA table_info({table})")
    for _, column, dtype, *_ in c.fetchall():
        outputschema += f"{column}:{dtype},"
    outputschema += "]\n"
database_context = outputschema

def example_promptchain(messages):
    answer = friendlyChatbot(messages=messages[-5:])
    response = checkIfNeedsSQL(statement=answer).strip()[:1]
    trials = 0
    if response != "n":
        previousAttempt = None
        while trials < 3:
            sqlresponse = sqlExecutor(messages=messages[-5:], previousAttempt=previousAttempt, database_context=database_context)
            try:
                conn = sqlite3.connect(DB_PATH) 
                c = conn.cursor()
                c.execute(sqlresponse)
                res = str(c.fetchall())[:500]
            except Exception as e:
                res = f"Error executing SQL {e}"
            validSQL = gotAValidSQLResponse(messages=messages[-5:], sql=sqlresponse, result=res).strip()[:1]
            if validSQL != "n":
                return composeDataDrivenAnswer(messages=messages[-5:], sql=sqlresponse, result=res)
            trials += 1
            print("I tried and failed... trying again")
            previousAttempt = f"{sqlresponse}\n{res[:100]}"
    return answer


In [None]:

async def chatbot():
    messages = []
    async with websockets.connect(uri, extra_headers={"Cookie": f"Authorization=Bearer {BOT_TOKEN}"}) as ws:
        while True:
            try:
                raw_data = await asyncio.wait_for(ws.recv(), timeout=0.5)
                data = json.loads(raw_data)
                print("log:", data['message'])
                if data.get("replay", False) or data["sender"] == "datAI":
                    messages.append(data)
                    continue
                if data["message"] == "break":
                    break
                if data.get("meta", False):
                    continue
                messages.append(data)
            except asyncio.TimeoutError:
                # not sure why I'm doing this, it felt important to have an "infinite" loop...
                continue
            response = example_promptchain(messages)
            await ws.send(json.dumps({"message": response}))


In [None]:
await chatbot()

In [None]:
# messages = await chatbot_messages()
# response = example_promptchain(messages)