In [2]:
import os
from keys import OPENAI_KEY
from langchain_openai import ChatOpenAI
from utils.db import get_database_schema_execute_all, run_query_and_return_df

os.environ['OPENAI_API_KEY'] = OPENAI_KEY  

In [3]:
# Base directory where CSV files are located
path_to_csv_files = 'testDBs/test1/db/'
db_output_dir = 'example.db'
llm_model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0) 

In [4]:
db_schema = get_database_schema_execute_all(path_to_csv_files = path_to_csv_files, db_output_dir= db_output_dir)

CREATE TABLE "invoice" ("id" PRIMARY KEY, "order_id" TEXT, "creation_date" TEXT, FOREIGN KEY ("order_id") REFERENCES "order"("id"))
CREATE TABLE "order" ("id" PRIMARY KEY, "creation_date" TEXT)
CREATE TABLE "payment" ("id" PRIMARY KEY, "invoice_id" TEXT, "creation_date" TEXT, FOREIGN KEY ("invoice_id") REFERENCES "invoice"("id"))
CREATE TABLE "shipment" ("id" PRIMARY KEY, "order_id" TEXT, "creation_date" TEXT, FOREIGN KEY ("order_id") REFERENCES "order"("id"))
Data inserted into table order
Data inserted into table invoice
Data inserted into table payment
Data inserted into table shipment
Database created: example.db


In [5]:
def get_sql_query(state):
    messages = state['messages']
    user_input = messages[-1]
    response = llm_model.invoke(user_input)
    state['messages'].append(response.content) # appending AIMessage response to the AgentState
    return state

def run_sql_query(state):
    messages = state['messages']
    agent_response = messages[-1]
    #agent_response = 'SELECT * FROM "order" where "id"="o1"'
    try:
        df = run_query_and_return_df(path_to_db = db_output_dir, query = agent_response)
        state['df'] = df
    except Exception as e:
        state['df'] = 'ERROR'
    return state 

In [6]:
from langgraph.graph import Graph

workflow = Graph()
# nodes
workflow.add_node("agent", get_sql_query)
workflow.add_node("tool", run_sql_query)
# edges
workflow.add_edge('agent', 'tool')
# entry, exit
workflow.set_entry_point("agent")
workflow.set_finish_point("tool")

app = workflow.compile()

In [9]:
prompt="""Consider the following db schema:
            {db_schema}
            Write a sql statements that returns all in the order table. Use quotes for identifiers. Provide only the query."""
AgentState = {"messages": [prompt]}
app.invoke(AgentState)

Skipping write for channel tags which has no readers
Skipping write for channel tags which has no readers


{'messages': ['How are you?',
  "I'm just a computer program, so I don't have feelings, but I'm here to help you with anything you need. How can I assist you today?"],
 'df': 'ERROR'}