In [16]:
# Part 1: Imports and Basic Setup
# Import required libraries for data processing, database operations, language models and environment variables
import os
from typing import Dict, List, Optional, TypedDict, Literal, Union
from dataclasses import dataclass
from enum import Enum
import pandas as pd
import json
from langchain_anthropic import ChatAnthropic
from langchain.agents import create_sql_agent
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_core.messages import SystemMessage, HumanMessage
from langgraph.graph import StateGraph, START, END
from dotenv import load_dotenv
import time
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
import sqlite3
import re

# Load API keys from environment file
load_dotenv('api_key.env')

# Part 2: Type Definitions and Base Classes
# Define types and base classes for query handling and state management

# Enum for different types of queries that can be processed
class QueryType(Enum):
    DIRECT_SQL = "direct_sql"  # For simple data retrieval
    ANALYSIS = "analysis"      # For complex analysis

# Class to store query classification results
@dataclass
class QueryClassification:
    type: QueryType
    explanation: str
    raw_response: str

# TypedDict to maintain state during analysis workflow
class AnalysisState(TypedDict):
    user_query: str
    query_classification: Dict
    decomposed_questions: List[str]
    sql_results: Dict
    analysis: str
    final_output: Dict
    token_usage: Dict
    processing_time: float
    agent_states: Dict  # Store intermediate states and raw outputs
    raw_responses: Dict # Store raw LLM responses

# Custom exception for configuration related errors
class ConfigError(Exception):
    """Custom exception for configuration errors"""
    pass

# Configuration class to store and validate settings
@dataclass
class Config:
    db_path: str = "apple_last_year_data.csv"
    sqlite_path: str = "sqlite:///consumption.db"
    model_name: str = "claude-3-sonnet-20240229"
    
    @property
    def api_key(self) -> str:
        api_key = os.getenv("ANTHROPIC_API_KEY")
        if not api_key:
            raise ConfigError("ANTHROPIC_API_KEY not found in api_key.env file")
        return api_key

# Part 3: Prompt Templates
# Define system prompts for different stages of analysis

# Prompt for classifying incoming queries
QUERY_CLASSIFIER_PROMPT = """You are a query classifier that determines if a stock market question needs complex analysis or can be answered with a direct SQL query.

Example 1:
Question: "Show me the last 5 days of stock prices"
Classification: direct_sql
Explanation: This is a straightforward data retrieval request.

Example 2:
Question: "What are the emerging trends in trading volume and their impact on price?"
Classification: analysis
Explanation: This requires complex analysis of relationships and patterns.

Respond in JSON format:
{
    "type": "direct_sql" or "analysis",
    "explanation": "brief explanation of classification"
}
"""

# Prompt for SQL agent to handle database queries
SQL_AGENT_PROMPT = """You are an expert financial database analyst. Your task is to:
1. Analyze stock market queries
2. Create appropriate SQL queries
3. Provide clear results

Example 1:
User: "What's the stock's performance last week?"
Thought: Need to analyze daily price changes and volume for the past week
SQL:
SELECT 
    date,
    ROUND(open, 2) as open_price,
    ROUND(close, 2) as close_price,
    ROUND(((close - open) / open * 100), 2) as daily_return,
    ROUND(high, 2) as high,
    ROUND(low, 2) as low,
    volume
FROM consumption
WHERE date >= date('now', '-7 days')
ORDER BY date DESC;

Example 2:
User: "Find volatile trading days"
Thought: Looking for days with large price ranges and high volume
SQL:
WITH metrics AS (
    SELECT AVG(volume) as avg_vol,
           AVG((high - low) / open * 100) as avg_range
    FROM consumption
)
SELECT 
    date,
    ROUND(open, 2) as open_price,
    ROUND(close, 2) as close_price,
    ROUND(((high - low) / open * 100), 2) as price_range_pct,
    volume,
    ROUND(volume / avg_vol, 2) as vol_ratio
FROM consumption, metrics
WHERE (high - low) / open * 100 > avg_range
AND volume > avg_vol
ORDER BY price_range_pct DESC
LIMIT 5;

Your responses should include:
1. Thought process
2. SQL query
3. Result interpretation"""

# Prompt for financial analysis of SQL results
ANALYST_PROMPT = """You are an expert financial analyst. Analyze the provided SQL results and provide insights.

Focus on:
1. Price trends and patterns
2. Volume analysis
3. Technical indicators
4. Risk assessment
5. Notable patterns

Example Analysis Structure:
1. Key Findings
   - Main price trends
   - Volume patterns
   - Notable events

2. Technical Analysis
   - Support/resistance levels
   - Pattern recognition
   - Momentum indicators

3. Risk Assessment
   - Volatility measures
   - Liquidity analysis
   - Risk factors

4. Recommendations
   - Key levels to watch
   - Risk considerations
   - Potential scenarios

Be specific and data-driven in your analysis."""

# Part 4: Main StockAnalyzer Class
# Core class that orchestrates the entire analysis process
class StockAnalyzer:
    def __init__(self, config: Config):
        """Initialize analyzer with configuration and setup components"""
        self.config = config
        self.db = self._init_database()
        self.llm = self._init_llm()
        self.sql_agent = self._setup_sql_agent()
        self.workflow = self._create_workflow()
        self.token_usage = {"prompt_tokens": 0, "completion_tokens": 0}
        self.anthropic_client = Anthropic(api_key=config.api_key)
        self.agent_states = {}  # Store intermediate states
        self.raw_responses = {} # Store raw responses
        self.conn = sqlite3.connect('consumption.db')

    def _init_database(self) -> SQLDatabase:
        """Initialize SQLite database from CSV data"""
        try:
            df = pd.read_csv(self.config.db_path)
            df.to_sql('consumption', 'sqlite:///consumption.db', index=False, if_exists='replace')
            return SQLDatabase.from_uri(self.config.sqlite_path)
        except Exception as e:
            raise ConfigError(f"Database initialization failed: {str(e)}")

    def _init_llm(self) -> ChatAnthropic:
        """Initialize the language model with specified configuration"""
        return ChatAnthropic(
            model=self.config.model_name,
            temperature=0,
            api_key=self.config.api_key
        )

    def _setup_sql_agent(self):
        """Setup SQL agent with database toolkit and language model"""
        toolkit = SQLDatabaseToolkit(db=self.db, llm=self.llm)
        return create_sql_agent(
            llm=self.llm,
            toolkit=toolkit,
            agent_type="zero-shot-react-description",
            verbose=True,
            prefix=SQL_AGENT_PROMPT
        )

    def _update_token_usage(self, response):
        """Track token usage from model responses"""
        try:
            if hasattr(response, '_raw_response') and 'usage' in response._raw_response:
                usage = response._raw_response['usage']
                self.token_usage["prompt_tokens"] += usage.get('input_tokens', 0)
                self.token_usage["completion_tokens"] += usage.get('output_tokens', 0)
            elif isinstance(response, dict) and 'usage' in response:
                usage = response['usage']
                self.token_usage["prompt_tokens"] += usage.get('input_tokens', 0)
                self.token_usage["completion_tokens"] += usage.get('output_tokens', 0)
            elif hasattr(response, 'usage'):
                usage = response.usage
                self.token_usage["prompt_tokens"] += usage.input_tokens if hasattr(usage, 'input_tokens') else 0
                self.token_usage["completion_tokens"] += usage.output_tokens if hasattr(usage, 'output_tokens') else 0
            else:
                # Make a direct API call to get token count
                message = response.content if hasattr(response, 'content') else str(response)
                result = self.anthropic_client.messages.create(
                    model=self.config.model_name,
                    messages=[{"role": "user", "content": message}],
                    max_tokens=1
                )
                if hasattr(result, 'usage'):
                    self.token_usage["prompt_tokens"] += result.usage.input_tokens
                    self.token_usage["completion_tokens"] += result.usage.output_tokens
        except Exception as e:
            print(f"Error updating token usage: {str(e)}")

    def _classify_query(self, query: str) -> QueryClassification:
        """Determine if query needs complex analysis or direct SQL"""
        try:
            start_time = time.time()
            response = self.llm.invoke([
                SystemMessage(content=QUERY_CLASSIFIER_PROMPT),
                HumanMessage(content=f"Classify this question: {query}")
            ])
            
            self._update_token_usage(response)
            classification = json.loads(response.content)
            
            # Store raw response
            self.raw_responses['classification'] = response.content
            
            return QueryClassification(
                type=QueryType(classification['type']),
                explanation=classification['explanation'],
                raw_response=response.content
            )
        except Exception as e:
            return QueryClassification(
                type=QueryType.ANALYSIS,
                explanation="Classification failed, defaulting to analysis",
                raw_response=str(e)
            )

    def _direct_sql_query(self, query: str) -> Dict:
        """Process simple queries that only need SQL execution"""
        start_time = time.time()
        try:
            result = self.sql_agent.invoke({"input": query})
            self._update_token_usage(result)
            
            # Store agent state
            self.agent_states['direct_sql'] = result
            
            thought = self._extract_thought(result['output'])
            sql = self._extract_sql(result['output'])
            
            # Execute SQL query and get results
            try:
                # Extract SQL query from agent output
                sql_query = sql if sql else result['output'].split('SELECT')[1].split('Final Answer')[0].strip()
                if not sql_query.lower().startswith('select'):
                    sql_query = 'SELECT ' + sql_query
                
                # Clean up SQL query by removing any text after the first semicolon
                sql_query = sql_query.split(';')[0] + ';'
                
                # Execute query and get results
                df = pd.read_sql_query(sql_query, self.conn)
                formatted_results = df.to_dict('records')
            except Exception as e:
                formatted_results = f"Error executing SQL: {str(e)}"
            
            processing_time = time.time() - start_time
            
            output_data = {
                "query_type": "direct_sql",
                "user_query": query,
                "thought_process": thought if thought else "No thought process provided",
                "sql_query": sql if sql else sql_query,
                "results": formatted_results,
                "raw_agent_output": result['output'],
                "timestamp": pd.Timestamp.now().isoformat(),
                "token_usage": self.token_usage,
                "processing_time": processing_time,
                "agent_state": result
            }
            
            # Save to query-specific output file
            filename = f"{query[:50].replace(' ', '_').lower()}_analysis.json"
            with open(filename, 'w') as f:
                json.dump(output_data, f, indent=2)
                
            return output_data
            
        except Exception as e:
            return {"error": str(e), "query": query}

    def _decompose_question(self, state: Dict) -> Dict:
        """Break down complex questions into simpler sub-questions"""
        response = self.llm.invoke([
            SystemMessage(content="Break down this stock analysis question into specific sub-questions that can be answered with SQL queries:"),
            HumanMessage(content=state['user_query'])
        ])
        
        self._update_token_usage(response)
        
        # Store raw response
        self.raw_responses['decomposition'] = response.content
        
        questions = [
            q.strip().split(". ", 1)[1] if ". " in q else q.strip()
            for q in response.content.split("\n")
            if q.strip() and q[0].isdigit()
        ]
        
        state['decomposed_questions'] = questions
        state['agent_states']['decomposition'] = {
            'raw_response': response.content,
            'parsed_questions': questions
        }
        return state

    def _run_sql_analysis(self, state: Dict) -> Dict:
        """Execute SQL queries for each sub-question"""
        results = {}
        agent_states = {}
        
        for i, question in enumerate(state["decomposed_questions"], 1):
            try:
                result = self.sql_agent.invoke({"input": question})
                self._update_token_usage(result)
                
                # Store agent state
                agent_states[f"question_{i}"] = result
                
                thought = self._extract_thought(result['output'])
                sql = self._extract_sql(result['output'])
                
                # Execute SQL query and get results
                try:
                    # Clean up SQL query by removing any text after the first semicolon
                    sql = sql.split(';')[0] + ';'
                    
                    df = pd.read_sql_query(sql, self.conn)
                    parsed_result = df.to_dict('records')
                except Exception as e:
                    parsed_result = f"Error executing SQL: {str(e)}"
                
                results[f"question_{i}"] = {
                    "question": question,
                    "thought": thought if thought else "No thought process provided",
                    "sql": sql if sql else "No SQL query provided",
                    "result": parsed_result,
                    "raw_output": result['output']
                }
                    
            except Exception as e:
                results[f"question_{i}"] = {
                    "error": str(e),
                    "question": question
                }
        
        state['sql_results'] = results
        state['agent_states']['sql_analysis'] = agent_states
        return state

    def _analyze_results(self, state: Dict) -> Dict:
        """Generate comprehensive analysis from SQL results"""
        results_context = json.dumps(state["sql_results"], indent=2)
        response = self.llm.invoke([
            SystemMessage(content=ANALYST_PROMPT),
            HumanMessage(content=f"""
            Original Question: {state['user_query']}
            
            Analysis Results:
            {results_context}
            
            Provide a comprehensive analysis.""")
        ])
        
        self._update_token_usage(response)
        
        # Store raw response
        self.raw_responses['analysis'] = response.content
        
        state['analysis'] = response.content
        state['agent_states']['analysis'] = {
            'raw_response': response.content
        }
            
        return state

    def _format_output(self, state: Dict) -> Dict:
        """Format final analysis results"""
        state['final_output'] = {
            "query_type": "analysis",
            "user_query": state["user_query"],
            "query_classification": state.get("query_classification", {}),
            "sub_questions": state["decomposed_questions"],
            "sql_analysis": state["sql_results"],
            "expert_analysis": state["analysis"],
            "timestamp": pd.Timestamp.now().isoformat(),
            "token_usage": self.token_usage,
            "processing_time": state.get("processing_time", 0),
            "agent_states": self.agent_states,
            "raw_responses": self.raw_responses
        }
        
        # Save complete analysis output with query-specific filename
        filename = f"{state['user_query'][:50].replace(' ', '_').lower()}_analysis.json"
        with open(filename, 'w') as f:
            json.dump(state['final_output'], f, indent=2)
            
        return state

    def _extract_thought(self, text: str) -> str:
        """Extract thought process from agent response"""
        if "Thought:" in text:
            return text.split("Thought:")[1].split("SQL")[0].strip()
        return ""

    def _extract_sql(self, text: str) -> str:
        """Extract SQL query from agent response"""
        if "SQL:" in text:
            sql_part = text.split("SQL:")[1]
            if "SQLResult:" in sql_part:
                return sql_part.split("SQLResult:")[0].strip()
            if "Final Answer:" in sql_part:
                return sql_part.split("Final Answer:")[0].strip()
            return sql_part.strip()
        return ""

    def _extract_result(self, text: str) -> str:
        """Extract results from agent response"""
        if "SQLResult:" in text:
            return text.split("SQLResult:")[1].strip()
        return ""

    def _create_workflow(self) -> StateGraph:
        """Create analysis workflow graph"""
        workflow = StateGraph(AnalysisState)
        
        workflow.add_node("decompose", self._decompose_question)
        workflow.add_node("sql_analysis", self._run_sql_analysis)
        workflow.add_node("analyze", self._analyze_results)
        workflow.add_node("format", self._format_output)
        
        workflow.add_edge(START, "decompose")
        workflow.add_edge("decompose", "sql_analysis")
        workflow.add_edge("sql_analysis", "analyze")
        workflow.add_edge("analyze", "format")
        workflow.add_edge("format", END)
        
        return workflow.compile()

    def analyze(self, query: str) -> Dict:
        """Main method to process queries and generate analysis"""
        start_time = time.time()
        try:
            # Reset storages for new analysis
            self.token_usage = {"prompt_tokens": 0, "completion_tokens": 0}
            self.agent_states = {}
            self.raw_responses = {}
            
            # First, classify the query
            classification = self._classify_query(query)
            
            # For direct SQL queries, use simplified processing
            if classification.type == QueryType.DIRECT_SQL:
                return self._direct_sql_query(query)
            
            # For analysis queries, use the full workflow
            initial_state = {
                "user_query": query,
                "query_classification": {
                    "type": classification.type.value,
                    "explanation": classification.explanation,
                    "raw_response": classification.raw_response
                },
                "decomposed_questions": [],
                "sql_results": {},
                "analysis": "",
                "final_output": {},
                "token_usage": self.token_usage,
                "processing_time": 0,
                "agent_states": {},
                "raw_responses": {}
            }
            
            final_state = self.workflow.invoke(initial_state)
            
            processing_time = time.time() - start_time
            final_state["processing_time"] = processing_time
            final_state["token_usage"] = self.token_usage
            
            # Update final output
            final_state["final_output"]["processing_time"] = processing_time
            final_state["final_output"]["token_usage"] = self.token_usage
            final_state["final_output"]["agent_states"] = self.agent_states
            final_state["final_output"]["raw_responses"] = self.raw_responses
            
            return final_state["final_output"]
            
        except Exception as e:
            return {"error": str(e), "query": query}
        finally:
            self.conn.close()

# Part 5: Helper Functions
def format_output(results: Dict) -> str:
    """Format analysis results in a readable string format"""
    output = []
    output.append("=== Stock Analysis Results ===")
    output.append(f"\nQuery: {results.get('user_query', 'N/A')}")
    
    # Add performance metrics
    output.append(f"\nProcessing Time: {results.get('processing_time', 0):.2f} seconds")
    token_usage = results.get('token_usage', {})
    output.append(f"Token Usage:")
    output.append(f"  Prompt Tokens: {token_usage.get('prompt_tokens', 0)}")
    output.append(f"  Completion Tokens: {token_usage.get('completion_tokens', 0)}")
    output.append(f"  Total Tokens: {token_usage.get('prompt_tokens', 0) + token_usage.get('completion_tokens', 0)}")
    
    if "error" in results:
        output.append(f"\nError: {results['error']}")
        return "\n".join(output)
    
    if results.get('query_type') == 'direct_sql':
        output.append(f"\nThought Process: {results.get('thought_process', 'N/A')}")
        output.append(f"\nSQL Query: {results.get('sql_query', 'N/A')}")
        output.append("\nResults:")
        if isinstance(results.get('results'), list):
            df = pd.DataFrame(results['results'])
            output.append(str(df))
        else:
            output.append(str(results.get('results', 'No results available')))
    else:
        output.append("\nSub-Questions:")
        for i, q in enumerate(results.get('sub_questions', []), 1):
            output.append(f"{i}. {q}")
        
        output.append("\nSQL Analysis:")
        for key, data in results.get('sql_analysis', {}).items():
            output.append(f"\nQuestion: {data.get('question', 'N/A')}")
            if 'error' not in data:
                output.append(f"Thought Process: {data.get('thought', 'N/A')}")
                output.append(f"SQL Query: {data.get('sql', 'N/A')}")
                try:
                    if isinstance(data.get('result'), (list, dict)):
                        df = pd.DataFrame(data['result'])
                        output.append(str(df))
                    else:
                        output.append(f"Results: {data.get('result', 'No results available')}")
                except:
                    output.append(f"Results: {data.get('result', 'No results available')}")
            else:
                output.append(f"Error: {data['error']}")
        
        output.append("\nExpert Analysis:")
        output.append(results.get('expert_analysis', 'No analysis available'))
    
    return "\n".join(output)

def analyze_stock_query(query: str) -> str:
    """Main function to handle stock analysis queries and return formatted results"""
    try:
        # Initialize analyzer
        config = Config()
        analyzer = StockAnalyzer(config)
        
        # Run analysis
        results = analyzer.analyze(query)
        
        if results and "error" not in results:
            formatted_output = format_output(results)
            filename = f"{query[:50].replace(' ', '_').lower()}_analysis.json"
            return formatted_output + f"\n\nDetailed results saved to {filename}"
        else:
            return f"Error: {results.get('error', 'Unknown error occurred')}"
    except Exception as e:
        return f"Error during analysis: {str(e)}"

# Part 6: Main Execution
if __name__ == "__main__":
    # Example queries for testing
    test_queries = [
        "Show me the last 5 days of stock prices"
    ]
    
    for query in test_queries:
        print(f"\nProcessing: {query}")
        print("=" * 50)
        result = analyze_stock_query(query)
        print(result)
        print("\n" + "="*50)


Processing: Show me the last 5 days of stock prices


[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: ""[0m[38;5;200m[1;3mconsumption[0m[32;1m[1;3mThought: The 'consumption' table seems relevant for stock price data. Let me check the schema for that table.
Action: sql_db_schema
Action Input: consumption[0m[33;1m[1;3m
CREATE TABLE consumption (
	"Unnamed: 0" BIGINT, 
	"Date" TEXT, 
	"Price" FLOAT, 
	"Close" FLOAT, 
	"High" FLOAT, 
	"Low" FLOAT, 
	"Open" FLOAT, 
	"Volume" FLOAT
)

/*
3 rows from consumption table:
Unnamed: 0	Date	Price	Close	High	Low	Open	Volume
0	2023-12-18	194.9350128173828	195.88999938964844	196.6300048828125	194.38999938964844	196.08999633789065	55751900.0
1	2023-12-19	195.9798889160156	196.94000244140625	196.9499969482422	195.88999938964844	196.16000366210935	40714100.0
2	2023-12-20	193.88018798828125	194.8300018310547	197.67999267578125	194.8300018310547	196.8999938964844	52242800.0
*/[0m[32;1m[1

In [17]:
def display_json_details(query: str) -> None:
    """Display detailed JSON analysis results in a readable format"""
    try:
        # Generate filename from query
        filename = f"{query[:50].replace(' ', '_').lower()}_analysis.json"
        
        # Read and parse JSON file
        with open(filename, 'r') as f:
            data = json.load(f)

        print("\n=== DETAILED ANALYSIS REPORT ===\n")

        def format_value(value, indent=0):
            """Recursively format JSON values with proper indentation"""
            indent_str = "    " * indent
            
            if isinstance(value, dict):
                print()
                for k, v in value.items():
                    key_str = k.replace('_', ' ').title()
                    print(f"{indent_str}{key_str}:")
                    format_value(v, indent + 1)

            elif isinstance(value, list):
                print()
                for item in value:
                    print(f"{indent_str}•", end=' ')
                    format_value(item, indent + 1)

            elif isinstance(value, (int, float)):
                print(f"{value:,}")

            elif isinstance(value, bool):
                print(str(value))

            elif value is None:
                print("None")

            else:
                # Handle string values
                print(str(value).strip())

        # Process each top-level key
        for key, value in data.items():
            print(f"\n{key.replace('_', ' ').title()}:", end='')
            format_value(value)
        
        print("\n" + "="*50 + "\n")

    except FileNotFoundError:
        print(f"\nError: Analysis file '{filename}' not found\n")

    except json.JSONDecodeError:
        print(f"\nError: Unable to parse JSON from '{filename}'\n")

    except Exception as e:
        print(f"\nError displaying JSON details: {str(e)}\n")


# Example usage
if __name__ == "__main__":
    for query in test_queries:
        print(f"\nDisplaying detailed analysis for: {query}")
        display_json_details(query)



Displaying detailed analysis for: Show me the last 5 days of stock prices

=== DETAILED ANALYSIS REPORT ===


Query Type:direct_sql

User Query:Show me the last 5 days of stock prices

Thought Process:No thought process provided

Sql Query:No SQL query provided

Results:No results available

Raw Agent Output:The query provided the last 5 days of stock prices, including the open, close, high, low and volume for each day. The most recent trading day was 2024-12-16, where the stock opened at $247.99, reached a high of $251.38, low of $247.65, and closed at $251.04 with volume of 51,665,600 shares traded. The previous 4 days of price data were also returned in descending date order.

Timestamp:2024-12-22T19:38:26.497416

Token Usage:
Prompt Tokens:
195
Completion Tokens:
2

Processing Time:16.305702924728394

Agent State:
Input:
Show me the last 5 days of stock prices
Output:
The query provided the last 5 days of stock prices, including the open, close, high, low and volume for each day. 