In [1]:
import os
import openai
from dotenv import load_dotenv
from sql_market_agent.agent.agent import create_sql_market_agent
from langchain_openai.chat_models import ChatOpenAI
from langchain.callbacks import StdOutCallbackHandler
import matplotlib.pyplot as plt

load_dotenv()

OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
DB_HOST = os.getenv("POSTGRES_HOST", "localhost")
DB_PORT = os.getenv("POSTGRES_PORT", "5432")
DB_NAME = os.environ.get("POSTGRES_DB")
DB_USER = os.environ.get("POSTGRES_USER")
DB_PASSWORD = os.environ.get("POSTGRES_PASSWORD")

2024-01-29 16:11:14,791 - INFO - NumExpr defaulting to 8 threads.


In [2]:
%matplotlib inline


In [3]:
openai.api_key = OPENAI_API_KEY


llm = ChatOpenAI(temperature=0, model="gpt-4", streaming=True)

# Create postgres connection string:
# db_connection_string = (
#     f"postgresql+psycopg2://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
# )
agent_executor = create_sql_market_agent(
    llm=llm,
    # db_connection_string=db_connection_string,
    preinitialize_database=True,
    stocks=["CAT", "NKE", "XOM"],
)

2024-01-29 16:11:17,915 - INFO - DB connection string not provided, using local db on disc...
2024-01-29 16:11:17,928 - INFO - Initializing database with StockData table...
2024-01-29 16:11:17,945 - INFO - Database (sqlite) initialized.
2024-01-29 16:11:17,946 - INFO - Fetching data for analysis...
  df.index += _pd.TimedeltaIndex(dst_error_hours, 'h')
  df.index += _pd.TimedeltaIndex(dst_error_hours, 'h')
  df.index += _pd.TimedeltaIndex(dst_error_hours, 'h')
2024-01-29 16:11:19,428 - INFO - Data fetched and inserted successfully.
2024-01-29 16:11:19,430 - INFO - Connecting to db for sql agent...
2024-01-29 16:11:19,438 - INFO - Connected to db for sql agent successfully
  warn_deprecated(


In [5]:
response = agent_executor(
            {"input": "plot close price for XOM most recent 5 days", "chat_history": []},
            # callbacks=[StdOutCallbackHandler()],
)



[1m> Entering new AgentExecutor chain...[0m


2024-01-29 16:13:13,336 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


[32;1m[1;3m
Invoking: `sql_database_tool` with `What is the most recent date for which you have data in the database?`


[0m

[1m> Entering new AgentExecutor chain...[0m


2024-01-29 16:13:15,282 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


[32;1m[1;3mTo find the most recent date in the database, I need to identify the tables and their schemas first. Then, I can look for fields that likely contain date information, such as those named 'date', 'timestamp', etc. After identifying these fields, I can write a SQL query to find the most recent date.
Action: sql_db_list_tables
Action Input: ""[0m[38;5;200m[1;3mstockdata[0m

2024-01-29 16:13:20,400 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


[32;1m[1;3mThe database contains a table named 'stockdata'. I need to check the schema of this table to identify any fields that might contain date information.
Action: sql_db_schema
Action Input: "stockdata"[0m[33;1m[1;3m
CREATE TABLE stockdata (
	id INTEGER, 
	symbol TEXT NOT NULL, 
	sector TEXT, 
	date TEXT NOT NULL, 
	open REAL NOT NULL, 
	high REAL NOT NULL, 
	low REAL NOT NULL, 
	close REAL NOT NULL, 
	volume INTEGER NOT NULL, 
	dailychangepercent REAL NOT NULL, 
	PRIMARY KEY (id), 
	UNIQUE (symbol, date)
)

/*
3 rows from stockdata table:
id	symbol	sector	date	open	high	low	close	volume	dailychangepercent
1	CAT	Farm & Heavy Construction Machinery	2024-01-03	287.06008758155457	287.597575469304	282.01346088643976	282.9889221191406	3043400	-2.8731506923064756
2	CAT	Farm & Heavy Construction Machinery	2024-01-04	283.0984307494382	286.89085828753036	282.48129477500663	284.7806396484375	2995400	0.6331405186746312
3	CAT	Farm & Heavy Construction Machinery	2024-01-05	284.9299310610

2024-01-29 16:13:23,268 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


[32;1m[1;3mThe 'stockdata' table contains a 'date' field. I can write a SQL query to find the most recent date in this field.
Action: sql_db_query_checker
Action Input: "SELECT MAX(date) FROM stockdata"[0m

2024-01-29 16:13:26,545 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


[36;1m[1;3mSELECT MAX(date) FROM stockdata[0m

2024-01-29 16:13:27,262 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


[32;1m[1;3mThe SQL query is correct. Now I can execute it to find the most recent date in the 'stockdata' table.
Action: sql_db_query
Action Input: "SELECT MAX(date) FROM stockdata"[0m[36;1m[1;3m[('2024-01-26',)][0m

2024-01-29 16:13:30,232 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


[32;1m[1;3mI now know the final answer.
Final Answer: The most recent date for which we have data in the database is 2024-01-26.[0m

[1m> Finished chain.[0m
[38;5;200m[1;3mThe most recent date for which we have data in the database is 2024-01-26.[0m

2024-01-29 16:13:32,177 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


[32;1m[1;3mThe most recent date for which we have data in the database is January 26, 2024.[0m

[1m> Finished chain.[0m
