In [None]:
import os
import yfinance as yf
import streamlit as st
from agno.agent import Agent
from agno.models.google import Gemini
import plotly.graph_objects as go

# Set API key if provided
os.environ['GOOGLE_API_KEY'] = os.getenv('GOOGLE_API_KEY', '')

In [None]:
def compare_stocks(symbols):
    data = {}
    for symbol in symbols:
        try:
            stock = yf.Ticker(symbol)
            hist = stock.history(period="6mo")
            if hist.empty:
                continue
            data[symbol] = hist['Close'].pct_change().sum()
        except Exception:
            continue
    return data

In [None]:
def get_company_info(symbol):
    stock = yf.Ticker(symbol)
    return {
        "name": stock.info.get("longName", "N/A"),
        "sector": stock.info.get("sector", "N/A"),
        "market_cap": stock.info.get("marketCap", "N/A"),
        "summary": stock.info.get("longBusinessSummary", "N/A"),
    }

In [None]:
def get_company_news(symbol):
    stock = yf.Ticker(symbol)
    return stock.news[:5]

In [None]:
# Agent definitions
market_analyst = Agent(
    model=Gemini(id="gemini-2.0-flash-exp"),
    description="Analyzes and compares stock performance.",
    instructions=[
        "Retrieve and compare stock performance from Yahoo Finance.",
        "Calculate percentage change over a 6-month period.",
        "Rank stocks based on performance."
    ],
    markdown=True
)

company_researcher = Agent(
    model=Gemini(id="gemini-2.0-flash-exp"),
    description="Fetches company profiles and news.",
    instructions=[
        "Retrieve company info from Yahoo Finance.",
        "Summarize latest news and key metrics."
    ],
    markdown=True
)

stock_strategist = Agent(
    model=Gemini(id="gemini-2.0-flash-exp"),
    description="Recommends top stocks for investment.",
    instructions=[
        "Analyze performance and fundamentals.",
        "Recommend stocks with justification."
    ],
    markdown=True
)

team_lead = Agent(
    model=Gemini(id="gemini-2.0-flash-exp"),
    description="Combines insights for final report.",
    instructions=[
        "Summarize analysis into investor-ready report.",
        "Rank stocks based on all inputs."
    ],
    markdown=True
)

In [None]:
def get_market_analysis(symbols):
    perf = compare_stocks(symbols)
    if not perf:
        return "No valid data."
    return market_analyst.run(f"Compare these stock performances: {perf}").content

def get_company_analysis(symbol):
    info = get_company_info(symbol)
    news = get_company_news(symbol)
    return company_researcher.run(
        f"Analysis for {info['name']} in {info['sector']} sector.\n"
        f"Market Cap: {info['market_cap']}\n"
        f"Summary: {info['summary']}\n"
        f"News: {news}"
    ).content

def get_stock_recommendations(symbols):
    analysis = get_market_analysis(symbols)
    details = {s: get_company_analysis(s) for s in symbols}
    return stock_strategist.run(
        f"Based on the following: {analysis}\n{details}\nRecommend best stocks."
    ).content

def get_final_investment_report(symbols):
    market_analysis = get_market_analysis(symbols)
    analyses = [get_company_analysis(s) for s in symbols]
    recommendations = get_stock_recommendations(symbols)
    return team_lead.run(
        f"Market: {market_analysis}\n\nCompany: {analyses}\n\nRecommendations: {recommendations}"
    ).content

In [None]:
st.set_page_config(page_title="AI Investment Strategist", layout="wide")

st.markdown("""
<h1 style='text-align: center;'> AI Investment Strategist</h1>
<h3 style='text-align: center;'>Get real-time, AI-powered market insights.</h3>
""", unsafe_allow_html=True)

st.sidebar.markdown("""
<h2>Configuration</h2>
<p>Enter stock symbols (e.g., AAPL, MSFT).</p>
""", unsafe_allow_html=True)

symbols_input = st.sidebar.text_input("Stock Symbols", "AAPL, TSLA, GOOG")
api_key = st.sidebar.text_input("Google API Key", type="password")
stock_symbols = [s.strip() for s in symbols_input.split(",")]

if st.sidebar.button("Generate Investment Report"):
    if api_key:
        os.environ['GOOGLE_API_KEY'] = api_key
    report = get_final_investment_report(stock_symbols)
    st.subheader(" Investment Report")
    st.markdown(report)

    st.markdown("### Stock Chart (6 Months)")
    stock_data = yf.download(stock_symbols, period="6mo")['Close']
    fig = go.Figure()
    for symbol in stock_symbols:
        fig.add_trace(go.Scatter(x=stock_data.index, y=stock_data[symbol], mode='lines', name=symbol))
    fig.update_layout(title="Stock Trends", xaxis_title="Date", yaxis_title="Price (USD)", template="plotly_dark")
    st.plotly_chart(fig)