In [4]:
import os
from typing import List, Dict, TypedDict, Annotated
from dataclasses import dataclass
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import SystemMessage, HumanMessage
from langchain.agents import AgentExecutor, create_sql_agent
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities.sql_database import SQLDatabase
from langgraph.graph import StateGraph, START, END

ANTHROPIC_API_KEY = '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'

@dataclass
class AnalysisQuestion:
    question: str
    sql_query: str = ""
    answer: str = ""

class State(TypedDict):
    messages: List[tuple]
    questions: List[AnalysisQuestion]
    final_analysis: str

QUESTION_DECOMPOSER_PROMPT = """You are an expert at breaking down complex stock analysis questions into specific sub-questions that can be answered using SQL queries.

Available data columns: date, open, high, low, close, volume

Break down the user's question into specific analytical questions that can be answered with SQL queries.
Your response should be in this format:
1. First specific question to answer
2. Second specific question to answer
...

Focus on questions that can be answered with the available data columns. Avoid questions about company fundamentals or external factors."""

ANALYSIS_SYNTHESIZER_PROMPT = """You are an expert financial analyst who synthesizes data-driven insights into clear, actionable analysis.

Using the answers to our analytical sub-questions, provide a comprehensive analysis that:
1. Directly answers the user's original question
2. Supports conclusions with specific data points
3. Provides context and implications
4. Highlights any important caveats or limitations

Be concise but thorough. Use actual numbers from the data to support your analysis."""

def init_sql_agent():
    db = SQLDatabase.from_uri("sqlite:///consumption.db")
    llm = ChatAnthropic(
        model="claude-3-sonnet-20240229",
        temperature=0,
        api_key=ANTHROPIC_API_KEY
    )
    toolkit = SQLDatabaseToolkit(db=db, llm=llm)
    return create_sql_agent(
        llm=llm,
        toolkit=toolkit,
        verbose=True
    )

def decompose_question(state: State):
    llm = ChatAnthropic(
        model="claude-3-sonnet-20240229",
        temperature=0,
        api_key=ANTHROPIC_API_KEY
    )
    
    original_question = state["messages"][0][1]
    response = llm.invoke([
        SystemMessage(content=QUESTION_DECOMPOSER_PROMPT),
        HumanMessage(content=original_question)
    ])
    
    # Parse the numbered list response into individual questions
    sub_questions = [
        AnalysisQuestion(question=q.strip().split(". ", 1)[1])
        for q in response.content.split("\n")
        if q.strip() and q[0].isdigit()
    ]
    
    return {
        "messages": state["messages"],
        "questions": sub_questions,
        "final_analysis": ""
    }

def get_sql_answers(state: State):
    agent = init_sql_agent()
    updated_questions = []
    
    for question in state["questions"]:
        result = agent.invoke({"input": question.question})
        question.answer = result["output"]
        updated_questions.append(question)
    
    return {
        "messages": state["messages"],
        "questions": updated_questions,
        "final_analysis": ""
    }

def synthesize_analysis(state: State):
    llm = ChatAnthropic(
        model="claude-3-sonnet-20240229",
        temperature=0,
        api_key=ANTHROPIC_API_KEY
    )
    
    # Format sub-questions and their answers
    analysis_context = "\n\n".join([
        f"Question: {q.question}\nAnswer: {q.answer}"
        for q in state["questions"]
    ])
    
    original_question = state["messages"][0][1]
    response = llm.invoke([
        SystemMessage(content=ANALYSIS_SYNTHESIZER_PROMPT),
        HumanMessage(content=f"""Original Question: {original_question}

Analysis Components:
{analysis_context}

Please provide a comprehensive analysis that answers the original question.""")
    ])
    
    return {
        "messages": state["messages"],
        "questions": state["questions"],
        "final_analysis": response.content
    }

def create_analysis_workflow():
    workflow = StateGraph(State)
    
    # Add nodes
    workflow.add_node("decompose", decompose_question)
    workflow.add_node("analyze", get_sql_answers)
    workflow.add_node("synthesize", synthesize_analysis)
    
    # Add edges
    workflow.add_edge(START, "decompose")
    workflow.add_edge("decompose", "analyze")
    workflow.add_edge("analyze", "synthesize")
    workflow.add_edge("synthesize", END)
    
    return workflow.compile()

def analyze_stock(question: str) -> str:
    """
    Analyze a stock based on a user question using question decomposition and SQL analysis.
    
    Args:
        question (str): User's analysis question
        
    Returns:
        str: Comprehensive analysis based on the data
    """
    try:
        workflow = create_analysis_workflow()
        
        final_state = None
        for state in workflow.stream({
            "messages": [("user", question)],
            "questions": [],
            "final_analysis": ""
        }):
            final_state = state
            
        return final_state["final_analysis"] if final_state else "Analysis failed to complete."
        
    except Exception as e:
        return f"An error occurred during analysis: {str(e)}"

# Example usage
if __name__ == "__main__":
    question = "Show me the last 10 days price of the stock as a json"
    result = analyze_stock(question)
    print(result)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: 
[0m[38;5;200m[1;3mconsumption[0m[32;1m[1;3mThought: The consumption table seems relevant to find date information, so I should query its schema.
Action: sql_db_schema
Action Input: consumption
[0m[33;1m[1;3m
CREATE TABLE consumption (
	"Unnamed: 0" INTEGER, 
	"Date" TEXT, 
	"Price" REAL, 
	"Close" REAL, 
	"High" REAL, 
	"Low" REAL, 
	"Open" REAL, 
	"Volume" REAL
)

/*
3 rows from consumption table:
Unnamed: 0	Date	Price	Close	High	Low	Open	Volume
0	2023-12-18	194.9350128173828	195.88999938964844	196.6300048828125	194.38999938964844	196.08999633789065	55751900.0
1	2023-12-19	195.9798889160156	196.94000244140625	196.9499969482422	195.88999938964844	196.16000366210935	40714100.0
2	2023-12-20	193.88018798828125	194.8300018310547	197.67999267578125	194.8300018310547	196.8999938964844	52242800.0
*/[0m[32;1m[1;3mThought: The "Date" column in the consumption table looks like 