In [2]:
# Import libraries and set up environment
import json
import asyncio
import logging
from dotenv import load_dotenv
import os
from langchain_groq import ChatGroq
from langchain.tools import StructuredTool
from langgraph.graph import StateGraph, END
from typing import TypedDict, Annotated, List
import operator
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
from mcp.client.streamable_http import streamablehttp_client
from mcp.client.session import ClientSession

# Configure logging
logging.basicConfig(filename="app.log", level=logging.INFO)
load_dotenv()

# Verify environment variables
assert os.getenv("GROQ_API_KEY"), "GROQ_API_KEY not set"
# assert os.getenv("HOROSCOPE_ENDPOINT"), "HOROSCOPE_ENDPOINT not set"


In [3]:
# Load MCP tools from JSON configuration
async def load_mcp_tools(config_path: str):
    with open(config_path, 'r') as f:
        config = json.load(f)
    
    all_tools = []
    for server_name, server_config in config["mcpServers"].items():
        async with streamablehttp_client(server_config["url"]) as (read, write):
            async with ClientSession(read, write) as session:
                await session.initialize()
                tools = await session.list_tools()
                for tool in tools:
                    async def call_tool(params: dict, session=session, tool_name=tool["name"]):
                        return await session.call_tool(tool_name, params)
                    schema = {
                        "type": "object",
                        "properties": {
                            "query": {"type": "string"},
                            "zodiac_sign": {"type": "string"},
                            "horoscope_type": {"type": "string", "enum": ["DAILY", "MONTHLY"]}
                        },
                        "required": ["zodiac_sign"] if tool["name"] == "get_horoscope" else ["query"]
                    }
                    all_tools.append(
                        StructuredTool.from_function(
                            func=call_tool,
                            name=tool["name"],
                            description=tool["description"],
                            args_schema=schema
                        )
                    )
    return all_tools


In [10]:
# Define LangGraph state and workflow
class GraphState(TypedDict):
    messages: Annotated[List[HumanMessage | AIMessage | ToolMessage], operator.add]
    tools_called: List[str]

async def setup_llm_and_tools():
    tools = await load_mcp_tools("mcp_config.json")
    llm = ChatGroq(model="llama3-8b-8192", api_key=os.getenv("GROQ_API_KEY"))
    return llm.bind_tools(tools), tools

async def llm_node(state: GraphState, llm):
    logging.info(f"Processing messages: {[msg.content for msg in state['messages']]}")
    response = await llm.ainvoke(state["messages"])
    tools_called = state.get("tools_called", []) + [call["name"] for call in response.tool_calls]
    return {"messages": [response], "tools_called": tools_called}

async def tool_node(state: GraphState, tools):
    tool_calls = state["messages"][-1].tool_calls
    results = []
    for call in tool_calls:
        tool = next(t for t in tools if t.name == call["name"])
        logging.info(f"Invoking tool: {call['name']} with args: {call['args']}")
        result = await tool.ainvoke(call["args"])
        results.append(ToolMessage(content=str(result), tool_call_id=call["id"]))
    return {"messages": results}

def should_continue(state: GraphState):
    last_message = state["messages"][-1]
    return "continue" if hasattr(last_message, "tool_calls") and last_message.tool_calls else "end"

async def create_workflow():
    llm, tools = await setup_llm_and_tools()
    workflow = StateGraph(GraphState)
    workflow.add_node("llm", lambda state: llm_node(state, llm))
    workflow.add_node("tools", lambda state: tool_node(state, tools))
    workflow.set_entry_point("llm")
    workflow.add_conditional_edges("llm", should_continue, {"continue": "tools", "end": END})
    workflow.add_edge("tools", "llm")
    app = workflow.compile()
    print(app)
    return workflow.compile(), llm, tools


In [11]:
# Test the workflow with sample prompts
async def run_prompt(prompt: str):
    app, _, _ = await create_workflow()
    result = await app.ainvoke({"messages": [HumanMessage(content=prompt)], "tools_called": []})
    logging.info(f"Tools called: {result['tools_called']}")
    return result["messages"][-1].content

# Run prompts (Jupyter handles async automatically with IPython)
prompts = [
    "What's the daily horoscope for Virgo?",
    "What's the monthly horoscope for Leo?",
    "Search for recent AI news"
]
for prompt in prompts:
    result = await run_prompt(prompt)
    print(f"Prompt: {prompt}\nResponse: {result}\n")


  + Exception Group Traceback (most recent call last):
  |   File "/Users/saibhargavrallapalli/Documents/Git/mcp_integration/.venv/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3670, in run_code
  |     await eval(code_obj, self.user_global_ns, self.user_ns)
  |   File "/var/folders/n2/2s0tjc891jx80m_1lwcfhm_40000gn/T/ipykernel_89285/2450133448.py", line 15, in <module>
  |     result = await run_prompt(prompt)
  |              ^^^^^^^^^^^^^^^^^^^^^^^^
  |   File "/var/folders/n2/2s0tjc891jx80m_1lwcfhm_40000gn/T/ipykernel_89285/2450133448.py", line 3, in run_prompt
  |     app, _, _ = await create_workflow()
  |                 ^^^^^^^^^^^^^^^^^^^^^^^
  |   File "/var/folders/n2/2s0tjc891jx80m_1lwcfhm_40000gn/T/ipykernel_89285/3753619681.py", line 32, in create_workflow
  |     llm, tools = await setup_llm_and_tools()
  |                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  |   File "/var/folders/n2/2s0tjc891jx80m_1lwcfhm_40000gn/T/ipykernel_89285/3753619681.py", line 

In [None]:
app