In [25]:
import os
from typing import Dict, List, Optional
from dataclasses import dataclass
import pandas as pd
from langchain_anthropic import ChatAnthropic
from langchain.agents import create_sql_agent
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_core.messages import HumanMessage
from dotenv import load_dotenv

# Load environment variables
load_dotenv('api_key.env')

class ConfigError(Exception):
    """Custom exception for configuration errors"""
    pass

@dataclass
class Config:
    """Configuration settings for the application."""
    db_path: str = "apple_last_year_data.csv"
    sqlite_path: str = "sqlite:///consumption.db"
    model_name: str = "claude-3-sonnet-20240229"
    
    @property
    def api_key(self) -> str:
        """Securely get API key from environment variable."""
        api_key = os.getenv("ANTHROPIC_API_KEY")
        if not api_key:
            raise ConfigError(
                "ANTHROPIC_API_KEY not found in api_key.env file. "
                "Please check if the file exists and contains: ANTHROPIC_API_KEY=your-key-here"
            )
        return api_key

class StockAnalyzer:
    def __init__(self, config: Config):
        self.config = config
        self.db = self._init_database()
        self.llm = self._init_llm()
        self.agent = self._setup_agent()

    def _init_database(self) -> SQLDatabase:
        """Initialize SQLite database from CSV."""
        try:
            df = pd.read_csv(self.config.db_path)
            df.to_sql('consumption', 'sqlite:///consumption.db', index=False, if_exists='replace')
            return SQLDatabase.from_uri(self.config.sqlite_path)
        except FileNotFoundError:
            raise ConfigError(f"CSV file not found: {self.config.db_path}")

    def _init_llm(self) -> ChatAnthropic:
        """Initialize the language model."""
        return ChatAnthropic(
            model=self.config.model_name,
            temperature=0,
            api_key=self.config.api_key
        )

    def _create_prompt(self, schema_info: str) -> str:
        """Create a detailed prompt for the SQL agent."""
        return f"""You are an expert financial database analyst. Your task is to:
1. Interpret user questions about stock data
2. Develop a clear thought process
3. Create and execute appropriate SQL queries
4. Present results clearly

Available Schema:
{schema_info}

Example 1:
User Question: "How did Apple perform last month?"
Thought Process: We need to analyze key performance metrics for the previous month including price changes, trading volume, and volatility.
SQL Query: 
WITH last_month AS (
    SELECT 
        ROUND(MAX(close), 2) as highest_close,
        ROUND(MIN(close), 2) as lowest_close,
        ROUND(AVG(close), 2) as avg_close,
        ROUND(((MAX(close) - MIN(close)) / MIN(close) * 100), 2) as price_range_percent,
        ROUND(AVG(volume), 0) as avg_volume
    FROM consumption 
    WHERE strftime('%m', date) = strftime('%m', date('now', '-1 month'))
)
SELECT * FROM last_month;

Example 2:
User Question: "What were the highest volume trading days?"
Thought Process: We should identify days with exceptional trading volume and analyze the corresponding price movements.
SQL Query:
WITH avg_vol AS (
    SELECT AVG(volume) as avg_daily_volume FROM consumption
)
SELECT 
    date,
    ROUND(close, 2) as closing_price,
    volume,
    ROUND(volume / avg_daily_volume, 2) as volume_ratio
FROM consumption, avg_vol
WHERE volume > avg_daily_volume
ORDER BY volume DESC
LIMIT 5;

Your Response Format:
1. Interpretation: Explain what the user is asking for
2. Thought Process: Detail your analytical approach
3. SQL Query: Show the complete SQL query
4. Results: Present the data in a clear format
5. Analysis: Provide insights about the results

Guidelines:
- Round numeric values to 2 decimal places
- Sort time-series data chronologically
- Include relevant column headers
- Show percentage changes where appropriate
- Explain any notable patterns or anomalies
- Handle edge cases and null values appropriately"""

    def _setup_agent(self):
        """Set up the SQL agent with necessary tools and prompts."""
        toolkit = SQLDatabaseToolkit(db=self.db, llm=self.llm)
        schema_info = self.db.get_table_info()
        prompt = self._create_prompt(schema_info)
        return create_sql_agent(
            llm=self.llm,
            toolkit=toolkit,
            agent_type="zero-shot-react-description",
            verbose=True,
            prefix=prompt
        )

    def analyze(self, query: str) -> Dict:
        """Analyze a stock-related query and return detailed results."""
        try:
            # Execute query and get full response
            result = self.agent.invoke({
                "input": query
            })
            
            # Parse the response to extract different components
            response_text = result['output']
            
            # Extract components (this will vary based on actual output format)
            components = self._parse_response(response_text)
            
            return {
                "user_question": query,
                "interpretation": components.get("interpretation", ""),
                "thought_process": components.get("thought_process", ""),
                "sql_query": components.get("sql_query", ""),
                "results": components.get("results", []),
                "analysis": components.get("analysis", "")
            }
        except Exception as e:
            return {
                "error": str(e),
                "user_question": query
            }

    def _parse_response(self, response: str) -> Dict:
        """Parse the agent's response into structured components."""
        try:
            # Split response into sections
            sections = response.split('\n')
            current_section = ""
            parsed = {
                "interpretation": "",
                "thought_process": "",
                "sql_query": "",
                "results": [],
                "analysis": ""
            }
            
            for line in sections:
                if "Thought:" in line:
                    current_section = "thought_process"
                    parsed[current_section] = line.replace("Thought:", "").strip()
                elif "SQL Query:" in line:
                    current_section = "sql_query"
                    parsed[current_section] = line.replace("SQL Query:", "").strip()
                elif "Result:" in line or "SQLResult:" in line:
                    current_section = "results"
                    # Try to parse as DataFrame if possible
                    try:
                        result_str = line.split("SQLResult:", 1)[1].strip()
                        parsed["results"] = pd.read_json(result_str).to_dict('records')
                    except:
                        parsed["results"] = line.replace("Result:", "").strip()
                elif line.strip():
                    parsed[current_section] += "\n" + line.strip()
            
            return parsed
        except Exception as e:
            print(f"Error parsing response: {e}")
            return {"error": str(e)}

def test_connection():
    """Test if the API key is loaded correctly."""
    api_key = os.getenv("ANTHROPIC_API_KEY")
    if api_key:
        print("✓ API key loaded successfully!")
        print(f"✓ First few characters: {api_key[:8]}...")
        return True
    else:
        print("✗ API key not found in api_key.env")
        print("Please ensure your api_key.env file contains: ANTHROPIC_API_KEY=your-key-here")
        return False

def main():
    try:
        # Test connection first
        if not test_connection():
            return

        # Initialize analyzer
        config = Config()
        analyzer = StockAnalyzer(config)
        
        # Example queries
        queries = [
            "Show me the last 10 days of stock prices",
            "What's the average trading volume this month?",
            "Find days with unusual price movements"
        ]
        
        for query in queries:
            print(f"\nAnalyzing: {query}")
            print("-" * 50)
            
            result = analyzer.analyze(query)
            
            if "error" not in result:
                print("Question:", result["user_question"])
                print("\nInterpretation:", result["interpretation"])
                print("\nThought Process:", result["thought_process"])
                print("\nSQL Query:", result["sql_query"])
                print("\nResults:")
                if isinstance(result["results"], list):
                    df = pd.DataFrame(result["results"])
                    print(df.to_string(index=False))
                else:
                    print(result["results"])
                print("\nAnalysis:", result["analysis"])
            else:
                print(f"Error: {result['error']}")
            
            print("\n" + "="*50)
                
    except ConfigError as e:
        print(f"Configuration Error: {e}")
    except Exception as e:
        print(f"Unexpected error: {e}")

if __name__ == "__main__":
    main()

✓ API key loaded successfully!
✓ First few characters: sk-ant-a...

Analyzing: Show me the last 10 days of stock prices
--------------------------------------------------


[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mThought: I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables.
Action: sql_db_list_tables
Action Input: 
[0m[38;5;200m[1;3mconsumption[0m[32;1m[1;3mThought: The consumption table seems to contain stock price data, so I should query its schema to see the relevant columns.
Action: sql_db_schema
Action Input: consumption
[0m[33;1m[1;3m
CREATE TABLE consumption (
	"Unnamed: 0" BIGINT, 
	"Date" TEXT, 
	"Price" FLOAT, 
	"Close" FLOAT, 
	"High" FLOAT, 
	"Low" FLOAT, 
	"Open" FLOAT, 
	"Volume" FLOAT
)

/*
3 rows from consumption table:
Unnamed: 0	Date	Price	Close	High	Low	Open	Volume
0	2023-12-18	194.9350128173828	195.88999938964844	196.6300048828125	194.38999938964844	196.08999