In [None]:
%%capture --no-stderr
%pip install --quiet -U langgraph google-genai==0.3.0 langchain_community langchain-core==0.3.21 langchain beautifulsoup4 tiktoken langchain-pinecone pinecone-notebooks sentence-transformers

In [7]:
import os
from IPython.display import Image, display
from google.genai import Client
from typing import Literal
from langchain import hub
from langchain_core.messages import HumanMessage
from langchain_core.output_parsers import StrOutputParser
from pydantic import BaseModel, Field
from langgraph.prebuilt import tools_condition
from langchain.tools.retriever import create_retriever_tool
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langgraph.graph import StateGraph, MessagesState, START, END
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import ToolNode
from google.genai.types import Tool, GenerateContentConfig, GoogleSearch
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_pinecone import PineconeVectorStore as lang_pinecone

In [None]:
index_name = os.getenv('INDEX_NAME')
pinecone_api_key= os.getenv('PINECONE_API_KEY')
embed_model = HuggingFaceEmbeddings(model_name='BAAI/bge-small-en-v1.5')

In [9]:
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")

client = Client(
  api_key=GOOGLE_API_KEY
)

memory: MemorySaver = MemorySaver()

config = {"configurable": {"thread_id": "1"}}

In [None]:
urls = [
    "https://huggingface.co/datasets/ruslanmv/ai-medical-chatbot",
]

docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]

text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=100, chunk_overlap=50
)
doc_splits = text_splitter.split_documents(docs_list)

# Add to vectorDB
vectorstore = lang_pinecone.from_documents(
    doc_splits,
    embed_model,
    index_name = index_name
)

retriever = vectorstore.as_retriever()

In [6]:
retriever_tool = create_retriever_tool(
    retriever,
    "retrieve_blog_posts",
    "Search and return information about Lilian Weng blog posts on LLM agents, prompt engineering, and adversarial attacks on LLMs.",
)

google_search_tool = Tool(
    google_search = GoogleSearch()
)

tools = [retriever_tool, google_search_tool]

In [27]:
model = client.models.generate_content(
        model="gemini-2.0-flash-exp",
        contents="what is the current weather in karachi?",
        config=GenerateContentConfig(
            tools = [google_search_tool],
            response_modalities=["TEXT"]
        )
    )

class State(MessagesState):
  pass

In [18]:
### Edges

def grade_documents(state) -> Literal["generate", "rewrite"]:
    """
    Determines whether the retrieved documents are relevant to the question.

    Args:
        state (messages): The current state

    Returns:
        str: A decision for whether the documents are relevant or not
    """
    print("---CHECK RELEVANCE---")

    class grade(BaseModel):
        """Binary score for relevance check."""
        binary_score: str = Field(description="Relevance score 'yes' or 'no'")
        

    # LLM with tool and validation
    llm_with_tool = model.with_structured_output(grade)

In [19]:
### Nodes

def agent(state):
    """
    Invokes the agent model to generate a response based on the current state. Given
    the question, it will decide to retrieve using the retriever tool, or simply end.

    Args:
        state (messages): The current state

    Returns:
        dict: The updated state with the agent response appended to messages
    """
    print("---CALL AGENT---")
    messages = state["messages"]
    
    # Gemini 2.0 Flash Model invocation
    model = client.models.generate_content(
        temperature=0,
        streaming=True,
        model="gemini-2.0-flash",
        
        config=GenerateContentConfig(
        tools=tools,
        response_modalities=["TEXT"]
        )
    )
    
    response = model.invoke(messages)
    
    return {"messages": [response]}

In [20]:
def rewrite(state):
    """
    Transform the query to produce a better question.

    Args:
        state (messages): The current state

    Returns:
        dict: The updated state with re-phrased question
    """
    print("---TRANSFORM QUERY---")
    messages = state["messages"]
    question = messages[0].content

    msg = [
        HumanMessage(
            content=f""" \n 
    Look at the input and try to reason about the underlying semantic intent / meaning. \n 
    Here is the initial question:
    \n ------- \n
    {question} 
    \n ------- \n
    Formulate an improved question: """,
        )
    ]

    response = model.invoke(msg)
    
    return {"messages": [response]}

In [21]:
def generate(state):
    """
    Generate answer

    Args:
        state (messages): The current state

    Returns:
         dict: The updated state with re-phrased question
    """
    print("---GENERATE---")
    messages = state["messages"]
    question = messages[0].content
    last_message = messages[-1]

    docs = last_message.content

    # Gemini 2.0 Flash Prompt
    prompt = hub.pull("rlm/rag-prompt")

    # Post-processing
    def format_docs(docs):
        return "\n\n".join(doc.page_content for doc in docs)

    # Chain
    rag_chain = prompt | StrOutputParser()

    # Run
    response = rag_chain.invoke({"context": docs, "question": question})
    return {"messages": [response]}

In [22]:
# Define a new graph
workflow = StateGraph(State)

# Define the nodes we will cycle between
workflow.add_node("agent", agent)  # agent
retrieve = ToolNode([retriever_tool])
workflow.add_node("retrieve", retrieve)  # retrieval
workflow.add_node("rewrite", rewrite)  # Re-writing the question
workflow.add_node(
    "generate", generate
)  # Generating a response after we know the documents are relevant

# Call agent node to decide to retrieve or not
workflow.add_edge(START, "agent")

# Decide whether to retrieve
workflow.add_conditional_edges(
    "agent",
    tools_condition,
    {
        "tools": "retrieve",
        END: END,
    },
)

# Edges taken after the `action` node is called.
workflow.add_conditional_edges(
    "retrieve",
    grade_documents,
)
workflow.add_edge("generate", END)
workflow.add_edge("rewrite", "agent")

# Compile
graph = workflow.compile(checkpointer=memory)

In [None]:
display(Image(graph.get_graph().draw_mermaid_png()))