In [1]:
#FinRagAssist - Smart Investment Advisor

In [2]:

# Block 1 — Imports, config, load_data

import os
from typing import Dict, Any
import numpy as np
import pandas as pd
import joblib

from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import xgboost as xgb

from dotenv import load_dotenv

# Paths
DATA_CSV = r"C:\Users\rahil\Downloads\cleaned_data_final.csv"
XGB_MODEL_PATH = r"C:\Users\rahil\Downloads\xgb_risk_model.joblib"

# Final feature set (training uses rich features, UI uses grouped features)
FEATURE_COLS = [
    # Rich numeric features from CSV
    "Age",
    "Income Level",
    "Account Balance",
    "Deposits",
    "Withdrawals",
    "Transfers",
    "International Transfers",
    "Investments",
    "Loan Amount",
    "Loan Term (Months)",
    "Net Savings",
    "Loan to Income Ratio",
    "Investment Ratio",

    # Simplified features used in UI
    "AgeGroup",
    "IncomeGroup",
    "EmploymentStatus",
    "LoanStatus",
    "InvestmentGoal",
    "InvestmentAmount",
]

TARGET_COL = "Risk Tolerance"

load_dotenv()

def load_data(path: str = DATA_CSV) -> pd.DataFrame:
    """Load CSV into DataFrame."""
    if not os.path.exists(path):
        raise FileNotFoundError(f"Data file not found at {path}")
    df = pd.read_csv(path)
    print(f"Loaded {df.shape[0]} rows and {df.shape[1]} columns")
    return df


In [74]:

# Block 2 — Stock summaries & comparison (compact)

import numpy as _np

def _cagr(s: pd.Series):
    if s.empty: 
        return _np.nan
    yrs = (s.index[-1] - s.index[0]).days / 365.25
    return (s.iloc[-1] / s.iloc[0]) ** (1 / yrs) - 1 if yrs > 0 else _np.nan

def _max_dd(s: pd.Series):
    if s.empty: 
        return _np.nan
    return ((s - s.cummax()) / s.cummax()).min()

def summarize_stock_price_df(df: pd.DataFrame, price_col="Close", name="STOCK"):
    if price_col not in df or df[price_col].dropna().empty:
        return f"### {name} — no price data."

    s = df[price_col].dropna().sort_index()
    r = s.pct_change().dropna()

    return (
        f"### {name}\n"
        f"- Latest: {s.iloc[-1]:.2f}\n"
        f"- CAGR: {_cagr(s):.2%}\n"
        f"- Vol: {(r.std()*_np.sqrt(252)):.2%}\n"
        f"- Max DD: {_max_dd(s):.2%}"
    )

def compare_two_price_series(df1, df2, price_col="Close", name1="A", name2="B"):
    """
    Clean, UI-friendly comparison of two assets.
    Always returns readable Markdown (no raw dicts or dtype junk).
    """

    # Convert Series safely
    def safe_float(x):
        try:
            return float(x)
        except:
            return None

    def metrics(df):
        if df is None or df.empty or price_col not in df.columns:
            return {"ok": False}

        s = df[price_col].dropna()
        if s.empty:
            return {"ok": False}

        last = safe_float(s.iloc[-1])

        # CAGR
        try:
            cagr = _cagr(s)
            cagr = float(cagr) if cagr is not None else None
        except:
            cagr = None

        # Volatility
        daily = s.pct_change().dropna()
        vol = float(daily.std() * (252 ** 0.5)) if not daily.empty else None

        # Max drawdown
        try:
            maxdd = _max_dd(s)
            maxdd = float(maxdd) if maxdd is not None else None
        except:
            maxdd = None

        return {
            "ok": True,
            "last": last,
            "cagr": cagr,
            "vol": vol,
            "maxdd": maxdd,
        }

    # Compute metrics
    m1 = metrics(df1)
    m2 = metrics(df2)

    # Build Markdown
    md = []
    md.append(f"## Comparison: {name1} vs {name2}")

    if not m1["ok"] or not m2["ok"]:
        md.append("**Not enough price data to compare these two symbols.**")
        return "\n".join(md)

    # Summary table
    md.append("### Snapshot\n")
    md.append(f"- **{name1}**: Price ₹{m1['last']:.2f}, CAGR {m1['cagr']*100:.2f}%, Vol {m1['vol']*100:.2f}%, MaxDD {m1['maxdd']*100:.2f}%")
    md.append(f"- **{name2}**: Price ₹{m2['last']:.2f}, CAGR {m2['cagr']*100:.2f}%, Vol {m2['vol']*100:.2f}%, MaxDD {m2['maxdd']*100:.2f}%")

    # Pros & cons
    md.append("\n### Pros & Cons\n")

    # Pros A
    pros1 = []
    if m1["cagr"] > m2["cagr"]:
        pros1.append("Higher returns (CAGR).")
    if m1["vol"] < m2["vol"]:
        pros1.append("Lower volatility.")
    if m1["maxdd"] > m2["maxdd"]:
        pros1.append("Smaller drawdowns.")

    cons1 = []
    if m1["cagr"] < m2["cagr"]:
        cons1.append("Lower returns (CAGR).")
    if m1["vol"] > m2["vol"]:
        cons1.append("Higher volatility.")
    if m1["maxdd"] < m2["maxdd"]:
        cons1.append("Larger drawdowns.")

    # Pros B
    pros2 = []
    if m2["cagr"] > m1["cagr"]:
        pros2.append("Higher returns (CAGR).")
    if m2["vol"] < m1["vol"]:
        pros2.append("Lower volatility.")
    if m2["maxdd"] > m1["maxdd"]:
        pros2.append("Smaller drawdowns.")

    cons2 = []
    if m2["cagr"] < m1["cagr"]:
        cons2.append("Lower returns (CAGR).")
    if m2["vol"] > m1["vol"]:
        cons2.append("Higher volatility.")
    if m2["maxdd"] < m1["maxdd"]:
        cons2.append("Larger drawdowns.")

    md.append(f"#### {name1} Pros\n" + ("\n".join(f"- {x}" for x in pros1) if pros1 else "- None"))
    md.append(f"#### {name1} Cons\n" + ("\n".join(f"- {x}" for x in cons1) if cons1 else "- None"))

    md.append(f"\n#### {name2} Pros\n" + ("\n".join(f"- {x}" for x in pros2) if pros2 else "- None"))
    md.append(f"#### {name2} Cons\n" + ("\n".join(f"- {x}" for x in cons2) if cons2 else "- None"))

    # Verdict
    score1 = (m1["cagr"] > m2["cagr"]) + (m1["vol"] < m2["vol"]) + (m1["maxdd"] > m2["maxdd"])
    score2 = (m2["cagr"] > m1["cagr"]) + (m2["vol"] < m1["vol"]) + (m2["maxdd"] > m1["maxdd"])

    verdict = name1 if score1 > score2 else name2 if score2 > score1 else "Tie"

    md.append(f"\n### Final Verdict\n**{verdict}** looks better overall for a conservative investor.")

    return "\n".join(md)



In [4]:

#Block 3 — XGBoost training + predict (UI-aligned)


def _age_to_group(age):
    try:
        a = float(age)
    except:
        return "26-35"
    if a < 26:
        return "18-25"
    elif a < 36:
        return "26-35"
    elif a < 46:
        return "36-45"
    elif a < 61:
        return "46-60"
    else:
        return "60+"

def _income_to_group(inc):
    try:
        v = float(inc)
    except:
        return "30,000-70,000"
    if v < 30000:
        return "<30,000"
    elif v <= 70000:
        return "30,000-70,000"
    else:
        return "70,000+"

def _ensure_features(df: pd.DataFrame) -> pd.DataFrame:
    """Add simplified/grouped features that UI uses."""
    d = df.copy()

    if "AgeGroup" not in d:
        d["AgeGroup"] = d["Age"].apply(_age_to_group) if "Age" in d else "26-35"

    if "IncomeGroup" not in d:
        d["IncomeGroup"] = d["Income Level"].apply(_income_to_group) if "Income Level" in d else "30,000-70,000"

    d["EmploymentStatus"] = d.get("Employment Status", "Salaried").astype(str)
    d["LoanStatus"] = d.get("Loan Status", "No").astype(str)

    if "InvestmentGoal" not in d:
        d["InvestmentGoal"] = d.get("Investment Goals", "Growth").astype(str)

    if "InvestmentAmount" not in d:
        if "Investments" in d:
            d["InvestmentAmount"] = d["Investments"].fillna(0)
        elif "Net Savings" in d:
            d["InvestmentAmount"] = d["Net Savings"].fillna(0)
        elif "Account Balance" in d:
            d["InvestmentAmount"] = d["Account Balance"].fillna(0)
        else:
            d["InvestmentAmount"] = 0.0

    return d

def preprocess(df: pd.DataFrame):
    """Prepare X: keep rich numeric features + encode grouped ones."""
    df = _ensure_features(df)

    use_cols = [c for c in FEATURE_COLS if c in df.columns]
    if not use_cols:
        raise ValueError("None of the FEATURE_COLS found in dataset")

    X = df[use_cols].copy()

    # Encode categoricals
    cat_maps = {}
    for c in X.select_dtypes(include=["object"]).columns:
        X[c] = X[c].fillna("missing").astype("category")
        cat_maps[c] = list(X[c].cat.categories)
        X[c] = X[c].cat.codes

    X = X.apply(pd.to_numeric, errors="coerce").fillna(0)

    scaler = StandardScaler()
    X.loc[:, :] = scaler.fit_transform(X.values)

    return X, cat_maps, scaler

def train_xgb_model(df: pd.DataFrame):
    """Train XGBoost and print accuracy."""
    if TARGET_COL not in df.columns:
        raise ValueError(f"CSV must contain '{TARGET_COL}'")

    X, cat_maps, scaler = preprocess(df)
    label_enc = LabelEncoder()
    y = label_enc.fit_transform(df[TARGET_COL].astype(str))

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=y
    )

    model = xgb.XGBClassifier(
        n_estimators=350,
        max_depth=4,
        learning_rate=0.05,
        subsample=0.8,
        colsample_bytree=0.8,
        eval_metric="mlogloss",
        use_label_encoder=False,
    )
    model.fit(X_train, y_train)

    preds = model.predict(X_test)
    acc = accuracy_score(y_test, preds)
    print(f"XGBoost test accuracy: {acc:.4f}")
    print(classification_report(y_test, preds, target_names=label_enc.classes_))

    bundle = {
        "model": model,
        "columns": X.columns.tolist(),
        "cat_maps": cat_maps,
        "scaler": scaler,
        "label_encoder": label_enc,
    }
    joblib.dump(bundle, XGB_MODEL_PATH)
    print(f"Saved model to {XGB_MODEL_PATH}")
    return bundle

def load_xgb_model(path: str = XGB_MODEL_PATH):
    if not os.path.exists(path):
        raise FileNotFoundError("Model not trained yet.")
    return joblib.load(path)

def predict_risk(bundle: Dict[str, Any], inp: Dict[str, Any]):
    """Predict risk tolerance for one user profile."""
    cols = bundle["columns"]
    cat_maps = bundle["cat_maps"]
    scaler = bundle["scaler"]
    le: LabelEncoder = bundle["label_encoder"]
    model = bundle["model"]

    row = {c: inp.get(c, "missing") for c in cols}
    df_row = pd.DataFrame([row])

    for col, cats in cat_maps.items():
        val = df_row.at[0, col]
        if val not in cats:
            val = "missing" if "missing" in cats else cats[0]
        df_row[col] = cats.index(val)

    df_row = df_row.apply(pd.to_numeric, errors="coerce").fillna(0)
    df_row[cols] = scaler.transform(df_row[cols].values)

    proba = model.predict_proba(df_row)[0]
    idx = int(np.argmax(proba))

    return {
        "prediction": le.inverse_transform([idx])[0],
        "probability": float(proba[idx]),
        "class_probabilities": {label: float(p) for label, p in zip(le.classes_, proba)},
    }


In [5]:
# Prediction

# if __name__ == "__main__":
#     df = load_data()
#     train_xgb_model(df)


In [8]:

# Block 4 — ChromaDB + embeddings setup


# Required from Block 1 — add defaults here to avoid NameError
try:
    CHROMA_COLLECTION_NAME
except NameError:
    CHROMA_COLLECTION_NAME = "fin_docs"

try:
    CHROMA_DIR
except NameError:
    CHROMA_DIR = r"C:\Users\rahil\Downloads\chroma_db"

try:
    EMBEDDING_MODEL_NAME
except NameError:
    EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"

try:
    DEFAULT_TOP_K
except NameError:
    DEFAULT_TOP_K = 4


from sentence_transformers import SentenceTransformer
import chromadb
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma


def init_chroma(collection_name: str = CHROMA_COLLECTION_NAME, persist_dir: str = CHROMA_DIR):
    """
    Initialize or load a ChromaDB collection.
    """
    client = chromadb.PersistentClient(path=persist_dir)

    try:
        coll = client.get_collection(name=collection_name)
        print(f"Loaded existing Chroma collection: {collection_name}")
    except Exception:
        coll = client.create_collection(name=collection_name)
        print(f"Created new Chroma collection: {collection_name}")

    return client, coll


def build_embeddings_and_upsert(docs, collection, embed_model_name: str = EMBEDDING_MODEL_NAME):
    """
    Encode text and insert into ChromaDB.
    Each doc: {"id": "...", "text": "...", "metadata": {...}}
    """
    model = SentenceTransformer(embed_model_name)

    texts = [d["text"] for d in docs]
    ids = [d["id"] for d in docs]
    metas = [d.get("metadata", {}) for d in docs]

    embs = model.encode(texts, show_progress_bar=True, convert_to_numpy=True)

    collection.upsert(
        ids=ids,
        documents=texts,
        metadatas=metas,
        embeddings=embs.tolist(),
    )

    print(f"Upserted {len(ids)} documents into collection '{collection.name}'")


def make_retriever(k: int = DEFAULT_TOP_K):
    """
    Create retriever for similarity-based queries.
    """
    hf = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
    
    vect_store = Chroma(
        collection_name=CHROMA_COLLECTION_NAME,
        embedding_function=hf,
        persist_directory=CHROMA_DIR,
    )

    return vect_store.as_retriever(
        search_type="similarity",
        search_kwargs={"k": k},
    )


In [22]:
# Block 5 - RAG: vector search + live stock/news fetch


import re

# looser ticker pattern (letters, numbers, dots, hyphens, up to 10 chars)
_TICKER_RE = re.compile(r"^[A-Za-z0-9\.\-]{1,10}$")


def fetch_stock_data(ticker: str) -> Document | None:
    """
    Get a short stock summary from yfinance and wrap it in a Document.
    Returns None if no data was found or on error.
    """
    ticker = str(ticker).strip().upper()
    if not _TICKER_RE.match(ticker):
        return None

    try:
        stock = yf.Ticker(ticker)
        hist = stock.history(period="1mo")
        if hist is None or hist.empty:
            return None

        last_close = float(hist["Close"].iloc[-1])
        avg_5 = float(hist["Close"].tail(5).mean())

        sector = None
        industry = None
        try:
            info = stock.info if isinstance(stock.info, dict) else {}
            sector = info.get("sector")
            industry = info.get("industry")
        except Exception:
            # some tickers / yfinance versions can raise when accessing .info
            sector = None
            industry = None

        text = (
            f"{ticker} latest close: {last_close:.2f}. "
            f"5-day avg close: {avg_5:.2f}. "
            f"Sector: {sector or 'N/A'}. "
            f"Industry: {industry or 'N/A'}."
        )
        return Document(page_content=text, metadata={"source": "yfinance", "ticker": ticker})
    except Exception as e:
        print(f"[fetch_stock_data] error for {ticker}: {e}")
        return None


def rag_query(query: str, retriever, collection, k: int | None = None) -> List[Document]:
    """
    Query retriever (if present). If no relevant docs, fetch live data
    (yfinance + news) and upsert into collection (if upsert function present).
    """
    if k is None:
        k = globals().get("DEFAULT_TOP_K", 4)

    # Try retrieval (pass k where possible)
    results = []
    if retriever is not None:
        try:
            # prefer retriever.get_relevant_documents(query, k=k) if signature supports it
            try:
                results = retriever.get_relevant_documents(query, k=k)
            except TypeError:
                # older retriever signature
                results = retriever.get_relevant_documents(query)
        except Exception as exc:
            print(f"[rag_query] retriever error: {exc}")
            results = []

    # Quick relevance check
    if results:
        try:
            top_text = (results[0].page_content or "").lower()
            meta_text = str(results[0].metadata or "").lower()
            if query.lower() in top_text or query.lower() in meta_text:
                print(f"[rag_query] found {len(results)} relevant docs for '{query}'")
                return results
            else:
                print(f"[rag_query] docs found but not clearly relevant for '{query}' — fetching live data")
        except Exception:
            print("[rag_query] could not inspect retrieved docs; fetching live data")

    #  Live fetch fallback 
    print(f"[rag_query] fetching live data for '{query}' (yfinance + news)...")
    new_docs: List[Document] = []

    q_clean = query.strip()
    # If query looks like a ticker, try stock summary
    if _TICKER_RE.match(q_clean):
        sd = fetch_stock_data(q_clean.upper())
        if sd:
            new_docs.append(sd)

    # Fetch news via Tavily (caller must implement actual API call)
    tavily_results = None  # <-- replace with real call where available
    try:
        news_docs = fetch_news_tavily(query, max_results=5, results=tavily_results)
        new_docs.extend(news_docs)
    except Exception as exc:
        print(f"[rag_query] fetch_news_tavily failed: {exc}")

    # Upsert into collection if possible
    if new_docs:
        upsert_items = [{"id": f"live_{i}_{q_clean}", "text": d.page_content, "metadata": d.metadata} for i, d in enumerate(new_docs)]
        upsert_fn = globals().get("build_embeddings_and_upsert")
        if callable(upsert_fn) and collection is not None:
            try:
                upsert_fn(upsert_items, collection)
            except Exception as exc:
                print(f"[rag_query] upsert failed: {exc}")

    return new_docs


In [24]:
# Block 6: Tools for the agent

from langchain.agents import Tool


# Risk profiling tool
def risk_tool_func(user_input: dict):
    """
    Predict the user's risk tolerance using the stored XGBoost model.
    Expects keys like AgeGroup, IncomeGroup, EmploymentStatus, LoanStatus,
    InvestmentGoal, InvestmentAmount (aligned with predict_risk).
    """
    bundle = load_xgb_model()
    return predict_risk(bundle, user_input)


risk_tool = Tool(
    name="RiskProfiler",
    func=risk_tool_func,
    description=(
        "Predicts a user's risk tolerance (High, Medium, Low) based on "
        "their financial profile (age group, income group, employment, loans, goals, etc.)."
    ),
)


# Market data / RAG tool
def market_tool_func(query: str):
    """
    Returns market or stock-related context from the RAG setup.
    """
    client, coll = init_chroma()
    retriever = make_retriever()
    docs = rag_query(query, retriever, coll) or []
    if not docs:
        return "No market information found."
    return "\n\n".join(d.page_content for d in docs)


market_tool = Tool(
    name="MarketData",
    func=market_tool_func,
    description=(
        "Provides market/stock context from the internal RAG store. "
        "Accepts a stock ticker (e.g. AAPL) or company name (e.g. Apple)."
    ),
)


# Stock comparison tool
def compare_tool_func(stock1: str, stock2: str):
    """
    Fetch RAG summaries for two stocks and return a combined view
    for the LLM to analyse further (pros/cons, verdict).
    """
    client, coll = init_chroma()
    retriever = make_retriever()

    docs1 = rag_query(stock1, retriever, coll) or []
    docs2 = rag_query(stock2, retriever, coll) or []

    summary1 = "\n\n".join(d.page_content for d in docs1) or "No data found."
    summary2 = "\n\n".join(d.page_content for d in docs2) or "No data found."

    return (
        f"Stock 1: {stock1}\n{summary1}\n\n"
        f"Stock 2: {stock2}\n{summary2}\n\n"
        "Now compare them with pros/cons and a final verdict."
    )


def _compare_wrapper(pair: str):
    # handle 'AAPL,MSFT' or 'AAPL, MSFT'
    parts = [p.strip() for p in pair.split(",")]
    if len(parts) != 2:
        return "Please provide two tickers as 'TICKER1,TICKER2'."
    return compare_tool_func(parts[0], parts[1])


compare_tool = Tool(
    name="CompareStocks",
    func=_compare_wrapper,
    description="Compare two stocks by providing input as 'TICKER1,TICKER2'.",
)


In [26]:
# Block 7: Agent setup with LangChain

import re
from langchain.agents import initialize_agent, AgentType, Tool
from langchain.chat_models import ChatOpenAI


def parse_user_profile(text: str) -> dict:
    """
    Parse simple 'key=value' pairs (comma-separated) into a dictionary.

    Example:
        "Age=30, Occupation=Engineer, Income=60000"
    """
    profile = {}
    for part in text.split(","):
        part = part.strip()
        if not part or "=" not in part:
            continue
        key, val = part.split("=", 1)
        key = key.strip()
        val = val.strip()
        # Try to convert numeric values
        try:
            val = float(val) if "." in val else int(val)
        except ValueError:
            pass
        profile[key] = val
    return profile


# 1) Risk profiling tool
def risk_tool_func(user_input: str):
    """
    Accepts text like:
      "Age=30, Income=60000, Employment Status=Salaried, Loan Amount=10000, Investment Goals=Growth, Investments=25000"
    and maps it into the simplified feature space expected by predict_risk.
    """
    bundle = load_xgb_model()
    raw = parse_user_profile(user_input)

    # Map raw values into grouped features used by the model
    age_val = raw.get("Age", raw.get("age", 30))
    inc_val = raw.get("Income", raw.get("Income Level", raw.get("income", 60000)))

    age_group = _age_to_group(age_val)
    income_group = _income_to_group(inc_val)

    emp = raw.get("Employment Status", raw.get("EmploymentStatus", "Salaried"))
    loan_status_raw = str(raw.get("Loan Status", raw.get("LoanStatus", ""))).lower()
    loan_amt = float(raw.get("Loan Amount", raw.get("loan_amount", 0)) or 0)
    has_loan = loan_amt > 0 or loan_status_raw in ["yes", "y", "true", "1"]
    loan_status = "Yes" if has_loan else "No"

    goal = raw.get("Investment Goals", raw.get("Goal", raw.get("InvestmentGoal", "Growth")))
    invest_amt = raw.get(
        "InvestmentAmount",
        raw.get("Investments", raw.get("Net Savings", raw.get("Account Balance", 0.0))),
    )
    try:
        invest_amt = float(invest_amt)
    except Exception:
        invest_amt = 0.0

    feature_row = {
        "AgeGroup": age_group,
        "IncomeGroup": income_group,
        "EmploymentStatus": str(emp),
        "LoanStatus": loan_status,
        "InvestmentGoal": str(goal),
        "InvestmentAmount": invest_amt,
    }

    return predict_risk(bundle, feature_row)


risk_tool = Tool(
    name="RiskProfiler",
    func=risk_tool_func,
    description=(
        "Predicts a user's risk tolerance (High, Medium, Low) from profile data. "
        "Input format: key=value pairs separated by commas. "
        "Example: 'Age=30, Income=60000, Employment Status=Salaried, "
        "Loan Amount=10000, Investment Goals=Growth, Investments=25000'."
    ),
)


# 2) Market data tool (RAG over your vector store)
def market_tool_func(query: str):
    client, coll = init_chroma()
    retriever = make_retriever()
    docs = rag_query(query, retriever, coll) or []
    if not docs:
        return "No market context found for this query."
    return "\n".join(d.page_content for d in docs)


market_tool = Tool(
    name="MarketData",
    func=market_tool_func,
    description=(
        "Returns market-related context from the internal RAG setup. "
        "Input can be a stock ticker (AAPL, TSLA) or a company name (Apple, Tesla)."
    ),
)


# 3) Comparison tool
def compare_tool_func(query: str):
    """
    Compare two stocks given as 'AAPL,MSFT' or text like 'Compare Apple and Microsoft'.
    """
    query = query.strip()

    # Try comma-separated first
    if "," in query:
        parts = [x.strip() for x in query.split(",") if x.strip()]
        if len(parts) != 2:
            return "Please provide exactly two stocks, e.g. 'AAPL,MSFT'."
        stock1, stock2 = parts
    else:
        # Fallback: split on 'and'
        parts = re.split(r"\band\b", query, flags=re.IGNORECASE)
        if len(parts) == 2:
            stock1, stock2 = parts[0].strip(), parts[1].strip()
        else:
            return "Please provide two stocks, e.g. 'AAPL,MSFT'."

    client, coll = init_chroma()
    retriever = make_retriever()
    docs1 = rag_query(stock1, retriever, coll) or []
    docs2 = rag_query(stock2, retriever, coll) or []

    summary1 = "\n".join(d.page_content for d in docs1) or "No context found."
    summary2 = "\n".join(d.page_content for d in docs2) or "No context found."

    return (
        f"Stock 1: {stock1}\n{summary1}\n\n"
        f"Stock 2: {stock2}\n{summary2}\n\n"
        "Now provide pros/cons and a final verdict."
    )


compare_tool = Tool(
    name="CompareStocks",
    func=compare_tool_func,
    description="Compare two stocks. Input: 'AAPL,MSFT' or 'Compare Apple and Microsoft'.",
)


def get_agent():
    """
    Create a LangChain agent wired up with risk, market, and comparison tools.
    """
    if not OPENAI_KEY:
        raise ValueError("Missing OPENAI_API_KEY.")

    llm = ChatOpenAI(
        model="gpt-4o-mini",
        temperature=0.2,
        openai_api_key=OPENAI_KEY,
    )

    tools = [risk_tool, market_tool, compare_tool]

    agent = initialize_agent(
        tools=tools,
        llm=llm,
        agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
        verbose=True,
        handle_parsing_errors=True,
    )
    return agent


In [18]:
#Test

# agent = get_agent()

# # 1. Risk profiling
# query1 = "Age=30, Occupation=Engineer, Income=60000, Loan Amount=10000, Employment Status=Employed, Investment Goals=Growth"
# print(agent.run(query1))

# # 2. Market data
# query2 = "Tell me about Tesla"
# print(agent.run(query2))

# # 3. Compare stocks
# query3 = "Compare AAPL and MSFT"
# print(agent.run(query3))


In [44]:
# Block 8.1 - Prompt + LLM chain

from langchain.prompts import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.chains import LLMChain
import os

# Ensure the key exists (fallback to env directly)
OPENAI_KEY = os.getenv("OPENAI_API_KEY")

profile_prompt = PromptTemplate(
    input_variables=["text"],
    template="""
You're a helpful financial assistant. Read the user's profile text and return a JSON object with these keys:
- Age Group: one of "18-25", "26-35", "36-45", "46-60", "60+"
- Income Group: one of "<30,000", "30,000-70,000", "70,000+"
- Employment Status: e.g. "Salaried", "Self-employed", "Retired"
- Loan Status: one of "approved", "pending", "rejected"
- Investment Goal: e.g. "Growth", "Wealth Preservation", "Short-term Safety"
- Investment Amount: numeric (INR)

User text: {text}
"""
)

llm_client = ChatOpenAI(
    model="gpt-4o-mini",
    temperature=0.0,
    openai_api_key=OPENAI_KEY
)

profile_parser = LLMChain(llm=llm_client, prompt=profile_prompt)


In [32]:
# 8.2 - Allocation and projection utilities
def allocation_for_risk(risk: str, goal: str):
    if risk == "High":
        return {"Stocks": 0.70, "FD": 0.20, "Gold": 0.10}
    if risk == "Medium":
        return {"Stocks": 0.50, "FD": 0.30, "Gold": 0.20}
    return {"Stocks": 0.30, "FD": 0.40, "Gold": 0.30}

def simulate_growth(principal: float, split: dict, years: int = 5):
    rates = {"Stocks": 0.12, "FD": 0.06, "Gold": 0.08}
    portfolio_vals = [principal]
    for _ in range(years):
        prev = portfolio_vals[-1]
        next_total = sum(prev * split[k] * (1 + rates[k]) for k in split)
        portfolio_vals.append(next_total)
    fd_series = [principal * ((1 + rates["FD"])**i) for i in range(years+1)]
    gold_series = [principal * ((1 + rates["Gold"])**i) for i in range(years+1)]
    return portfolio_vals, fd_series, gold_series


In [54]:
# 8.3 - Fetch history + basic indicators (with NSE .NS fallback)
import yfinance as yf
import pandas as pd
import matplotlib.pyplot as plt


def get_history(ticker: str, period: str = "6mo", interval: str = "1d"):
    """
    Fetch price history for a symbol.

    - First tries the raw symbol (e.g. INFY, AAPL).
    - If that returns no data and the symbol has no suffix,
      also tries '<SYMBOL>.NS' for common Indian stocks (TCS, RELIANCE, etc.).
    """
    base = (ticker or "").strip()
    if not base:
        return pd.DataFrame()

    candidates = [base]
    if "." not in base:  # no exchange suffix provided -> also try NSE
        candidates.append(base.upper() + ".NS")

    for sym in candidates:
        try:
            df = yf.Ticker(sym).history(period=period, interval=interval)
            if df is not None and not df.empty:
                df = df.reset_index()

                # Ensure there is a 'Date' column for plotting
                if "Date" not in df.columns:
                    # yfinance usually names the index 'Date', but just in case:
                    df.rename(columns={df.columns[0]: "Date"}, inplace=True)

                # Optional: keep track of which symbol actually worked
                df["__symbol__"] = sym
                return df
        except Exception as exc:
            print(f"[get_history] yfinance error for {sym}: {exc}")

    # If nothing worked, return empty frame
    return pd.DataFrame()


def add_indicators(df: pd.DataFrame):
    """
    Add SMA20, EMA20, RSI14, MACD, MACD_Signal columns to a price DataFrame.
    Expects a 'Close' column.
    """
    df = df.copy()
    if "Close" not in df.columns:
        return df

    # SMA / EMA
    df["SMA20"] = df["Close"].rolling(20, min_periods=1).mean()
    df["EMA20"] = df["Close"].ewm(span=20, adjust=False).mean()

    # RSI14 (standard EMA-based)
    diff = df["Close"].diff()
    up = diff.clip(lower=0)
    down = -diff.clip(upper=0)
    ma_up = up.ewm(com=13, adjust=False).mean()
    ma_down = down.ewm(com=13, adjust=False).mean()
    rs = ma_up / (ma_down + 1e-9)
    df["RSI14"] = 100 - (100 / (1 + rs))

    # MACD
    ema12 = df["Close"].ewm(span=12, adjust=False).mean()
    ema26 = df["Close"].ewm(span=26, adjust=False).mean()
    df["MACD"] = ema12 - ema26
    df["MACD_Signal"] = df["MACD"].ewm(span=9, adjust=False).mean()

    return df


def plot_with_indicators(df: pd.DataFrame, ticker: str):
    """
    Plot price + SMA20/EMA20, RSI14, and MACD for a given history DataFrame.
    Returns a matplotlib Figure (or None if df is empty).
    """
    if df is None or df.empty or "Close" not in df.columns or "Date" not in df.columns:
        return None

    df2 = add_indicators(df)

    fig = plt.figure(constrained_layout=True, figsize=(8, 8))
    gs = fig.add_gridspec(3, 1, height_ratios=[3, 1, 1])

    # Price + MAs
    ax_price = fig.add_subplot(gs[0, 0])
    ax_price.plot(df2["Date"], df2["Close"], lw=1, label="Close")
    ax_price.plot(df2["Date"], df2["SMA20"], lw=1, label="SMA20")
    ax_price.plot(df2["Date"], df2["EMA20"], lw=1, label="EMA20")
    ax_price.set_title(f"{ticker} — Price with SMA/EMA")
    ax_price.legend(loc="upper left")

    # RSI
    ax_rsi = fig.add_subplot(gs[1, 0], sharex=ax_price)
    ax_rsi.plot(df2["Date"], df2["RSI14"], label="RSI14")
    ax_rsi.axhline(70, linestyle="--", linewidth=0.7)
    ax_rsi.axhline(30, linestyle="--", linewidth=0.7)
    ax_rsi.set_ylabel("RSI")

    # MACD
    ax_macd = fig.add_subplot(gs[2, 0], sharex=ax_price)
    ax_macd.plot(df2["Date"], df2["MACD"], label="MACD")
    ax_macd.plot(df2["Date"], df2["MACD_Signal"], label="Signal")
    ax_macd.set_ylabel("MACD")
    ax_macd.legend(loc="upper left")

    for lbl in ax_macd.get_xticklabels():
        lbl.set_rotation(30)

    return fig


In [46]:
# 8.4 - Risk profiler runner
import numpy as np
import json
import matplotlib.pyplot as plt

def run_risk_profile(user_text: str,
                     age_sel, income_sel, employment_sel,
                     loan_sel, goal_sel, invest_amount):
    # initial status
    yield ("Parsing / running...", "", None, None)

    parsed = None
    parse_note = None

    # 1) Try free-text LLM parser
    if user_text and user_text.strip():
        try:
            raw = profile_parser.run({"text": user_text})
            parsed = json.loads(raw)
        except Exception as exc:
            parsed = None
            parse_note = f"Parsing failed: {exc} — using form values."

    # 2) Start from form values
    age_group = age_sel
    income_group = income_sel
    employment = employment_sel
    loan_status = loan_sel
    goal = goal_sel
    try:
        amount = float(invest_amount or 0)
    except Exception:
        amount = 0.0

    # 3) If parsing worked, override with parsed fields (LLM keys → model keys)
    if isinstance(parsed, dict):
        age_group = parsed.get("Age Group", age_group)
        income_group = parsed.get("Income Group", income_group)
        employment = parsed.get("Employment Status", employment)
        loan_status = parsed.get("Loan Status", loan_status)
        goal = parsed.get("Investment Goal", goal)
        try:
            amount = float(parsed.get("Investment Amount", amount) or 0)
        except Exception:
            pass

    # 4) Normalise loan status to approved / pending / rejected
    loan_str = str(loan_status).strip().lower()
    if "approve" in loan_str:
        loan_norm = "approved"
    elif "pend" in loan_str:
        loan_norm = "pending"
    elif "reject" in loan_str:
        loan_norm = "rejected"
    else:
        # fall back to dropdown value (should already be one of the three)
        loan_norm = str(loan_sel)

    # 5) Build feature dict exactly as the XGBoost pipeline expects
    features = {
        "AgeGroup": age_group,
        "IncomeGroup": income_group,
        "EmploymentStatus": employment,
        "InvestmentGoal": goal,
        "LoanStatus": loan_norm,
        "InvestmentAmount": amount,
    }

    # 6) Model prediction
    try:
        model_bundle = load_xgb_model()
        pred = predict_risk(model_bundle, features)
    except Exception as exc:
        yield ("", f"Prediction failed: {exc}", None, None)
        return

    risk_label = pred.get("prediction", "Medium")
    confidence = pred.get("probability", 0.6)

    # 7) Allocation + plots
    split = allocation_for_risk(risk_label, goal)

    # pie chart
    fig1, ax1 = plt.subplots(figsize=(4, 3))
    ax1.pie(split.values(), labels=split.keys(), autopct="%1.1f%%", startangle=90)
    ax1.set_title("Allocation")

    # 5-year projection
    port, fd_s, gold_s = simulate_growth(amount, split, years=5)
    yrs = np.arange(0, 6)
    fig2, ax2 = plt.subplots(figsize=(6, 3))
    ax2.plot(yrs, port, "o-", label="Portfolio")
    ax2.plot(yrs, fd_s, "s--", label="FD-only")
    ax2.plot(yrs, gold_s, "d--", label="Gold-only")
    ax2.set_title("5-year projection")
    ax2.set_xlabel("Year")
    ax2.set_ylabel("Value")
    ax2.legend()

    # 8) Markdown output
    text = f"### Risk: **{risk_label}**\n\nConfidence: {confidence:.2%}\n\nAllocation: "
    text += ", ".join(f"{k}: {v*100:.0f}%" for k, v in split.items())
    if parse_note:
        text = f"**Note:** {parse_note}\n\n" + text

    yield ("", text, fig1, fig2)


In [88]:
# 8.5 - Stock analysis runner (filtered RAG per ticker)
import re
from langchain.prompts import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.chains import LLMChain

explain_llm = ChatOpenAI(
    model="gpt-4o-mini",
    temperature=0.0,
    openai_api_key=OPENAI_KEY,
)

explain_prompt = PromptTemplate(
    input_variables=["ticker", "trend", "rsi", "macd"],
    template="""
You're writing a short, plain-language explanation for an investor.

Ticker: {ticker}
Trend (one line): {trend}
RSI note: {rsi}
MACD note: {macd}

Write a 2-3 sentence explanation.
""",
)
explain_chain = LLMChain(llm=explain_llm, prompt=explain_prompt)


def _filter_docs_for_ticker(docs, ticker: str):
    """
    Keep only documents that look relevant to the given ticker:
    - metadata['ticker'] == ticker (if present), OR
    - the text itself mentions the ticker.
    """
    if not docs:
        return []

    t = ticker.strip().upper()
    out = []
    for d in docs:
        text = (getattr(d, "page_content", "") or "")
        meta = getattr(d, "metadata", {}) or {}
        meta_ticker = str(meta.get("ticker", "")).upper()
        if not t:
            out.append(d)
            continue

        if meta_ticker == t or t in text.upper():
            out.append(d)

    return out


def _clean_research_snip(docs, ticker: str, max_chars: int = 1200) -> str:
    """
    Build a short research snippet:
    - Only use docs relevant to the ticker.
    - Inside each doc, keep only sentences that mention the ticker.
    """
    if not docs:
        return "No research available."

    t = ticker.strip().upper()
    snippets = []

    for d in docs:
        text = (getattr(d, "page_content", "") or "").strip()
        if not text:
            continue

        if t:
            # split into sentences, keep those mentioning the ticker
            sentences = re.split(r'(?<=[.!?])\s+', text)
            hit_sents = [s for s in sentences if t in s.upper()]
            text = " ".join(hit_sents).strip()

        if text:
            snippets.append(text)

    if not snippets:
        return "No research available."

    joined = "\n\n".join(snippets).strip()
    return joined[:max_chars] if len(joined) > max_chars else joined


def run_stock_analysis_enhanced(query: str):
    yield ("Fetching chart...", "", None)

    ticker = query.strip()

    # --- RAG / research text (filtered) ---
    try:
        client, coll = init_chroma()
        retriever = make_retriever(4)
        docs_raw = rag_query(ticker, retriever, coll)
        docs_filt = _filter_docs_for_ticker(docs_raw, ticker)
        rag_snip = _clean_research_snip(docs_filt, ticker)
    except Exception as exc:
        print("[run_stock_analysis_enhanced] RAG error:", exc)
        rag_snip = "No research available."

    # --- Price history + indicators ---
    hist = get_history(ticker, period="6mo")
    fig = plot_with_indicators(hist, ticker) if not hist.empty else None

    trend = "insufficient data"
    rsi_note = "RSI unavailable"
    macd_note = "MACD unavailable"

    try:
        if not hist.empty:
            dfi = add_indicators(hist)
            change = (dfi["Close"].iloc[-1] - dfi["Close"].iloc[0]) / (dfi["Close"].iloc[0] + 1e-9)
            trend = "Uptrend" if change > 0.03 else ("Downtrend" if change < -0.03 else "Sideways")

            last_rsi = dfi["RSI14"].iloc[-1]
            rsi_note = (
                f"RSI ~{last_rsi:.1f} — "
                + ("Overbought" if last_rsi > 70 else ("Oversold" if last_rsi < 30 else "Neutral"))
            )

            m_val = dfi["MACD"].iloc[-1]
            s_val = dfi["MACD_Signal"].iloc[-1]
            macd_note = (
                f"MACD {m_val:.3f} vs {s_val:.3f} — "
                + ("Bullish" if m_val > s_val else "Bearish/Neutral")
            )
    except Exception as exc:
        print("[run_stock_analysis_enhanced] indicator error:", exc)

    # --- Natural-language explanation ---
    try:
        expl = explain_chain.run(
            {"ticker": ticker, "trend": trend, "rsi": rsi_note, "macd": macd_note}
        )
    except Exception:
        expl = f"{trend}. {rsi_note}. {macd_note}."

    md = (
        f"**Research:**\n\n{rag_snip}\n\n"
        f"**Chart note:**\n\n{expl}\n\n"
        f"**Tech details:**\n- Trend: {trend}\n- {rsi_note}\n- {macd_note}"
    )

    yield ("", md, fig)


In [78]:
# 8.6 - Compare two tickers (simple & robust)

import yfinance as yf
import pandas as pd

def _download_with_ns(symbol: str, period: str = "3y") -> pd.DataFrame:
    """
    Try history for a symbol. If plain symbol fails and there's no suffix,
    also try SYMBOL.NS (common for Indian listings).
    """
    base = (symbol or "").strip()
    if not base:
        return pd.DataFrame()

    candidates = [base]
    if "." not in base:
        candidates.append(base.upper() + ".NS")

    for sym in candidates:
        # Try yf.download
        try:
            df = yf.download(sym, period=period, progress=False)
            if df is not None and not df.empty:
                df["__symbol__"] = sym
                return df
        except Exception as exc:
            print(f"[download_with_ns] download error for {sym}: {exc}")

        # Fallback: Ticker().history
        try:
            df2 = yf.Ticker(sym).history(period=period)
            if df2 is not None and not df2.empty:
                df2["__symbol__"] = sym
                return df2
        except Exception as exc:
            print(f"[download_with_ns] history error for {sym}: {exc}")

    return pd.DataFrame()


def run_compare_generator(sym1: str, sym2: str):
    """
    Gradio generator: compares two tickers using RAG text + price data.

    Returns:
        (status_text, markdown)
    """
    # First yield: status only
    yield ("Gathering data...", "")

    try:
        # ---------- RAG text (best-effort) ----------
        try:
            client, coll = init_chroma()
            retriever = make_retriever(4)
            docs1 = rag_query(sym1, retriever, coll) or []
            docs2 = rag_query(sym2, retriever, coll) or []

            txt1 = "\n".join(getattr(d, "page_content", "") for d in docs1) or "No RAG data."
            txt2 = "\n".join(getattr(d, "page_content", "") for d in docs2) or "No RAG data."
        except Exception as exc:
            print(f"[run_compare_generator] RAG error: {exc}")
            txt1 = txt2 = "No RAG available."

        # ---------- Price history (3y, with .NS fallback) ----------
        try:
            h1 = _download_with_ns(sym1, period="3y")
        except Exception as exc:
            print(f"[run_compare_generator] history error for {sym1}: {exc}")
            h1 = pd.DataFrame()

        try:
            h2 = _download_with_ns(sym2, period="3y")
        except Exception as exc:
            print(f"[run_compare_generator] history error for {sym2}: {exc}")
            h2 = pd.DataFrame()

        parts = [
            f"## Research Summary\n",
            f"### {sym1}\n{txt1}\n\n",
            f"### {sym2}\n{txt2}\n\n",
        ]

        # ---------- Price-based comparison ----------
        if (
            h1 is not None and not h1.empty and "Close" in h1.columns and
            h2 is not None and not h2.empty and "Close" in h2.columns
        ):
            try:
                # compare_two_price_series comes from Block 2
                cmp_md = compare_two_price_series(
                    h1,
                    h2,
                    price_col="Close",
                    name1=sym1,
                    name2=sym2,
                )
                parts.append("\n\n")
                parts.append(cmp_md)
            except Exception as exc:
                print(f"[run_compare_generator] compare helper error: {exc}")
                parts.append("\n\n**Price data available but comparison failed.**")
        else:
            parts.append("**Not enough price data to compare these symbols.**")

        final_md = "".join(parts)
        yield ("", final_md)

    except Exception as exc:
        # Catch-all so Gradio never shows generic ERROR
        yield ("", f"Error while comparing {sym1} and {sym2}: {exc}")


In [90]:
# 8.7 - UI layout and launch
import gradio as gr

# try loading dataset (optional, for dropdown auto-fill)
try:
    df_opts = load_data()
except Exception:
    df_opts = None

def get_choices(col, default):
    if df_opts is None or col not in (df_opts.columns if hasattr(df_opts, "columns") else []):
        return default
    vals = sorted([str(x) for x in df_opts[col].dropna().unique().tolist()], key=str.lower)
    return vals if vals else default


# -----------------------------
# Fixed dropdown options
# -----------------------------
age_opts = ["18-25", "26-35", "36-45", "46-60", "60+"]
income_opts = ["<30,000", "30,000-70,000", "70,000+"]

employment_opts = get_choices(
    "Employment Status",
    ["Salaried", "Self-employed", "Retired"]
)

#  Now using approved / pending / rejected
loan_opts = ["approved", "pending", "rejected"]

goal_opts = get_choices(
    "Investment Goals",
    ["Growth", "Wealth Preservation", "Short-term Safety"]
)


# Gradio App

with gr.Blocks(title="FinRagAssist") as demo:
    gr.Markdown("# **FinRagAssist — Smart Investment Advisor**")

    # RISK PROFILER TAB
    with gr.Tab("Risk Profiler"):
        with gr.Row():
            with gr.Column():
                free_text = gr.Textbox(
                    label="Optional: profile text",
                    placeholder="e.g. I'm 32, salaried, 80k/month, loan approved, goal growth, invest 50k"
                )
                age_in = gr.Dropdown(label="Age Group", choices=age_opts, value="26-35")
                income_in = gr.Dropdown(label="Income Group", choices=income_opts, value="30,000-70,000")
                emp_in = gr.Dropdown(label="Employment Status", choices=employment_opts, value=employment_opts[0])
            
            with gr.Column():
                loan_in = gr.Dropdown(label="Loan Status", choices=loan_opts, value="approved")
                goal_in = gr.Dropdown(label="Investment Goal", choices=goal_opts, value=goal_opts[0])
                amt_in = gr.Number(label="Investment Amount (₹)", value=25000)

                btn = gr.Button("Get Recommendation")

                status_box = gr.Textbox(label="Status", interactive=False)
                out_md = gr.Markdown()
                pie = gr.Plot()
                growth = gr.Plot()

        btn.click(
            run_risk_profile,
            inputs=[free_text, age_in, income_in, emp_in, loan_in, goal_in, amt_in],
            outputs=[status_box, out_md, pie, growth]
        )

    # STOCK ANALYSIS TAB
    with gr.Tab("Stock Analysis"):
        ticker = gr.Textbox(label="Ticker or Name")
        analyze_btn = gr.Button("Analyze")

        st_status = gr.Textbox(label="Status", interactive=False)
        st_md = gr.Markdown()
        st_fig = gr.Plot()

        analyze_btn.click(
            run_stock_analysis_enhanced,
            inputs=[ticker],
            outputs=[st_status, st_md, st_fig]
        )

    # COMPARE TAB
    with gr.Tab("Compare Stocks"):
        a = gr.Textbox(label="Ticker 1")
        b = gr.Textbox(label="Ticker 2")
        cmp_btn = gr.Button("Compare")

        cmp_status = gr.Textbox(label="Status", interactive=False)
        cmp_out = gr.Markdown()

        cmp_btn.click(
            run_compare_generator,
            inputs=[a, b],
            outputs=[cmp_status, cmp_out]
        )


if __name__ == "__main__":
    demo.launch()


Loaded 5000 rows and 18 columns
* Running on local URL:  http://127.0.0.1:7872
* To create a public link, set `share=True` in `launch()`.
