In [5]:
import os
import sqlite3
import pandas as pd
from typing import List, Union, Literal, Sequence, TypedDict, Annotated
from pydantic import BaseModel
import functools
import operator
from langchain_anthropic import ChatAnthropic
from langchain.agents import AgentExecutor, create_sql_agent
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities.sql_database import SQLDatabase
from langgraph.prebuilt import ToolExecutor
from langgraph.graph import END, StateGraph, START

# Configuration
ANTHROPIC_API_KEY = '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'
DB_PATH = "apple_last_year_data.csv"

def init_database(csv_path: str):
    try:
        conn = sqlite3.connect('consumption.db', check_same_thread=False)
        df = pd.read_csv(csv_path)
        df.to_sql('consumption', conn, index=False, if_exists='replace')
        return SQLDatabase.from_uri("sqlite:///consumption.db")
    except Exception as e:
        print(f"Error initializing database: {str(e)}")
        raise

def get_schema_info(db: SQLDatabase) -> str:
    """Get formatted schema information for the database."""
    table_info = db.get_table_info()
    return f"Database Schema:\n{table_info}"

def init_llm():
    try:
        return ChatAnthropic(
            model="claude-3-sonnet-20240229",
            temperature=0,
            api_key=ANTHROPIC_API_KEY
        )
    except Exception as e:
        print(f"Error initializing LLM: {str(e)}")
        raise

def create_interpretation_prompt(schema_info: str) -> str:
    return f"""You are an agent designed to interpret user questions and create SQL queries for a financial database.
First, analyze the user's question in the context of the available schema. Then create a syntactically correct SQL query.

{schema_info}

Here are two examples showing how to interpret questions and create appropriate queries:

Example 1:
User Question: "How did Apple perform last month?"
Interpretation: This question is asking about Apple stock's performance metrics for the previous month. We should look at key indicators like closing price changes, volume, and price range.
SQL Query: 
WITH last_month AS (
    SELECT 
        ROUND(MAX(close), 2) as highest_close,
        ROUND(MIN(close), 2) as lowest_close,
        ROUND(AVG(close), 2) as avg_close,
        ROUND(((MAX(close) - MIN(close)) / MIN(close) * 100), 2) as price_range_percent,
        ROUND(AVG(volume), 0) as avg_volume
    FROM consumption 
    WHERE strftime('%m', date) = strftime('%m', date('now', '-1 month'))
)
SELECT 
    highest_close,
    lowest_close,
    avg_close,
    price_range_percent,
    avg_volume
FROM last_month;

Example 2:
User Question: "What's the trading pattern on high volume days?"
Interpretation: This question is asking about stock behavior on days with above-average trading volume. We should analyze price movements on high volume days.
SQL Query:
WITH avg_volume AS (
    SELECT AVG(volume) as mean_volume
    FROM consumption
)
SELECT 
    date,
    ROUND(open, 2) as open_price,
    ROUND(close, 2) as close_price,
    ROUND(((close - open) / open * 100), 2) as price_change_percent,
    volume,
    ROUND((volume / avg.mean_volume), 2) as volume_ratio
FROM consumption
CROSS JOIN avg_volume avg
WHERE volume > avg.mean_volume
ORDER BY volume DESC;

Now, follow these guidelines for new queries:
1. Always interpret the user's question first
2. Consider the available columns in the schema
3. Round numeric results to two decimal places
4. Include relevant performance metrics
5. Pay attention to time periods mentioned
6. Use proper SQL syntax and formatting
7. Avoid LIMIT unless specifically requested
8. Consider multiple aspects of stock performance when relevant

Remember to:
1. Include appropriate table joins if needed
2. Ensure all column references are valid
3. Consider the business context of the question
4. Provide comprehensive analysis when needed"""

class InterpretationState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]
    next: str
    interpretation: str

def interpret_and_execute(state: InterpretationState, agent_executor, schema_info: str):
    try:
        last_message = state["messages"][-1].content
        
        # First, have the LLM interpret the question
        llm = init_llm()
        interpretation_prompt = f"""Given this user question: "{last_message}"
        And this database schema:
        {schema_info}
        
        Provide a brief interpretation of what data we need to query and why."""
        
        interpretation = llm.invoke(interpretation_prompt).content
        
        # Execute the query with both the original question and interpretation
        enhanced_question = f"""Question: {last_message}
        Interpretation: {interpretation}
        Please provide an SQL query that addresses this interpretation."""
        
        result = agent_executor.invoke({
            "input": enhanced_question
        })
        
        # Combine interpretation and results
        complete_response = f"""Interpretation: {interpretation}\n\nAnalysis: {result['output']}"""
        
        return {
            "messages": [HumanMessage(content=complete_response)],
            "interpretation": interpretation
        }
    except Exception as e:
        print(f"Error in interpretation node: {str(e)}")
        return {
            "messages": [HumanMessage(content=f"Error: {str(e)}")],
            "interpretation": "Error occurred during interpretation"
        }

def setup_sql_agent(db_path: str):
    try:
        db = init_database(db_path)
        llm = init_llm()
        toolkit = SQLDatabaseToolkit(db=db, llm=llm)
        
        schema_info = get_schema_info(db)
        interpretation_prefix = create_interpretation_prompt(schema_info)
        
        agent_executor = create_sql_agent(
            llm=llm,
            toolkit=toolkit,
            agent_type="zero-shot-react-description",
            verbose=True,
            prefix=interpretation_prefix
        )
        
        return agent_executor, db
    except Exception as e:
        print(f"Error setting up SQL agent: {str(e)}")
        raise

def create_workflow(agent_executor, schema_info: str):
    try:
        workflow = StateGraph(InterpretationState)
        
        # Create the interpretation and execution node
        interpret_node = functools.partial(
            interpret_and_execute,
            agent_executor=agent_executor,
            schema_info=schema_info
        )
        workflow.add_node("interpret_and_execute", interpret_node)
        
        # Add edges
        workflow.add_edge(START, "interpret_and_execute")
        workflow.add_edge("interpret_and_execute", END)
        
        return workflow.compile()
    except Exception as e:
        print(f"Error creating workflow: {str(e)}")
        raise

def run_sql_query(query: str, db_path: str):
    try:
        # Setup agent and get schema info
        agent_executor, db = setup_sql_agent(db_path)
        schema_info = get_schema_info(db)
        
        # Create workflow
        graph = create_workflow(agent_executor, schema_info)
        
        # Run query
        results = []
        for s in graph.stream({
            "messages": [HumanMessage(content=query)],
            "interpretation": ""
        }):
            if "__end__" not in s:
                results.append(s)
        
        return results
    except Exception as e:
        print(f"Error running SQL query: {str(e)}")
        return [{"messages": [HumanMessage(content=f"Error: {str(e)}")]}]

def main():
    try:
        # Example queries that demonstrate the enhanced interpretation
        queries = [
            "Show me the last 10 days price of the stock"
        ]
        
        for query in queries:
            print(f"\nExecuting query: {query}")
            print("-" * 50)
            
            results = run_sql_query(query, DB_PATH)
            for result in results:
                print(result)
                print("-" * 50)
                
    except Exception as e:
        print(f"Error in main execution: {str(e)}")

if __name__ == "__main__":
    main()


Executing query: Show me the last 10 days price of the stock
--------------------------------------------------


[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mThought: To get started, I should first check what tables are available in the database.

Action: sql_db_list_tables
Action Input: 
[0m[38;5;200m[1;3mconsumption[0m[32;1m[1;3mThought: The only table in the database is the `consumption` table, so I will need to query that table to get the requested information.

Action: sql_db_schema
Action Input: consumption
[0m[33;1m[1;3m
CREATE TABLE consumption (
	"Unnamed: 0" INTEGER, 
	"Date" TEXT, 
	"Price" REAL, 
	"Close" REAL, 
	"High" REAL, 
	"Low" REAL, 
	"Open" REAL, 
	"Volume" REAL
)

/*
3 rows from consumption table:
Unnamed: 0	Date	Price	Close	High	Low	Open	Volume
0	2023-12-18	194.9350128173828	195.88999938964844	196.6300048828125	194.38999938964844	196.08999633789065	55751900.0
1	2023-12-19	195.9798889160156	196.94000244140625	196.9499969482422	195.8899