In [1]:
import os
import sqlite3
import pandas as pd
from typing import List, Union, Literal, Sequence, TypedDict, Annotated
from pydantic import BaseModel
import functools
import operator
from langchain_anthropic import ChatAnthropic
from langchain.agents import AgentExecutor, create_sql_agent
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities.sql_database import SQLDatabase
from langgraph.prebuilt import ToolExecutor
from langgraph.graph import END, StateGraph, START

# Configuration
ANTHROPIC_API_KEY = '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'
DB_PATH = "apple_last_year_data.csv"

def init_database(csv_path: str):
    try:
        conn = sqlite3.connect('consumption.db', check_same_thread=False)
        df = pd.read_csv(csv_path)
        df.to_sql('consumption', conn, index=False, if_exists='replace')
        return SQLDatabase.from_uri("sqlite:///consumption.db")
    except Exception as e:
        print(f"Error initializing database: {str(e)}")
        raise

def init_llm():
    try:
        return ChatAnthropic(
            model="claude-3-sonnet-20240229",
            temperature=0,
            api_key=ANTHROPIC_API_KEY
        )
    except Exception as e:
        print(f"Error initializing LLM: {str(e)}")
        raise

TWO_SHOT_PREFIX = """You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct SQL query to run.

Here are two examples:

Example 1:
Question: What was the total revenue in January?
Thought: I need to sum the revenue column for January.
Action: I should use the sql_db_query tool to query the database.
SQL: SELECT ROUND(SUM(revenue), 2) as total_revenue FROM consumption WHERE strftime('%m', date) = '01'

Example 2:
Question: Which product had the highest sales growth between Q1 and Q2?
Thought: I need to calculate sales growth between quarters and find the maximum.
Action: I should use the sql_db_query tool with a more complex query.
SQL: WITH q1_sales AS (
    SELECT product, SUM(revenue) as q1_rev
    FROM consumption 
    WHERE strftime('%m', date) IN ('01','02','03')
    GROUP BY product
),
q2_sales AS (
    SELECT product, SUM(revenue) as q2_rev
    FROM consumption 
    WHERE strftime('%m', date) IN ('04','05','06')
    GROUP BY product
)
SELECT 
    q1.product,
    ROUND(((q2.q2_rev - q1.q1_rev) / q1.q1_rev * 100), 2) as growth_rate
FROM q1_sales q1
JOIN q2_sales q2 ON q1.product = q2.product
ORDER BY growth_rate DESC
LIMIT 1

Now, follow these guidelines for new queries:
- Do not use LIMIT statements unless specifically asked
- Round numeric results to two decimal places
- Avoid complex SQL queries with division when possible
- Perform operations step by step
- Pay attention to all conditions mentioned in the query
- For questions on market share, use column="Amount" unless stated otherwise
- YTD or ytd = Year to Date
- Don't assume current year unless specified

The schema of the table is: consumption

Remember:
1. Include appropriate table joins if needed
2. Ensure all column references are valid
3. Use proper SQL syntax and formatting"""

def setup_sql_agent(db_path: str):
    try:
        db = init_database(db_path)
        llm = init_llm()
        toolkit = SQLDatabaseToolkit(db=db, llm=llm)
        
        # Create the SQL agent using zero-shot-react-description
        agent_executor = create_sql_agent(
            llm=llm,
            toolkit=toolkit,
            agent_type="zero-shot-react-description",
            verbose=True,
            prefix=TWO_SHOT_PREFIX
        )
        
        return agent_executor
    except Exception as e:
        print(f"Error setting up SQL agent: {str(e)}")
        raise

class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]
    next: str

def agent_node(state: AgentState, agent_executor):
    try:
        # Extract the last message content
        last_message = state["messages"][-1].content
        
        # Run the agent
        result = agent_executor.invoke({
            "input": last_message
        })
        
        # Return the result as a message
        return {"messages": [HumanMessage(content=str(result["output"]))]}
    except Exception as e:
        print(f"Error in agent node: {str(e)}")
        return {"messages": [HumanMessage(content=f"Error: {str(e)}")]}

def create_workflow(agent_executor):
    try:
        workflow = StateGraph(AgentState)
        
        # Create the agent node
        sql_node = functools.partial(agent_node, agent_executor=agent_executor)
        workflow.add_node("sql_agent", sql_node)
        
        # Add edges
        workflow.add_edge(START, "sql_agent")
        workflow.add_edge("sql_agent", END)
        
        return workflow.compile()
    except Exception as e:
        print(f"Error creating workflow: {str(e)}")
        raise

def run_sql_query(query: str, db_path: str):
    try:
        # Setup agent
        agent_executor = setup_sql_agent(db_path)
        
        # For debugging, print the table schema
        db = init_database(db_path)
        print("\nDatabase Schema:")
        print(db.get_table_info())
        
        # Create workflow
        graph = create_workflow(agent_executor)
        
        # Run query
        results = []
        for s in graph.stream({
            "messages": [HumanMessage(content=query)]
        }):
            if "__end__" not in s:
                results.append(s)
        
        return results
    except Exception as e:
        print(f"Error running SQL query: {str(e)}")
        return [{"messages": [HumanMessage(content=f"Error: {str(e)}")]}]

def main():
    try:
        # Example queries
        queries = [
            "which date has the highest close?",
            "Which month had the best performance?"
        ]
        
        for query in queries:
            print(f"\nExecuting query: {query}")
            print("-" * 50)
            
            results = run_sql_query(query, DB_PATH)
            for result in results:
                print(result)
                print("-" * 50)
                
    except Exception as e:
        print(f"Error in main execution: {str(e)}")

if __name__ == "__main__":
    main()


Executing query: which date has the highest close?
--------------------------------------------------

Database Schema:

CREATE TABLE consumption (
	"Unnamed: 0" INTEGER, 
	"Date" TEXT, 
	"Price" REAL, 
	"Close" REAL, 
	"High" REAL, 
	"Low" REAL, 
	"Open" REAL, 
	"Volume" REAL
)

/*
3 rows from consumption table:
Unnamed: 0	Date	Price	Close	High	Low	Open	Volume
0	2023-12-18	194.9350128173828	195.88999938964844	196.6300048828125	194.38999938964844	196.08999633789065	55751900.0
1	2023-12-19	195.9798889160156	196.94000244140625	196.9499969482422	195.88999938964844	196.16000366210935	40714100.0
2	2023-12-20	193.88018798828125	194.8300018310547	197.67999267578125	194.8300018310547	196.8999938964844	52242800.0
*/


[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mThought: To see what tables are available in the database, I should first list the tables.
Action: sql_db_list_tables
Action Input: ""[0m[38;5;200m[1;3mconsumption[0m[32;1m[1;3mThought: Since there is only one 