## Import Required Packages

In [24]:
import os
import re
import json
import uuid
import sqlite3
import difflib
from typing import Optional, Annotated

from dotenv import load_dotenv
from typing_extensions import TypedDict, NotRequired

from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_ollama import ChatOllama

from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langgraph.checkpoint.sqlite import SqliteSaver

from langchain_community.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
from langchain_core.tools import tool
from pprint import pprint
from utils.logger_config import setup_logger

from utils import build_tools, safe_float, safe_json_loads, quiet_call


## Setup

In [2]:
load_dotenv()

LOGS_DIR = os.getenv("LOGS_DIR", "logs")
os.makedirs(LOGS_DIR, exist_ok=True)

logger = setup_logger(
    debug_mode=True,
    log_name="InventmentAgent",
    log_dir=LOGS_DIR,
)

In [3]:
class Config:
    OLLAMA_MODEL = os.getenv("OLLAMA_MODEL","granite4:350m") # "llama3.1:8b")
    TEMPERATURE = float(os.getenv("TEMPERATURE", "0.0"))
    DDG_RESULTS = int(os.getenv("DDG_RESULTS", "5"))
    CHECKPOINT_DB = os.getenv("CHECKPOINT_DB", "invest_agent_checkpoints.sqlite")
    MASTER_STOCK_FILE = os.getenv("MASTER_STOCK_FILE", "all_nse_stocks.json")

In [4]:
llm = ChatOllama(
    model=Config.OLLAMA_MODEL,
    temperature=Config.TEMPERATURE,
    verbose=True,
)

In [6]:
#ddg = DuckDuckGoSearchResults(num_results=Config.DDG_RESULTS)
ddg_api = DuckDuckGoSearchAPIWrapper()

In [None]:
# mlflow.set_tracking_uri(Config.MLFLOW_TRACKING_URI)
# mlflow.set_experiment(Config.MLFLOW_EXPERIMENT)
# mlflow.langchain.autolog()

In [7]:
def print_workflow_info(workflow, app=None, *, title="WORKFLOW INFORMATION"):
    """
    Pretty-prints a LangGraph workflow using Rich (with a pprint fallback).

    - Shows nodes, edges, and finish points (best-effort attribute probing).
    - Optionally displays a notebook visualization if `app` is provided.
    """
    # --- Extract workflow info (best effort / defensive) ---
    nodes = getattr(workflow, "nodes", None)
    edges = getattr(workflow, "edges", None)

    # finish points: try public attr first, then common private fallback(s)
    finish_points = None
    for attr in ("finish_points", "_finish_points", "_finish_point", "finish_point"):
        if hasattr(workflow, attr):
            finish_points = getattr(workflow, attr)
            break

    info = {
        "nodes": nodes,
        "edges": edges,
        "finish_points": finish_points,
        "node_count": len(nodes) if hasattr(nodes, "__len__") else None,
        "edge_count": len(edges) if hasattr(edges, "__len__") else None,
    }

    # --- Pretty output ---
    try:
        from rich.console import Console
        from rich.panel import Panel
        from rich.table import Table
        from rich.pretty import Pretty

        console = Console()

        table = Table(show_header=False, box=None)
        table.add_row("Nodes", str(info["node_count"]))
        table.add_row("Edges", str(info["edge_count"]))
        table.add_row("Finish points", "Yes" if info["finish_points"] is not None else "Unknown")

        console.print(Panel(table, title=title))
        console.print(Panel(Pretty(info, expand_all=True), title="Details"))

    except Exception:
        # Fallback: standard library pprint
        import pprint
        print(title)
        print("=" * len(title))
        pprint.pprint(info, width=100, sort_dicts=False)

    # --- Optional notebook visualization ---
    if app:
        try:
            from IPython.display import display, Image
            display(Image(app.get_graph().draw_mermaid_png()))
        except Exception as e:
            print(f"(Visualization unavailable: {e})")

In [8]:
with open(Config.MASTER_STOCK_FILE, "r", encoding="utf-8") as f:
    NSE_STOCKS: list[dict[str, str]] = json.load(f)

SYMBOL_INDEX: dict[str, dict[str, str]] = {row["symbol"].upper(): row for row in NSE_STOCKS}
NAME_INDEX: list[tuple[str, str]] = [
    (row["symbol"].upper(), row["companyName"].lower()) for row in NSE_STOCKS
]

In [10]:
SYMBOL_INDEX
NAME_INDEX

[('20MICRONS', '20 microns limited'),
 ('21STCENMGM', '21st century management services limited'),
 ('360ONE', '360 one wam limited'),
 ('3IINFOLTD', '3i infotech limited'),
 ('3MINDIA', '3m india limited'),
 ('3PLAND', '3p land holdings limited'),
 ('5PAISA', '5paisa capital limited'),
 ('63MOONS', '63 moons technologies limited'),
 ('A2ZINFRA', 'a2z infra engineering limited'),
 ('AADHARHFC', 'aadhar housing finance limited'),
 ('AAKASH', 'aakash exploration services limited'),
 ('AAREYDRUGS', 'aarey drugs & pharmaceuticals limited'),
 ('AARON', 'aaron industries limited'),
 ('AARTECH', 'aartech solonics limited'),
 ('AARTIDRUGS', 'aarti drugs limited'),
 ('AARTIIND', 'aarti industries limited'),
 ('AARTIPHARM', 'aarti pharmalabs limited'),
 ('AARTISURF', 'aarti surfactants limited'),
 ('AARVI', 'aarvi encon limited'),
 ('AAVAS', 'aavas financiers limited'),
 ('ABB', 'abb india limited'),
 ('ABBOTINDIA', 'abbott india limited'),
 ('ABCAPITAL', 'aditya birla capital limited'),
 ('ABCO

## Agent State

In [11]:
class StockState(TypedDict):
    symbol: str
    company: str

    price: Optional[float]
    change: Optional[float]
    p_change: Optional[float]
    previous_close: Optional[float]
    open: Optional[float]
    day_high: Optional[float]
    day_low: Optional[float]
    vwap: Optional[float]
    w52_high: Optional[float]
    w52_low: Optional[float]

    pe: NotRequired[Optional[float]]
    sector_pe: NotRequired[Optional[float]]
    industry: NotRequired[str]
    last_update: NotRequired[str]

    raw: NotRequired[dict]
    error: NotRequired[str]

In [12]:
class InvestState(StockState):
    messages: Annotated[list[BaseMessage], add_messages]
    query: str

    news: NotRequired[list[dict]]
    final: NotRequired[dict]

    resolved_from: NotRequired[str]
    symbol_candidates: NotRequired[list[str]]

In [13]:

def init_state(query: str) -> InvestState:
    return {
        "query": query,
        "messages": [HumanMessage(content=query)],

        "symbol": "",
        "company": "N/A",

        "price": None,
        "change": None,
        "p_change": None,
        "previous_close": None,
        "open": None,
        "day_high": None,
        "day_low": None,
        "vwap": None,
        "w52_high": None,
        "w52_low": None,
    }


## Tools

In [14]:
bundle = build_tools(
    llm=llm,
    ddg_api=ddg_api,
    ddg_results=Config.DDG_RESULTS,
    safe_json_loads=safe_json_loads,
    safe_float=safe_float,
    quiet_call=quiet_call,
    logger=logger,
)

search_tool = bundle.tool_map["search_tool"]
llm_extract_company = bundle.tool_map["llm_extract_company"]
get_nse_stock_data = bundle.tool_map["get_nse_stock_data"]
llm_with_tools = bundle.llm_with_tools
llm_choose_symbol = bundle.llm_choose_symbol

06:43:44,595 tools.py [MainThread] - INFO :165 - Tools bound to LLM: ['search_tool', 'llm_extract_company', 'get_nse_stock_data']


### Search

In [17]:
print("Name: \n", search_tool.name)
print("\nDescription: \n", search_tool.description) 
print("\nArgs: \n", search_tool.args)

Name: 
 search_tool

Description: 
 Search the web using DuckDuckGo and return structured results (title/link/snippet).

Args: 
 {'query': {'title': 'Query', 'type': 'string'}}


In [19]:
stock_symbol = "HINDUSTAN AERONAUTICS LIMITED" 
pprint(search_tool.invoke(stock_symbol))

[{'link': 'https://en.wikipedia.org/wiki/Hindustan_Aeronautics_Limited',
  'snippet': '2 days ago - Hindustan Aeronautics Limited (HAL) is an Indian '
             'public sector aerospace and defence company . Headquartered in '
             'Bengaluru, it is involved in the designing, manufacturing and '
             'overhaul of combat aircraft, helicopters, unmanned aerial '
             'vehicles, jet and turbine engines, avionics, and other hardware.',
  'title': 'Hindustan Aeronautics Limited - Wikipedia'},
 {'link': 'https://grokipedia.com/page/Hindustan_Aeronautics_Limited',
  'snippet': 'HAL LogoHindustan Aeronautics Limited (HAL) is a Government of '
             'India-owned aerospace and defence company headquartered in '
             'Bengaluru, Karnataka, that designs, develops,...',
  'title': 'Hindustan Aeronautics Limited'},
 {'link': 'https://vajiramandravi.com/upsc-exam/hindustan-aeronautics-limited-hal/',
  'snippet': '3 weeks ago - Hindustan Aeronautics Limited (H

### Extract Company Name

In [None]:
import json
import re
from rapidfuzz import process, fuzz
from langchain_core.tools import tool

# ---------------------------
# Normalization
# ---------------------------

def normalize(text: str) -> str:
    text = text.lower()
    text = re.sub(r'[^a-z0-9 ]', ' ', text)
    text = re.sub(r'\b(limited|ltd|private|pvt|india)\b', '', text)
    return re.sub(r'\s+', ' ', text).strip()

# ---------------------------
# Load NSE data
# ---------------------------

with open("all_nse_stocks.json") as f:
    nse_data = json.load(f)

NSE_LOOKUP = [
    {
        "symbol": item["symbol"],
        "name": item["companyName"],
        "norm": normalize(item["companyName"])
    }
    for item in nse_data
]

NSE_NAMES = [c["norm"] for c in NSE_LOOKUP]

# ---------------------------
# Abbreviations (optional boost)
# ---------------------------

ABBREVIATIONS = {
    "hal": "HAL",
    "ril": "RELIANCE",
    "tcs": "TCS"
}

# ---------------------------
# Candidate detection
# ---------------------------

STOPWORDS = {
    "of", "with", "and", "compare", "stock", "shares",
    "industries", "india", "limited", "ltd"
}

def detect_candidates(sentence: str) -> set:
    sentence_norm = normalize(sentence)
    words = sentence_norm.split()

    candidates = set()

    for n in (1, 2, 3):
        for i in range(len(words) - n + 1):
            span = " ".join(words[i:i+n])

            if len(span) < 4:
                continue
            if span in STOPWORDS:
                continue
            if span in sentence_norm:
                candidates.add(span)

    return candidates

# ---------------------------
# Token containment rule
# ---------------------------

def sentence_tokens(sentence: str) -> set:
    return set(normalize(sentence).split())

def is_valid_match(company_norm: str, sent_tokens: set) -> bool:
    IGNORE = {"limited", "ltd", "india"}
    company_tokens = set(company_norm.split()) - IGNORE
    return company_tokens.issubset(sent_tokens)

# ---------------------------
# Symbol matching (SAFE)
# ---------------------------

def match_symbols(candidates: set, sentence: str, threshold: int = 90) -> list:
    found = set()
    sent_tokens = sentence_tokens(sentence)

    for candidate in candidates:
        match, score, idx = process.extractOne(
            candidate,
            NSE_NAMES,
            scorer=fuzz.token_set_ratio
        )

        if score < threshold:
            continue

        company = NSE_LOOKUP[idx]

        if not is_valid_match(company["norm"], sent_tokens):
            continue

        found.add(company["symbol"])

    return sorted(found)

# ---------------------------
# Main extractor
# ---------------------------

def extract_symbols_from_sentence(sentence: str) -> list:
    sentence_norm = normalize(sentence)

    # 1️⃣ Abbreviation shortcut
    symbols = {
        sym for key, sym in ABBREVIATIONS.items()
        if key in sentence_norm
    }

    # 2️⃣ Candidate + fuzzy validation
    candidates = detect_candidates(sentence)
    symbols.update(match_symbols(candidates, sentence))

    return list(symbols)

# ---------------------------
# LangChain tool
# ---------------------------

@tool
def extract_nse_symbols(sentence: str) -> list:
    """
    Extract NSE trading symbols from a full sentence.
    """
    symbols = extract_symbols_from_sentence(sentence)
    return symbols or ["No NSE-listed company found"]

# ---------------------------
# Test
# ---------------------------

query = "Compare of Reliance Industries, Tata COnsultancy  with HAL Stock?"
extract_nse_symbols.invoke(query)


['TCS', 'HAL', 'RELIANCE']

In [64]:
import json
import re
from rapidfuzz import process, fuzz
from langchain_core.tools import tool

# =====================================================
# Normalization (aggressive – for human input)
# =====================================================

def normalize(text: str) -> str:
    text = text.lower()
    text = re.sub(r'[^a-z0-9 ]', ' ', text)
    text = re.sub(r'\b(limited|ltd|private|pvt|india|stock|share|shares)\b', '', text)
    return re.sub(r'\s+', ' ', text).strip()

# =====================================================
# Load NSE data
# =====================================================

with open("all_nse_stocks.json") as f:
    nse_data = json.load(f)

NSE_LOOKUP = [
    {
        "symbol": item["symbol"],
        "norm": normalize(item["companyName"])
    }
    for item in nse_data
]

NSE_NAMES = [c["norm"] for c in NSE_LOOKUP]

# =====================================================
# Alias & abbreviation map (chatbot boost)
# =====================================================

ALIASES = {
    "hal": "HAL",
    "ril": "RELIANCE",
    "reliance india": "RELIANCE",
    "tcs": "TCS",
    "tata consultancy": "TCS",
    "tata consultncy": "TCS",
    "hdfc": "HDFCBANK",
    "l&t": "LT"
}

# =====================================================
# Candidate detection (forgiving)
# =====================================================

def detect_candidates(sentence: str) -> set:
    sentence_norm = normalize(sentence)
    words = sentence_norm.split()
    candidates = set()

    # 1–3 gram spans
    for n in (1, 2, 3):
        for i in range(len(words) - n + 1):
            span = " ".join(words[i:i+n])
            if len(span) >= 4:
                candidates.add(span)

    return candidates

# =====================================================
# Relaxed validation (chatbot mode)
# =====================================================

def token_overlap_ok(company_norm: str, sentence_norm: str) -> bool:
    company_tokens = set(company_norm.split())
    sentence_tokens = set(sentence_norm.split())

    overlap = company_tokens & sentence_tokens
    return len(overlap) / max(len(company_tokens), 1) >= 0.6

# =====================================================
# Symbol matching (safe + forgiving)
# =====================================================

def match_symbols(candidates: set, sentence: str, threshold: int = 80) -> list:
    sentence_norm = normalize(sentence)
    found = set()

    for candidate in candidates:
        match, score, idx = process.extractOne(
            candidate,
            NSE_NAMES,
            scorer=fuzz.token_set_ratio
        )

        if score < threshold:
            continue

        company = NSE_LOOKUP[idx]

        if not token_overlap_ok(company["norm"], sentence_norm):
            continue

        found.add(company["symbol"])

    return sorted(found)

# =====================================================
# Main extractor
# =====================================================

def extract_symbols_from_sentence(sentence: str) -> list:
    sentence_norm = normalize(sentence)
    symbols = set()

    # 1️⃣ Alias resolution first (very important for chat)
    for key, sym in ALIASES.items():
        if key in sentence_norm:
            symbols.add(sym)

    # 2️⃣ Fuzzy dictionary match
    candidates = detect_candidates(sentence)
    symbols.update(match_symbols(candidates, sentence))

    return list(symbols)

# =====================================================
# LangChain tool
# =====================================================

@tool
def extract_nse_symbols(sentence: str) -> list:
    """
    Extract NSE trading symbols from user sentences.
    Forgiving to typos, partial names, and abbreviations.
    """
    symbols = extract_symbols_from_sentence(sentence)
    return symbols or ["No NSE-listed company found"]

# =====================================================
# Test
# =====================================================

query = "Compare of Reliance Industries, Tata COnsultancy with HAL Stock?"
extract_nse_symbols.invoke(query)


['TCS', 'RELIANCE', 'HAL', 'RELCHEMQ']

['RELIANCE', 'RELCHEMQ']

In [None]:
@tool
def llm_extract_company(query: str) -> dict:
    """Extract company mention + intent. Returns JSON: {company, intent}."""
    prompt = (
        "Extract the Indian listed company mention from the user query.\n"
        "Return ONLY JSON: {\"company\":\"...\",\"intent\":\"price|analysis|news|unknown\"}.\n"
        "Rules: company should be short (e.g., 'Larsen and Toubro', 'TCS'); "
        "exclude words like price/stock/share/today/lastprice.\n"
        f"User query: {query}"
    )
    resp = llm.invoke(prompt)
    obj = safe_json_loads(resp.content) or {}
    return {
        "company": (obj.get("company") or "").strip(),
        "intent": (obj.get("intent") or "unknown").strip(),
    }

In [20]:
print("Name: \n", llm_extract_company.name)
print("\nDescription: \n", llm_extract_company.description) 
print("\nArgs: \n", llm_extract_company.args)

Name: 
 llm_extract_company

Description: 
 Extract company mention + intent. Returns JSON: {company, intent}.

Args: 
 {'query': {'title': 'Query', 'type': 'string'}}


In [21]:
query = 'What is the current price of Reliance India?'
pprint(llm_extract_company.invoke(query))

{'company': 'Reliance Industries', 'intent': 'price|analysis|news|unknown'}


In [None]:
query = 'Compare of Reliance India with HAL Stock?'
pprint(llm_extract_company.invoke(query))

{'company': 'Reliance Industries', 'intent': 'analysis'}


### Stock Data

In [15]:
print("Name: \n", get_nse_stock_data.name)
print("\nDescription: \n", get_nse_stock_data.description) 
print("\nArgs: \n", get_nse_stock_data.args)

Name: 
 get_nse_stock_data

Description: 
 Fetch NSE equity quote via nsepython (nse_eq primary, fallback to nse_fno, then secfno lastPrice).

Args: 
 {'symbol': {'title': 'Symbol', 'type': 'string'}}


In [16]:
stock_symbol = "HAL" 
pprint(get_nse_stock_data.invoke(stock_symbol))

{'52w_high': 5165.0,
 '52w_low': 3046.05,
 'change': 284.5,
 'company': 'Hindustan Aeronautics Limited',
 'day_high': 4638.6,
 'day_low': 4361.5,
 'industry': 'Aerospace & Defense',
 'last_update': '28-Jan-2026 16:00:00',
 'open': 4370.0,
 'p_change': 6.542485914683224,
 'pe': 34.5,
 'previous_close': 4348.5,
 'price': 4633.0,
 'raw': {'currentMarketType': 'NM',
         'industryInfo': {'basicIndustry': 'Aerospace & Defense',
                          'industry': 'Aerospace & Defense',
                          'macro': 'Industrials',
                          'sector': 'Capital Goods'},
         'info': {'activeSeries': ['EQ'],
                  'companyName': 'Hindustan Aeronautics Limited',
                  'debtSeries': [],
                  'identifier': 'HALEQN',
                  'industry': 'Aerospace & Defense',
                  'isCASec': False,
                  'isDebtSec': False,
                  'isDelisted': False,
                  'isETFSec': False,
               

## Prompt

In [40]:
SYSTEM_TEXT = (
    "You are an Investment agent who uses tools to get the latest stock data. "
    "Respond ONLY in valid JSON with keys: input, tool_called, output."
)

investor_prompt = ChatPromptTemplate.from_messages([
    ("system", SYSTEM_TEXT),
    ("human", "{query}"),
])


In [41]:
prompt_value = investor_prompt.invoke({"query": "HAL"})   # ChatPromptValue [web:52]
messages = prompt_value.to_messages()
messages

[SystemMessage(content='You are an Investment agent who uses tools to get the latest stock data. Respond ONLY in valid JSON with keys: input, tool_called, output.', additional_kwargs={}, response_metadata={}),
 HumanMessage(content='HAL', additional_kwargs={}, response_metadata={})]

In [39]:
tools = [search_tool, get_nse_stock_data]

llm_with_tools = llm.bind_tools(tools)

In [None]:
def call_model(state: MessagesState):
    resp = model_with_tools.invoke(state["messages"])
    return {"messages": [resp]}

def route(state: MessagesState):
    last = state["messages"][-1]
    return "tools" if getattr(last, "tool_calls", None) else END