In [7]:
import os
from typing import Dict, List, Optional, TypedDict
from dataclasses import dataclass
import pandas as pd
import json
from langchain_anthropic import ChatAnthropic
from langchain.agents import create_sql_agent
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_core.messages import SystemMessage, HumanMessage
from langgraph.graph import StateGraph, START, END
from dotenv import load_dotenv

# Load environment variables
load_dotenv('api_key.env')

class AnalysisState(TypedDict):
    user_query: str
    decomposed_questions: List[str]
    sql_results: Dict
    analysis: str
    final_output: Dict

class ConfigError(Exception):
    """Custom exception for configuration errors"""
    pass

@dataclass
class Config:
    db_path: str = "apple_last_year_data.csv"
    sqlite_path: str = "sqlite:///consumption.db"
    model_name: str = "claude-3-sonnet-20240229"
    
    @property
    def api_key(self) -> str:
        api_key = os.getenv("ANTHROPIC_API_KEY")
        if not api_key:
            raise ConfigError("ANTHROPIC_API_KEY not found in api_key.env file")
        return api_key

# Comprehensive SQL Agent prompt with examples
SQL_AGENT_PROMPT = """You are an expert financial database analyst. Your task is to:
1. Analyze stock market queries
2. Create appropriate SQL queries
3. Provide clear results

Database Schema:
- Table: consumption
- Columns: date, open, high, low, close, volume

Example 1:
User: "What's the stock's performance last week?"
Thought: Need to analyze daily price changes and volume for the past week
SQL:
SELECT 
    date,
    ROUND(open, 2) as open_price,
    ROUND(close, 2) as close_price,
    ROUND(((close - open) / open * 100), 2) as daily_return,
    ROUND(high, 2) as high,
    ROUND(low, 2) as low,
    volume
FROM consumption
WHERE date >= date('now', '-7 days')
ORDER BY date DESC;

Example 2:
User: "Find volatile trading days"
Thought: Looking for days with large price ranges and high volume
SQL:
WITH metrics AS (
    SELECT AVG(volume) as avg_vol,
           AVG((high - low) / open * 100) as avg_range
    FROM consumption
)
SELECT 
    date,
    ROUND(open, 2) as open_price,
    ROUND(close, 2) as close_price,
    ROUND(((high - low) / open * 100), 2) as price_range_pct,
    volume,
    ROUND(volume / avg_vol, 2) as vol_ratio
FROM consumption, metrics
WHERE (high - low) / open * 100 > avg_range
AND volume > avg_vol
ORDER BY price_range_pct DESC
LIMIT 5;

Your responses should include:
1. Thought process
2. SQL query
3. Result interpretation"""

# Analysis prompt for interpretation
ANALYST_PROMPT = """You are an expert financial analyst. Analyze the provided SQL results and provide insights.

Focus on:
1. Price trends and patterns
2. Volume analysis
3. Technical indicators
4. Risk assessment
5. Notable patterns

Example Analysis Structure:
1. Key Findings
   - Main price trends
   - Volume patterns
   - Notable events

2. Technical Analysis
   - Support/resistance levels
   - Pattern recognition
   - Momentum indicators

3. Risk Assessment
   - Volatility measures
   - Liquidity analysis
   - Risk factors

4. Recommendations
   - Key levels to watch
   - Risk considerations
   - Potential scenarios

Be specific and data-driven in your analysis."""

class StockAnalyzer:
    def __init__(self, config: Config):
        self.config = config
        self.db = self._init_database()
        self.llm = self._init_llm()
        self.sql_agent = self._setup_sql_agent()
        self.workflow = self._create_workflow()

    def _init_database(self) -> SQLDatabase:
        """Initialize database from CSV"""
        try:
            df = pd.read_csv(self.config.db_path)
            df.to_sql('consumption', 'sqlite:///consumption.db', index=False, if_exists='replace')
            return SQLDatabase.from_uri(self.config.sqlite_path)
        except Exception as e:
            raise ConfigError(f"Database initialization failed: {str(e)}")

    def _init_llm(self) -> ChatAnthropic:
        """Initialize the language model"""
        return ChatAnthropic(
            model=self.config.model_name,
            temperature=0,
            api_key=self.config.api_key
        )

    def _setup_sql_agent(self):
        """Setup the SQL agent"""
        toolkit = SQLDatabaseToolkit(db=self.db, llm=self.llm)
        return create_sql_agent(
            llm=self.llm,
            toolkit=toolkit,
            agent_type="zero-shot-react-description",
            verbose=True,
            prefix=SQL_AGENT_PROMPT
        )

    def _decompose_question(self, state: AnalysisState) -> AnalysisState:
        """Break down the main question into sub-questions"""
        response = self.llm.invoke([
            SystemMessage(content="Break down this stock analysis question into specific sub-questions:"),
            HumanMessage(content=state['user_query'])
        ])
        
        questions = [
            q.strip().split(". ", 1)[1] 
            for q in response.content.split("\n") 
            if q.strip() and q[0].isdigit()
        ]
        
        return {**state, "decomposed_questions": questions}

    def _run_sql_analysis(self, state: AnalysisState) -> AnalysisState:
        """Run SQL analysis for each sub-question"""
        results = {}
        for i, question in enumerate(state["decomposed_questions"], 1):
            try:
                result = self.sql_agent.invoke({"input": question})
                results[f"question_{i}"] = {
                    "question": question,
                    "thought": self._extract_thought(result['output']),
                    "sql": self._extract_sql(result['output']),
                    "result": self._extract_result(result['output'])
                }
            except Exception as e:
                results[f"question_{i}"] = {
                    "error": str(e),
                    "question": question
                }
        
        return {**state, "sql_results": results}

    def _analyze_results(self, state: AnalysisState) -> AnalysisState:
        """Generate comprehensive analysis"""
        results_context = json.dumps(state["sql_results"], indent=2)
        analysis = self.llm.invoke([
            SystemMessage(content=ANALYST_PROMPT),
            HumanMessage(content=f"""
            Original Question: {state['user_query']}
            
            Analysis Results:
            {results_context}
            
            Provide a comprehensive analysis.""")
        ])
        
        return {**state, "analysis": analysis.content}

    def _format_output(self, state: AnalysisState) -> AnalysisState:
        """Format the final output"""
        return {
            **state,
            "final_output": {
                "user_query": state["user_query"],
                "sub_questions": state["decomposed_questions"],
                "sql_analysis": state["sql_results"],
                "expert_analysis": state["analysis"],
                "timestamp": pd.Timestamp.now().isoformat()
            }
        }

    def _extract_thought(self, text: str) -> str:
        """Extract thought process from response"""
        if "Thought:" in text:
            return text.split("Thought:")[1].split("SQL")[0].strip()
        return ""

    def _extract_sql(self, text: str) -> str:
        """Extract SQL query from response"""
        if "SQL:" in text:
            return text.split("SQL:")[1].split("Result")[0].strip()
        return ""

    def _extract_result(self, text: str) -> str:
        """Extract results from response"""
        if "SQLResult:" in text:
            return text.split("SQLResult:")[1].strip()
        return ""

    def _create_workflow(self) -> StateGraph:
        """Create the analysis workflow"""
        workflow = StateGraph(AnalysisState)
        
        workflow.add_node("decompose", self._decompose_question)
        workflow.add_node("sql_analysis", self._run_sql_analysis)
        workflow.add_node("analyze", self._analyze_results)
        workflow.add_node("format", self._format_output)
        
        workflow.add_edge(START, "decompose")
        workflow.add_edge("decompose", "sql_analysis")
        workflow.add_edge("sql_analysis", "analyze")
        workflow.add_edge("analyze", "format")
        workflow.add_edge("format", END)
        
        return workflow.compile()

    def analyze(self, query: str) -> Dict:
        """Run the complete analysis"""
        try:
            final_state = None
            for state in self.workflow.stream({
                "user_query": query,
                "decomposed_questions": [],
                "sql_results": {},
                "analysis": "",
                "final_output": {}
            }):
                final_state = state
            
            return final_state["final_output"]
        except Exception as e:
            return {"error": str(e), "query": query}

def format_output(results: Dict) -> None:
    """Format and print analysis results"""
    print("\n=== Stock Analysis Results ===")
    print(f"\nQuery: {results['user_query']}")
    
    print("\nSub-Questions:")
    for i, q in enumerate(results['sub_questions'], 1):
        print(f"{i}. {q}")
    
    print("\nSQL Analysis:")
    for key, data in results['sql_analysis'].items():
        print(f"\nQuestion: {data['question']}")
        if 'error' not in data:
            print(f"Thought Process: {data['thought']}")
            print(f"SQL Query: {data['sql']}")
            try:
                result_data = json.loads(data['result'])
                df = pd.DataFrame(result_data)
                print("\nResults:")
                print(df.to_string(index=False))
            except:
                print(f"Results: {data['result']}")
        else:
            print(f"Error: {data['error']}")
    
    print("\nExpert Analysis:")
    print(results['expert_analysis'])

def main():
    try:
        # Initialize analyzer
        config = Config()
        analyzer = StockAnalyzer(config)
        
        # Example queries
        queries = [
            "Find the most volatile trading days and analyze their patterns"
        ]
        
        for query in queries:
            print(f"\nProcessing: {query}")
            print("=" * 50)
            
            results = analyzer.analyze(query)
            
            if "error" not in results:
                # Display formatted results
                format_output(results)
                
                # Save to JSON file
                filename = f"analysis_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.json"
                with open(filename, 'w') as f:
                    json.dump(results, f, indent=2)
                print(f"\nDetailed results saved to {filename}")
            else:
                print(f"Error: {results['error']}")
            
            print("\n" + "="*50)
                
    except ConfigError as e:
        print(f"Configuration Error: {e}")
    except Exception as e:
        print(f"Unexpected error: {e}")

if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\nAnalysis interrupted by user")
    except Exception as e:
        print(f"Fatal error: {str(e)}")


Processing: Find the most volatile trading days and analyze their patterns


[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: ""[0m[38;5;200m[1;3mconsumption[0m[32;1m[1;3mThought: The `consumption` table seems to contain stock data, so I should query its schema to see what columns are available.
Action: sql_db_schema
Action Input: consumption[0m[33;1m[1;3m
CREATE TABLE consumption (
	"Unnamed: 0" BIGINT, 
	"Date" TEXT, 
	"Price" FLOAT, 
	"Close" FLOAT, 
	"High" FLOAT, 
	"Low" FLOAT, 
	"Open" FLOAT, 
	"Volume" FLOAT
)

/*
3 rows from consumption table:
Unnamed: 0	Date	Price	Close	High	Low	Open	Volume
0	2023-12-18	194.9350128173828	195.88999938964844	196.6300048828125	194.38999938964844	196.08999633789065	55751900.0
1	2023-12-19	195.9798889160156	196.94000244140625	196.9499969482422	195.88999938964844	196.16000366210935	40714100.0
2	2023-12-20	193.88018798828125	194.8300018310547	197.67999267578125	194.8300018310547	196.89