In [None]:
from dotenv import load_dotenv
import websockets
import json
import asyncio
import sqlite3
import requests

load_dotenv()

import os
oopenai_api_key = os.environ.get("OPENAI_API_KEY")
BOT_TOKEN = "5a3556bb2de44a73ab2e5643cb633a6c"
THREAD_ID = "default"
DB_PATH = '../datasets/nba_sql.db'
uri = f"wss://www.approx.dev/ws/chat?thread_id={THREAD_ID}"

In [None]:
def get_gpt3_response(prompt, stop=None):
    headers = {"Authorization": f"Bearer {oopenai_api_key}", "Content-Type": "application/json"}
    data = {"prompt": prompt, "max_tokens": 500, "temperature": 0, "model": "text-davinci-002"}
    if stop:
        data["stop"] = stop
    response = requests.post("https://api.openai.com/v1/completions", headers=headers, json=data)
    return response.json()["choices"][0]["text"]

In [None]:
async def chatbot_messages():
    uri = f"ws://localhost:8000/ws/chat?thread_id={THREAD_ID}"
    messages = []
    async with websockets.connect(uri, extra_headers={"Cookie": f"Authorization=Bearer {BOT_TOKEN}"}) as ws:
        while True:
            try:
                raw_data = await asyncio.wait_for(ws.recv(), timeout=0.5)
                data = json.loads(raw_data)
                if data.get("replay", False):
                    messages.append(data)
                    continue
            except:
                break
    return messages

In [None]:
from databases import Database
import uuid

MIGRATION_VERSION_TABLE = "mochaver"


async def table_exists(db: Database, table_name: str):
    query = "SELECT name FROM sqlite_master WHERE type='table' AND name=:table_name;"
    result = await db.fetch_one(query, values={"table_name": table_name})
    return result is not None


async def get_migration_version(db: Database):
    if await table_exists(db, MIGRATION_VERSION_TABLE):
        version_query = f"SELECT version FROM {MIGRATION_VERSION_TABLE};"
        (result,) = await db.fetch_one(version_query)
        return result
    return None


async def set_version(db: Database, version: int):
    query = f"UPDATE {MIGRATION_VERSION_TABLE} SET version = :version;"
    await db.execute(query, values={"version": version})


MIGRATIONS = {}


async def setup_database(db: Database):
    async with db.transaction():
        # check if table exists for "migration_version"
        migration_version = await get_migration_version(db)
        for _, migration in sorted(MIGRATIONS.items(), key=lambda x: x[0]):
            await migration(db, migration_version)


def migration(version: int):
    def decorator(func):
        async def run_migration(db: Database, db_version: int):
            if db_version is None or db_version < version:
                print("Running migration", version)
                await func(db)
                await set_version(db, version)

        MIGRATIONS[version] = run_migration
        return run_migration

    return decorator

@migration(0)
async def migration_0(db: Database):
    create_migration_table = f"""
        CREATE TABLE {MIGRATION_VERSION_TABLE} (
            version INTEGER NOT NULL PRIMARY KEY
        ) WITHOUT ROWID;
    """
    await db.execute(create_migration_table)
    await db.execute(f"INSERT INTO {MIGRATION_VERSION_TABLE} (version) VALUES (0);")



@migration(1)
async def migration_1(db: Database):
    queries = [
        """
        CREATE TABLE promptHistory (
            id TEXT NOT NULL PRIMARY KEY,
            prompt_id TEXT NOT NULL,
            prompt_name TEXT NOT NULL,
            inputs TEXT NOT NULL,
            response TEXT NOT NULL,
            duration REAL NOT NULL,
            timestamp TEXT NOT NULL
        ) WITHOUT ROWID;
        """,
    ]
    for query in queries:
        await db.execute(query)


async def record_prompt(db: Database, prompt_id, prompt_name, inputs, response, duration):
    query = """
        INSERT INTO promptHistory (id, prompt_id, prompt_name, inputs, response, duration, timestamp)
        VALUES (:id, :prompt_id, :prompt_name, :inputs, :response, :duration, datetime());
    """
    await db.execute(
        query,
        values={
            "id": str(uuid.uuid4()),
            "prompt_id": prompt_id,
            "prompt_name": prompt_name,
            "inputs": json.dumps(inputs),
            "response": response,
            "duration": duration,
        },
    )

async def get_prompts(db: Database, prompt_name: str, n=5):
    query = f"""
        SELECT inputs, response FROM promptHistory
        WHERE prompt_name = :prompt_name
        ORDER BY timestamp DESC
        LIMIT {n};
    """
    result = await db.fetch_all(query, values={"prompt_name": prompt_name})
    return [(json.loads(inputs), response) for inputs, response in result]

In [None]:
from re import S
from jinja2 import Environment, meta
import time
import inspect
from hashlib import md5

database = Database("sqlite+aiosqlite:///promptHistory.db")
await setup_database(database)
env = Environment()


class Prompt:
    # prompts are functions that take in inputs and output strings
    def __init__(self, name, function=None):
        self.name = name
        self.function = function

    def execute(self, *args, **kwargs):
        if self.function is None:
            raise NotImplementedError("Must implement function")
        return self.function(*args, **kwargs)

    def __call__(self, *args, **kwargs):
        st = time.time()
        response = self.execute(*args, **kwargs)
        et = time.time()
        asyncio.create_task(
            record_prompt(database, self.id, self.name, {'args': args, 'kwargs': kwargs}, response, et - st)
        )
        return response

    @property
    def id(self):
        # grab the code from execute method and hash it
        return md5(inspect.getsource(self.function).encode('utf-8')).hexdigest()

# https://zetcode.com/python/jinja/
class GPT3Prompt(Prompt):
    def __init__(self, name, prompt_template_string, stop=None):
        super().__init__(name)
        self.prompt_template_string = prompt_template_string
        self.prompt_template = env.from_string(prompt_template_string)
        self.stop = stop

    def get_named_args(self):
        return meta.find_undeclared_variables(env.parse(self.prompt_template_string))

    def get_prompt(self, **kwargs):
        return self.prompt_template.render(**kwargs)

    def execute(self, **kwargs):
        prompt = self.get_prompt(**kwargs)
        response = get_gpt3_response(prompt, self.stop)
        return response

    @property
    def id(self):
        # grab the code from execute method and hash it
        return md5(self.prompt_template_string.encode('utf-8')).hexdigest()
    

In [None]:
isSQLyn = GPT3Prompt("isSQLyn",
"""
The following is a conversation with a data expert (datAI). 
In order to respond to the prompt, should the datAI use SQL or just say something friendly in response?
The users intent: {{ userIntent }}
===
{{ conversation }}
===
The datAI should use SQL to answer the question (y/n):"""
)

In [None]:
# isSQLyn(userIntent="I want to know how many people are in the database", conversation="Hi there")

In [None]:
sqlExecutor = GPT3Prompt("sqlExecutor",
"""
The following is a conversation with a data expert (datAI). In order to be accurate, datAI would like to use SQL. 
There is a sqlite database (nba.sql)
UserIntent: {{userIntent}}
---
|| Database Context ||
{{ database_context }}
===
{{ conversation }}
===
{% if previousAttempt %}
The AI has tried already but failed: Previous Query and response
{{ previousAttempt }}
---
New Attempt that is different
{% endif %}
The SQLite query that datAI would like to execute is:
```"""
, stop="```")

In [None]:
gotAValidSQLResponse = GPT3Prompt("gotAValidSQLResponse",
"""
===
{{ conversation }}
===
In order to answer a question (above), a data expert (datAI) executed the following SQL query:
{{ sql }}
And the result of the execution was:
{{ result }}
===
Did the query execute successfully (y/n):"""
)

In [None]:
composeDataDrivenAnswer = GPT3Prompt("composeDataDrivenAnswer",
""""
The following is a conversation with a data expert (datAI). In order to be accurate, the data expert used a database (queried with SQL).
===
{{ conversation }}
===
The query and response:
{{ sql }}
Result:
{{ result }}
===
What should datAI say in response to the last message summarizing what it ran and its result (it can just directly copy if that is the best answer)?
"""
)

In [None]:
neededHelpButStillFriendly = GPT3Prompt("neededHelpButStillFriendly",
"""
The following is a conversation with a friendly chatbot (datAI, a data expert), who loves basketball.
The dataAI tried to run some SQL, but failed to get a good response to the question, 
and now wants to ask a clarifying question so that it can query the database better to help the user...
===
{{ conversation }}
datAI:"""
)

In [None]:
friendlyChatbot = GPT3Prompt("friendlyChatbot",
"""
The following is a conversation with a friendly chatbot (datAI, a data expert), who loves basketball.
The response will never contain a data response unless it has proof in the form of executed SQL. 
datAI does not respond with guesses, so will ask questions to clarify the users intent to help it formulate a SQL query.
datAI never repeats itself either.
userIntent: {{ userIntent }}
===
{{ conversation }}
datAI:"""
)

In [None]:
checkIfNeedsSQL = GPT3Prompt("checkIfNeedsSQL",
"""
The following is a conversation with a friendly chatbot (datAI, a data expert), who loves basketball.
Does the following statement contain any attempt at factual information that would exist in a database?

Statement: {{ statement }}
===
y/n:"""
)

In [None]:
getUserIntent = GPT3Prompt("getUserIntent",
"""
The following is a conversation with a friendly chatbot (datAI, a data expert).
===
{{ conversation }}
===
What is the users intent in this conversation?
"""
)

In [None]:
def starts_with_y(string_data):
    return string_data.strip().lower()[:1] == 'y'

In [None]:
def execute_sql(sql, response_limit=500):
    try:
        conn = sqlite3.connect(DB_PATH) 
        c = conn.cursor()
        c.execute(sql)
        res = str(c.fetchall())[:response_limit]
    except Exception as e:
        res = f"Error executing SQL {e}"
    return res

exec_sql = Prompt("exec_sql", execute_sql)

In [None]:
conn = sqlite3.connect(DB_PATH) 
c = conn.cursor()
# execute a sql query to get all the tables and the columns in the tables
c.execute("SELECT name FROM sqlite_master WHERE type='table';")
outputschema = ""
for table, in c.fetchall():
    outputschema += table + "("
    c.execute(f"PRAGMA table_info({table})")
    for _, column, dtype, *_ in c.fetchall():
        outputschema += f" {column} "
    outputschema += ") \n"
    output = c.execute(f"select * from {table} limit 1").fetchall()
    if len(output) == 0:
        outputschema += "Empty Table"
    else:
        outputschema += " ".join([str(x) for x in output[0]])
    outputschema += "\n"
database_context = outputschema + "Example Query (Russell Westbrook's total Triple-Doubles)\n"
database_context += """
SELECT SUM(td3) 
FROM player_game_log 
LEFT JOIN player ON player.player_id = player_game_log.player_id 
WHERE player.player_name = 'Russell Westbrook';
=========
"""


In [None]:


def example_promptchain(conversation):
    userIntent = getUserIntent(conversation=conversation)
    shouldUseSQL = isSQLyn(conversation=conversation, userIntent=userIntent)
    answer = friendlyChatbot(userIntent=userIntent, conversation=conversation)
    needsSql = checkIfNeedsSQL(statement=answer)
    trials = 0
    if starts_with_y(shouldUseSQL) or starts_with_y(needsSql):
        previousAttempt = None
        while trials < 3:
            sqlresponse = sqlExecutor(userIntent=userIntent, conversation=conversation, previousAttempt=previousAttempt, database_context=database_context)
            res = exec_sql(sqlresponse)
            validSQL = gotAValidSQLResponse(conversation=conversation, sql=sqlresponse, result=res)
            if starts_with_y(validSQL):
                return composeDataDrivenAnswer(conversation=conversation, sql=sqlresponse, result=res)
            trials += 1
            previousAttempt = f"```{sqlresponse}```\nResult:{res[:100]}\n"
        return neededHelpButStillFriendly(conversation=conversation)
    return answer

In [None]:
epc = Prompt("example_promptchain", example_promptchain)

In [None]:
epc("Justin: Who won the most games in 2016?\nDatAI:The query ran was:\n\nSELECT team_id_winner, COUNT(*) as num_wins\nFROM game\nWHERE season_id = 2016\nGROUP BY team_id_winner\nORDER BY num_wins DESC\nLIMIT 1;\n\nThe result was:\n\n[(1610612744, 67)]\nJustin:What team is 1610612744?")

In [None]:
async def chatbot():
    messages = []
    async with websockets.connect(uri, extra_headers={"Cookie": f"Authorization=Bearer {BOT_TOKEN}"}) as ws:
        while True:
            try:
                raw_data = await asyncio.wait_for(ws.recv(), timeout=0.5)
                data = json.loads(raw_data)
                print("log:", data['message'])
                if data.get("replay", False) or data["sender"] == "datAI":
                    messages.append(data)
                    continue
                if data["message"] == "break":
                    break
                if data.get("meta", False):
                    continue
                messages.append(data)
            except asyncio.TimeoutError:
                # not sure why I'm doing this, it felt important to have an "infinite" loop...
                continue
            response = epc("\n".join([f"{m['sender']}: {m['message']}" for m in messages]))
            await ws.send(json.dumps({"message": response}))


In [None]:
await chatbot()

In [None]:
# messages = await chatbot_messages()
# response = example_promptchain(messages)

In [None]:
thing = await get_prompts(database, "sqlExecutor")

In [None]:
import pandas as pd
pd.DataFrame(thing)

In [None]:
sqlExecutor_test = GPT3Prompt("sqlExecutor_test",
"""
The following is a conversation with a data expert (datAI). In order to be accurate, datAI would like to use SQL. 
There is a sqlite database (nba.sql)
UserIntent: {{userIntent}}
===
{{ conversation }}
===e`
The SQLite query that datAI would like to execute is:
```"""
, stop="```")


In [None]:
# for t_in, t_out in thing:
#     print(f"Was: {t_out}, now is: {sqlExecutor_test(**t_in['kwargs'])}")