<a href="https://colab.research.google.com/gist/virattt/0e4c7740472177a327b61449c9af721d/hedge-fund-agent-team-v1-3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook provides a tutorial on how to use multi-agents with LangGraph.

Specifically, we use the **supervisor** pattern, where we have 1 supervisor agent and 3 analyst agents:
1. fundamental analyst
2. technical analyst
3. sentiment analyst

This code will be a part of an evolving series.

If you have any questions, please message me on X at [virattt](https://twitter.com/virattt).

# Setup

In [None]:
%%capture --no-stderr
%pip install -U langgraph langchain langchain_openai langchain_experimental langsmith pandas ta

In [None]:
import getpass
import os


def _set_if_undefined(var: str):
    if not os.environ.get(var):
        os.environ[var] = getpass.getpass(f"Please provide your {var}")


_set_if_undefined("OPENAI_API_KEY")               # For the agent. Get from https://platform.openai.com
_set_if_undefined("FINANCIAL_DATASETS_API_KEY")   # For getting financial data. Get from https://financialdatasets.ai
_set_if_undefined("TAVILY_API_KEY")               # For surfing the web. Get from https://tavily.com

# Define agent tools

In [None]:
from langchain_core.tools import tool
from typing import List, Dict, Optional, Union
import requests
import os
from typing import Dict, Union
from pydantic import BaseModel, Field
import requests
from langchain_core.tools import tool

import pandas as pd
import ta
from datetime import datetime, timedelta

class GetIncomeStatementsInput(BaseModel):
    ticker: str = Field(..., description="The ticker of the stock.")
    period: str = Field(default="ttm", description="The period of the income statements. Valid values are 'ttm', 'quarterly' or 'annual'.")
    limit: int = Field(default=10, description="The maximum number of income statements to return. Default is 10.")

@tool("get_income_statements", args_schema=GetIncomeStatementsInput, return_direct=True)
def get_income_statements(ticker: str, period: str = "ttm", limit: int = 10) -> Union[Dict, str]:
    """
    Get income statements for a ticker with specified period and limit.
    """
    api_key = os.environ.get("FINANCIAL_DATASETS_API_KEY")
    if not api_key:
        raise ValueError("Missing FINANCIAL_DATASETS_API_KEY.")

    url = (
        f'https://api.financialdatasets.ai/financials/income-statements'
        f'?ticker={ticker}'
        f'&period={period}'
        f'&limit={limit}'
    )

    try:
        response = requests.get(url, headers={'X-API-Key': api_key})
        return response.json()
    except Exception as e:
        return {"ticker": ticker, "income_statements": [], "error": str(e)}

class GetBalanceSheetsInput(BaseModel):
    ticker: str = Field(..., description="The ticker of the stock.")
    period: str = Field(default="ttm", description="The period of the balance sheets. Valid values are 'ttm', 'quarterly' or 'annual'.")
    limit: int = Field(default=10, description="The maximum number of balance sheets to return. Default is 10.")

@tool("get_balance_sheets", args_schema=GetBalanceSheetsInput, return_direct=True)
def get_balance_sheets(ticker: str, period: str = "ttm", limit: int = 10) -> Union[Dict, str]:
    """
    Get balance sheets for a ticker with specified period and limit.
    """
    api_key = os.environ.get("FINANCIAL_DATASETS_API_KEY")
    if not api_key:
        raise ValueError("Missing FINANCIAL_DATASETS_API_KEY.")

    url = (
        f'https://api.financialdatasets.ai/financials/balance-sheets'
        f'?ticker={ticker}'
        f'&period={period}'
        f'&limit={limit}'
    )

    try:
        response = requests.get(url, headers={'X-API-Key': api_key})
        return response.json()
    except Exception as e:
        return {"ticker": ticker, "balance_sheets": [], "error": str(e)}

class GetCashFlowStatementsInput(BaseModel):
    ticker: str = Field(..., description="The ticker of the stock.")
    period: str = Field(default="ttm", description="The period of the cash flow statements. Valid values are 'ttm', 'quarterly' or 'annual'.")
    limit: int = Field(default=10, description="The maximum number of cash flow statements to return. Default is 10.")

@tool("get_cash_flow_statements", args_schema=GetCashFlowStatementsInput, return_direct=True)
def get_cash_flow_statements(ticker: str, period: str = "ttm", limit: int = 10) -> Union[Dict, str]:
    """
    Get cash flow statements for a ticker with specified period and limit.
    """
    api_key = os.environ.get("FINANCIAL_DATASETS_API_KEY")
    if not api_key:
        raise ValueError("Missing FINANCIAL_DATASETS_API_KEY.")

    url = (
        f'https://api.financialdatasets.ai/financials/cash-flow-statements'
        f'?ticker={ticker}'
        f'&period={period}'
        f'&limit={limit}'
    )

    try:
        response = requests.get(url, headers={'X-API-Key': api_key})
        return response.json()
    except Exception as e:
        return {"ticker": ticker, "cash_flow_statements": [], "error": str(e)}

class GetPricesInput(BaseModel):
    ticker: str = Field(..., description="The ticker of the stock.")
    start_date: str = Field(..., description="The start of the price time window. Either a date with the format YYYY-MM-DD or a millisecond timestamp.")
    end_date: str = Field(..., description="The end of the aggregate time window. Either a date with the format YYYY-MM-DD or a millisecond timestamp.")
    interval: str = Field(default="day", description="The time interval of the prices. Valid values are second', 'minute', 'day', 'week', 'month', 'quarter', 'year'.")
    interval_multiplier: int = Field(default=1, description="The multiplier for the interval. For example, if interval is 'day' and interval_multiplier is 1, the prices will be daily. If interval is 'minute' and interval_multiplier is 5, the prices will be every 5 minutes.")
    limit: int = Field(default=5000, description="The maximum number of prices to return. The default is 5000 and the maximum is 50000.")

@tool("get_stock_prices", args_schema=GetPricesInput, return_direct=True)
def get_stock_prices(ticker: str, start_date: str, end_date: str, interval: str = 'day', interval_multiplier: int = 1, limit: int = 5000) -> Union[Dict, str]:
    """
    Get prices for a ticker over a given date range and interval.
    """

    api_key = os.environ.get("FINANCIAL_DATASETS_API_KEY")
    if not api_key:
        raise ValueError("Missing FINANCIAL_DATASETS_API_KEY.")
    url = (
        f"https://api.financialdatasets.ai/prices"
        f"?ticker={ticker}"
        f"&start_date={start_date}"
        f"&end_date={end_date}"
        f"&interval={interval}"
        f"&interval_multiplier={interval_multiplier}"
        f"&limit={limit}"
    )

    try:
        response = requests.get(url, headers={'X-API-Key': api_key})
        data = response.json()
        return data
    except Exception as e:
        return {"ticker": ticker, "prices": [], "error": str(e)}

class GetCurrentPriceInput(BaseModel):
    ticker: str = Field(..., description="The ticker of the stock.")

@tool("get_current_stock_price", args_schema=GetCurrentPriceInput, return_direct=True)
def get_current_stock_price(ticker: str) -> Union[Dict, str]:
    """
    Get the current (latest) stock price for a ticker.
    """
    api_key = os.environ.get("FINANCIAL_DATASETS_API_KEY")
    if not api_key:
        raise ValueError("Missing FINANCIAL_DATASETS_API_KEY.")

    url = f"https://api.financialdatasets.ai/prices/snapshot?ticker={ticker}"

    try:
        response = requests.get(url, headers={'X-API-Key': api_key})
        return response.json()
    except Exception as e:
        return {"ticker": ticker, "price": None, "error": str(e)}

class GetOptionsChainInput(BaseModel):
    ticker: str = Field(..., description="The ticker of the stock.")
    limit: int = Field(default=10, description="The maximum number of options to return. Default is 10.")
    strike_price: Optional[float] = Field(default=None, description="Optional filter for specific strike price.")
    option_type: Optional[str] = Field(default=None, description="Optional filter for option type. Valid values are 'call' or 'put'.")

@tool("get_options_chain", args_schema=GetOptionsChainInput, return_direct=True)
def get_options_chain(
    ticker: str,
    limit: int = 10,
    strike_price: Optional[float] = None,
    option_type: Optional[str] = None
) -> Union[Dict, str]:
    """
    Get options chain data for a ticker with optional filters for strike price and option type.
    """
    api_key = os.environ.get("FINANCIAL_DATASETS_API_KEY")
    if not api_key:
        raise ValueError("Missing FINANCIAL_DATASETS_API_KEY.")

    params = {
        'ticker': ticker,
        'limit': limit
    }

    if strike_price is not None:
        params['strike_price'] = strike_price
    if option_type is not None:
        params['option_type'] = option_type

    url = 'https://api.financialdatasets.ai/options/chain'

    try:
        response = requests.get(url, headers={'X-API-Key': api_key}, params=params)
        return response.json()
    except Exception as e:
        return {"ticker": ticker, "options_chain": [], "error": str(e)}

class GetInsiderTradesInput(BaseModel):
    ticker: str = Field(..., description="The ticker of the stock.")
    limit: int = Field(default=10, description="The maximum number of insider transactions to return. Default is 10.")

@tool("get_insider_trades", args_schema=GetInsiderTradesInput, return_direct=True)
def get_insider_trades(ticker: str, limit: int = 10) -> Union[Dict, str]:
    """
    Get insider trading transactions for a ticker.
    """
    api_key = os.environ.get("FINANCIAL_DATASETS_API_KEY")
    if not api_key:
        raise ValueError("Missing FINANCIAL_DATASETS_API_KEY.")

    url = (
        f'https://api.financialdatasets.ai/insider-transactions'
        f'?ticker={ticker}'
        f'&limit={limit}'
    )

    try:
        response = requests.get(url, headers={'X-API-Key': api_key})
        return response.json()
    except Exception as e:
        return {"ticker": ticker, "insider_transactions": [], "error": str(e)}

class GetTechnicalIndicatorsInput(BaseModel):
    """Input schema for technical indicators calculation."""
    ticker: str = Field(..., description="The ticker of the stock.")
    indicator: str = Field(..., description="The technical indicator to calculate. Valid values are 'rsi', 'macd', 'sma', 'ema', 'bbands'.")
    period: Optional[int] = Field(default=14, description="The period for the indicator calculation. Default is 14.")
    start_date: Optional[str] = Field(default=None, description="Start date in YYYY-MM-DD format.")
    end_date: Optional[str] = Field(default=None, description="End date in YYYY-MM-DD format.")
    interval: Optional[str] = Field(default="day", description="The time interval for price data.")
    interval_multiplier: Optional[int] = Field(default=1, description="Multiplier for the time interval.")

@tool("get_technical_indicators", args_schema=GetTechnicalIndicatorsInput)
def get_technical_indicators(
    ticker: str,
    indicator: str,
    period: int = 14,
    interval: str = "day",
    interval_multiplier: int = 1,
    start_date: Optional[str] = None,
    end_date: Optional[str] = None,
) -> Union[Dict, str]:
    """
    Calculate technical indicators for a given ticker and time period.
    Supports RSI, MACD, SMA, EMA, and Bollinger Bands calculations.
    """
    try:
        # Fetch historical price data with padding for calculations
        adjusted_start = (datetime.strptime(start_date, "%Y-%m-%d") - timedelta(days=period * 2)).strftime("%Y-%m-%d")

        price_data = get_stock_prices.invoke({
            "ticker": ticker,
            "start_date": adjusted_start,
            "end_date": end_date,
            "interval": interval,
            "interval_multiplier": interval_multiplier
        })

        if "error" in price_data:
            return price_data

       # Convert to pandas DataFrame with proper datetime handling
        df = pd.DataFrame(price_data["prices"])

        # Clean datetime strings by removing timezone
        df['time'] = df['time'].apply(lambda x: x.split(' EDT')[0].split(' EST')[0])
        # Convert to datetime after cleaning
        df['time'] = pd.to_datetime(df['time'])
        df.set_index('time', inplace=True)

        result = {
            "ticker": ticker,
            "indicator": indicator,
            "period": period,
            "data": []
        }

        # Calculate indicators (no changes here)
        if indicator.lower() == "rsi":
            rsi = ta.momentum.RSIIndicator(df['close'], window=period)
            df['indicator_value'] = rsi.rsi()
        elif indicator.lower() == "macd":
            macd = ta.trend.MACD(
                df['close'],
                window_slow=26,
                window_fast=12,
                window_sign=9
            )
            df['macd_line'] = macd.macd()
            df['signal_line'] = macd.macd_signal()
            df['histogram'] = macd.macd_diff()
            df['indicator_value'] = df['macd_line']
        elif indicator.lower() == "sma":
            df['indicator_value'] = ta.trend.SMAIndicator(
                df['close'],
                window=period
            ).sma_indicator()
        elif indicator.lower() == "ema":
            df['indicator_value'] = ta.trend.EMAIndicator(
                df['close'],
                window=period
            ).ema_indicator()
        elif indicator.lower() == "bbands":
            bb = ta.volatility.BollingerBands(
                df['close'],
                window=period,
                window_dev=2
            )
            df['middle_band'] = bb.bollinger_mavg()
            df['upper_band'] = bb.bollinger_hband()
            df['lower_band'] = bb.bollinger_lband()
            df['indicator_value'] = df['middle_band']

        # Filter to requested date range
        df = df[start_date:end_date]

        # Handle NaN values using newer pandas methods
        df = df.ffill().bfill()  # Using newer methods instead of fillna(method=...)

        for idx, row in df.iterrows():
            data_point = {
                "time": idx.strftime("%Y-%m-%d %H:%M:%S"),  # Clean datetime format
                "time_milliseconds": int(idx.timestamp() * 1000),
                "value": float(row['indicator_value'])
            }

            if indicator.lower() == "macd":
                data_point.update({
                    "signal_line": float(row['signal_line']),
                    "histogram": float(row['histogram'])
                })
            elif indicator.lower() == "bbands":
                data_point.update({
                    "upper_band": float(row['upper_band']),
                    "lower_band": float(row['lower_band'])
                })

            result["data"].append(data_point)

        return result

    except Exception as e:
        return {
            "ticker": ticker,
            "indicator": indicator,
            "error": f"Error calculating {indicator}: {str(e)}"
        }

In [None]:
from typing import Annotated

from langchain_community.tools.tavily_search import TavilySearchResults

# News tool
get_news_tool = TavilySearchResults(max_results=5)

In [None]:
# Group tools by analyst
fundamental_tools = [get_income_statements, get_balance_sheets, get_cash_flow_statements]
technical_tools = [get_stock_prices, get_current_stock_price, get_technical_indicators]
sentiment_tools = [get_options_chain, get_insider_trades, get_news_tool]

# Helper functions

In [None]:
from langchain_core.messages import HumanMessage

def agent_node(state, agent, name):
    result = agent.invoke(state)
    return {
        "messages": [HumanMessage(content=result["messages"][-1].content, name=name)]
    }

# Create LangGraph

In [None]:
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI
from pydantic import BaseModel
from typing import Literal, Sequence, List
from typing_extensions import TypedDict
import functools
import operator
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langgraph.graph import END, StateGraph, START
from langgraph.prebuilt import create_react_agent

# Define team members
members = ["fundamental_analyst", "technical_analyst", "sentiment_analyst"]

# Create routing prompt template
routing_prompt = ChatPromptTemplate.from_messages([
    (
        "system",
        "You are a portfolio manager supervising a hedge fund team with the following analysts:"
        "\n- fundamental_analyst: Analyzes financial statements and company health"
        "\n- technical_analyst: Analyzes price patterns and market trends"
        "\n- sentiment_analyst: Analyzes insider trading activity, options flow, and the news"
        "\nDetermine which analyst(s) should analyze the request. Respond with ONLY the analyst names"
        " separated by commas (e.g., 'technical_analyst,sentiment_analyst'). Choose analysts based on:"
        "\n- Use fundamental_analyst for questions about financials, valuations, or company health"
        "\n- Use technical_analyst for questions about price action, trends, or chart patterns"
        "\n- Use sentiment_analyst for questions about market sentiment, news impact, or trading activity"
    ),
    MessagesPlaceholder(variable_name="messages"),
])

# Create the summary prompt template
summary_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are a portfolio manager responsible for synthesizing analysis from your team of analysts. "
            "Review all the analysts' reports and provide a comprehensive summary including:\n"
            "1. Key financial metrics and their implications (only when you have this data) \n"
            "2. Technical analysis insights (only when you have this data) \n"
            "3. Market sentiment and news impact (only when you have this data) \n"
            "4. Overall investment recommendation\n"
            "Make sure to highlight any discrepancies or conflicting signals between different analyses."
        ),
        MessagesPlaceholder(variable_name="messages"),
        (
            "human",
            "Based on all the analyst reports above, provide a comprehensive summary and investment recommendation."
        ),
    ]
)

# Initialize LLM
llm = ChatOpenAI(model="gpt-4")

class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]
    selected_analysts: List[str]
    current_analyst_idx: int

def supervisor_router(state):
    """Route to appropriate analyst(s) based on the query"""
    # Create the routing chain
    routing_chain = routing_prompt | llm

    # Get the routing decision
    result = routing_chain.invoke(state)
    selected_analysts = [a.strip() for a in result.content.strip().split(',')]

    # Add routing message to state
    message = SystemMessage(
        content=f"Routing query to: {', '.join(selected_analysts)}",
        name="supervisor"
    )

    return {
        "messages": state["messages"] + [message],
        "selected_analysts": selected_analysts,
        "current_analyst_idx": 0
    }

def get_next_step(state):
    """Determine the next step in the workflow"""
    if not state["selected_analysts"]:
        return "final_summary"

    current_idx = state["current_analyst_idx"]
    if current_idx >= len(state["selected_analysts"]):
        return "final_summary"

    return state["selected_analysts"][current_idx]

def agent_node(state, agent, name):
    """Generic analyst node that updates the current_analyst_idx after completion"""
    result = agent.invoke(state)

    return {
        "messages": [HumanMessage(content=result["messages"][-1].content, name=name)],
        "selected_analysts": state["selected_analysts"],
        "current_analyst_idx": state["current_analyst_idx"] + 1
    }

def final_summary_agent(state):
    """Create final summary of all analyst reports"""
    summary_chain = summary_prompt | llm
    result = summary_chain.invoke(state)
    return {
        "messages": [HumanMessage(content=result.content, name="portfolio_manager")],
        "selected_analysts": state["selected_analysts"],
        "current_analyst_idx": state["current_analyst_idx"]
    }

# Initialize workflow
workflow = StateGraph(AgentState)

# Create the analysts
fundamental_analyst = create_react_agent(llm, tools=fundamental_tools)
fundamental_analyst_node = functools.partial(agent_node, agent=fundamental_analyst, name="fundamental_analyst")

technical_analyst = create_react_agent(llm, tools=technical_tools)
technical_analyst_node = functools.partial(agent_node, agent=technical_analyst, name="technical_analyst")

sentiment_analyst = create_react_agent(llm, tools=sentiment_tools)
sentiment_analyst_node = functools.partial(agent_node, agent=sentiment_analyst, name="sentiment_analyst")

# Add nodes
workflow.add_node("supervisor", supervisor_router)
workflow.add_node("fundamental_analyst", fundamental_analyst_node)
workflow.add_node("technical_analyst", technical_analyst_node)
workflow.add_node("sentiment_analyst", sentiment_analyst_node)
workflow.add_node("final_summary", final_summary_agent)

# Add conditional edges
workflow.add_conditional_edges(
    "supervisor",
    get_next_step,
    {
        "fundamental_analyst": "fundamental_analyst",
        "technical_analyst": "technical_analyst",
        "sentiment_analyst": "sentiment_analyst",
        "final_summary": "final_summary"
    }
)

# Add conditional edges from each analyst back to the router function
for analyst in members:
    workflow.add_conditional_edges(
        analyst,
        get_next_step,
        {
            "fundamental_analyst": "fundamental_analyst",
            "technical_analyst": "technical_analyst",
            "sentiment_analyst": "sentiment_analyst",
            "final_summary": "final_summary"
        }
    )

# Add entry point and final edges
workflow.add_edge(START, "supervisor")
workflow.add_edge("final_summary", END)

# Compile the graph
graph = workflow.compile()

In [None]:

# Example usage
response = graph.invoke({
    "messages": [HumanMessage(content="Analyze AAPL's recent price action and market sentiment")],
    "next_analyst": "supervisor"
})

# Code to pretty print Agent output

In [None]:
from typing import Dict, Any
import json
import re
from langchain_core.messages import HumanMessage
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
from rich.rule import Rule

console = Console()

def format_bold_text(content: str) -> Text:
    """Convert **text** to rich Text with bold formatting."""
    text = Text()
    pattern = r'\*\*(.*?)\*\*'

    # Split the text by the bold markers
    parts = re.split(pattern, content)

    # Alternate between regular and bold text
    for i, part in enumerate(parts):
        if i % 2 == 0:
            text.append(part)
        else:
            text.append(part, style="bold")

    return text

def format_message_content(content: str) -> Union[str, Text]:
    """Format the message content, handling JSON and text with bold markers."""
    try:
        # Try to parse as JSON for prettier formatting
        data = json.loads(content)
        return json.dumps(data, indent=2)
    except:
        # If not JSON, check for bold markers
        if '**' in content:
            return format_bold_text(content)
        return content

def format_agent_message(message: HumanMessage) -> Union[str, Text]:
    """Format a single agent message."""
    return format_message_content(message.content)

def get_agent_title(agent: str, message: HumanMessage) -> str:
    """Get the title for the agent panel, with fallback handling."""
    base_title = agent.replace('_', ' ').title()

    if hasattr(message, 'name') and message.name is not None:
        try:
            return message.name.replace('_', ' ').title()
        except:
            return base_title
    return base_title

def print_step(step: Dict[str, Any]) -> None:
    """Pretty print a single step of the agent execution."""
    for agent, data in step.items():
        # Handle supervisor steps
        if 'next' in data:
            next_agent = data['next']
            text = Text()
            text.append("Portfolio Manager ", style="bold magenta")
            text.append("assigns next task to ", style="white")

            if next_agent == "final_summary":
                text.append("FINAL SUMMARY", style="bold yellow")
            elif next_agent == "END":
                text.append("END", style="bold red")
            else:
                text.append(f"{next_agent}", style="bold green")

            console.print(Panel(
                text,
                title="[bold blue]Supervision Step",
                border_style="blue"
            ))

        # Handle agent responses and final summary
        if 'messages' in data:
            message = data['messages'][0]
            formatted_content = format_agent_message(message)

            if agent == "final_summary":
                # Final summary formatting
                console.print(Rule(style="yellow", title="Portfolio Analysis"))
                console.print(Panel(
                    formatted_content,
                    title="[bold yellow]Investment Summary and Recommendation",
                    border_style="yellow",
                    padding=(1, 2)
                ))
                console.print(Rule(style="yellow"))
            else:
                # Regular analyst reports
                title = get_agent_title(agent, message)
                console.print(Panel(
                    formatted_content,
                    title=f"[bold blue]{title} Report",
                    border_style="green"
                ))

def stream_agent_execution(graph, input_data: Dict, config: Dict) -> None:
    """Stream and pretty print the agent execution."""
    console.print("\n[bold blue]Starting Agent Execution...[/bold blue]\n")

    for step in graph.stream(input_data, config):
        if "__end__" not in step:
            print_step(step)
            console.print("\n")

    console.print("[bold blue]Analysis Complete[/bold blue]\n")

# Run the Hedge Fund team

In [None]:
input_data = {
    "messages": [HumanMessage(content="What is AAPL's current price and latest revenue?")]
}
config = {"recursion_limit": 10}
stream_agent_execution(graph, input_data, config)

In [None]:
input_data = {
    "messages": [HumanMessage(content="What is AAPL's latest news?")]
}
config = {"recursion_limit": 10}
stream_agent_execution(graph, input_data, config)