In [9]:
from langgraph.graph import StateGraph
from prophet import Prophet
from typing_extensions import TypedDict
import pandas as pd
import yfinance as yf
import matplotlib.pyplot as plt
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import ChatPromptTemplate

# Define the shared state structure
class FinancialAnalysisState(TypedDict):
    ticker: str
    period: str
    interval: str
    prompt: str
    stock_data: dict
    predictions: list
    report: str

# Node functions
def process_stock_data(state: FinancialAnalysisState) -> FinancialAnalysisState:
    ticker = state['ticker']
    period = state['period']
    interval = state['interval']
    
    # Fetch historical data from Yahoo Finance
    stock = yf.Ticker(ticker)
    df = stock.history(period=period, interval=interval)
    df.reset_index(inplace=True)
    
    # Rename the date column to "ds"
    if 'Date' in df.columns:
        df.rename(columns={'Date': 'ds'}, inplace=True)
    else:
        # If no explicit 'Date', assume the first column is the date
        df.rename(columns={df.columns[0]: 'ds'}, inplace=True)
    
    # Remove timezone information from the "ds" column
    df['ds'] = pd.to_datetime(df['ds']).dt.tz_localize(None)
    
    # Rename the price column to "y" by checking different possible column names
    if 'Close' in df.columns:
        df.rename(columns={'Close': 'y'}, inplace=True)
    elif 'Adj Close' in df.columns:
        df.rename(columns={'Adj Close': 'y'}, inplace=True)
    elif 'close' in df.columns:
        df.rename(columns={'close': 'y'}, inplace=True)
    elif 'adj close' in df.columns:
        df.rename(columns={'adj close': 'y'}, inplace=True)
    else:
        raise ValueError("Dataframe does not have a 'Close' or 'Adj Close' column.")
    
    state['stock_data'] = df.to_dict(orient='list')
    return state

# Node 2: Forecasting using Facebook Prophet
def predict_stock_movement(state: FinancialAnalysisState) -> FinancialAnalysisState:
    df = pd.DataFrame(state['stock_data'])
    
    # Verify required columns exist
    if 'ds' not in df.columns or 'y' not in df.columns:
        raise ValueError("Dataframe must have columns 'ds' and 'y'.")
    
    # Initialize and fit Prophet model
    model = Prophet(daily_seasonality=True)
    model.fit(df)
    
    # Create a DataFrame for future dates (forecasting 7 days ahead)
    future = model.make_future_dataframe(periods=7)
    forecast = model.predict(future)
    
    # Extract the last 7 days of forecasts
    forecast_subset = forecast[['ds', 'yhat']].tail(7)
    state['predictions'] = forecast_subset.to_dict(orient='list')
    return state


def generate_report(state: FinancialAnalysisState) -> FinancialAnalysisState:
    prompt = state['prompt']
    predictions = state['predictions']
    df = pd.DataFrame(state['stock_data'])
    
    # Compute summary statistics
    last_price = df['y'].iloc[-1]
    avg_price = df['y'].mean()
    min_price = df['y'].min()
    max_price = df['y'].max()
    
    # Format the predictions into a string
    forecast_str = "\n".join(
        [f"{row}: {price:.2f}" for row, price in zip(predictions['ds'], predictions['yhat'])]
    )
    
    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", 
            f"""You are a financial analyst. Based solely on the following historical stock data and forecast predictions, "
            please generate a detailed report that summarizes recent market trends and provides actionable insights for investors.\n\n"
            Historical Data Summary:\n"
            {prompt}\n\n"
            Historical Data Summary:\n"
                - Last Closing Price: {last_price:.2f}\n
                - Average Closing Price: {avg_price:.2f}\n
                - Minimum Closing Price: {min_price:.2f}\n
                - Maximum Closing Price: {max_price:.2f}\n\n
            Predicted Prices for the Next 7 Days:\n
            {forecast_str}\n\n"
            Please provide your analysis."""
            ),
            ("user", "{user_input}"), 
        ]
    )
    
    llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", api_key='AIzaSyDk3iIRV2N4mA7wpyHnbe1pjcjXtvfcizE', max_tokens=500)
    
    
    report = llm.invoke(prompt)
    state['report'] = report.content
    print('done reporting')
    return state

def plot_stock_trend(state: FinancialAnalysisState) -> FinancialAnalysisState:
    df = pd.DataFrame(state['stock_data'])
    plt.figure(figsize=(10, 5))
    plt.plot(df['ds'], df['y'], label="Stock Price")
    plt.title(f"Stock Price Movement for {state['ticker']}")
    plt.xlabel("Date")
    plt.ylabel("Price")
    plt.legend()
    plt.show()
    return state


# Build the state graph
def build_graph():
    workflow = StateGraph(FinancialAnalysisState)
    
    workflow.add_node("data_ingestion", process_stock_data)
    workflow.add_node("predict_stock", predict_stock_movement)
    workflow.add_node("generate_report", generate_report)
    workflow.add_node("visualize_data", plot_stock_trend)
    
    workflow.add_edge("data_ingestion", "predict_stock")
    workflow.add_edge("predict_stock", "generate_report")
    workflow.add_edge("generate_report", "visualize_data")
    
    workflow.set_entry_point("data_ingestion")
    workflow.set_finish_point("visualize_data")

    return workflow

if __name__ == "__main__":
    # Collect user inputs
    ticker = input("Enter the stock ticker symbol (e.g., 'AAPL'): ")
    period = input("Enter the period for data retrieval (e.g., '1y'): ")
    interval = input("Enter the interval for data points (e.g., '1d'): ")
    prompt = input("Enter your analysis prompt: ")
    
    # Initialize the state with user inputs
    initial_state = FinancialAnalysisState(
        ticker=ticker,
        period=period,
        interval=interval,
        prompt=prompt,
        stock_data={},
        predictions=[],
        report=""
    )
    
    # Build and run the graph
    graph = build_graph()
    app = graph.compile()
    final_state = app.invoke(initial_state)
    
    # Output the generated report
    print("\nGenerated Report:\n", final_state['report'])


KeyboardInterrupt: Interrupted by user