<a href="https://colab.research.google.com/github/aarushikankanmeli97/Generative_AI/blob/main/Multi_AI_Agent_with_RAG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### **Multi AI Agent RAG with LangGraph and AstraDB**

This project focuses on creating an LLM application which consists of a router node, a wikipedia search node and a vectorDB search node. The application takes in user query based on which the router of this application decides whether to do a wikipedia search or the vectorDB search. The vectorDB contains information from several website pages.

The database used here is *AstraDB*.

The final response is then passed to the LLM along woth some task specific prompt to get the final result.


In [None]:
!pip install -q langchain langchain-community langchain-groq langchainhub langchain-huggingface tiktoken langgraph cassio wikipedia

In [None]:
import cassio

##Connection of the Astra DB
ASTRA_DB_APPLICATION_TOKEN="****"
ASTRA_DB_ID="****"

cassio.init(
    token=ASTRA_DB_APPLICATION_TOKEN,
    database_id=ASTRA_DB_ID
)

In [None]:
## Build Index

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader

# Docs to index
urls = [
     "https://lilianweng.github.io/posts/2023-06-23-agent/",
     "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
     "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]

## Load the urls
docs = [WebBaseLoader(url).load() for url in urls]
doc_list = [item for sublist in docs for item in sublist]
print(doc_list)
## Split the document into chunks
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=1000,
    chunk_overlap=200,
)

docs_split = text_splitter.split_documents(doc_list)

In [None]:
docs_split

In [None]:
from langchain_huggingface import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-V2")


In [None]:
from langchain.vectorstores.cassandra import Cassandra
astra_vector_store = Cassandra(
    embedding = embeddings,
    table_name = "demo",
    session = None,
    keyspace = None
)

In [None]:
from langchain.indexes.vectorstore import VectorStoreIndexWrapper
astra_vector_store.add_documents(docs_split)
print("Inserted %i headlines."  % len(docs_split))
astra_vector_index=VectorStoreIndexWrapper(vectorstore=astra_vector_store)

In [None]:
retriever = astra_vector_store.as_retriever()
retriever.invoke("What is agent?")

In [None]:
## Langgraph application

from typing import Literal
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field

In [None]:
# Data Model
class RouteQuery(BaseModel):
  """Route a user query to the most relevant datasource."""

  datasource: Literal["vectorstore", "wiki_search"] = Field(
      ...,
      description = "Given a user question choose to reoute it to wikipedia or a vectorstore.",
  )

In [None]:
from langchain_groq import ChatGroq
from google.colab import userdata
import os

groq_api_key = userdata.get("<ENTER GROQ API KEY>")


In [None]:
llm = ChatGroq(groq_api_key=groq_api_key, model_name="llama-3.3-70b-versatile")
llm.invoke("Hi")

In [None]:
structured_llm_router = llm.with_structured_output(RouteQuery)

In [None]:
## Prompt
system = """You are an expert at routing a user question to a vectorstore or wikipedia.
The vectorstore contains documents related to agents, prompt engineering, and adversarial attacks.
Use the vectorstore for questions on these topics. Otherwise, use wiki-search."""
route_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "{question}"),
    ]
)

question_router = route_prompt | structured_llm_router

In [None]:
print(question_router.invoke({"question": "What is agent?"}))

In [None]:
print(question_router.invoke({"question": "Who was the first person to walk on moon?"}))

In [None]:
from langchain_community.utilities import WikipediaAPIWrapper
from langchain_community.tools import WikipediaQueryRun

api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=500)
wiki = WikipediaQueryRun(api_wrapper=api_wrapper)

In [None]:
wiki.run("What was Apollo 11 mission?")

In [None]:
## AI Agent application using LangGraph

from typing import List
from typing_extensions import TypedDict

class GraphState(TypedDict):
  """
  Represents the state of our graph.

  Attributes:
      question: question
      generation: LLM generation
      documents: list of documents
  """

  question: str
  generation: str
  documents: List[str]

In [None]:
from langchain.schema import Document

def retrieve(state):
  """
    Retrieve documents

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, documents, that contains retrieved documents
  """
  print("----Retrieve----")
  questions = state["question"]

  ## Retreival

  documents = retriever.invoke(questions)
  return {"documents":documents, "question":questions}

In [None]:
def wiki_search(state):
    """
    wiki search based on the re-phrased question.

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): Updates documents key with appended web results
    """

    print("---wikipedia---")

    question = state["question"]
    print(question)

    # Wiki search
    docs = wiki.invoke({"query": question})
    #print(docs["summary"])
    wiki_results = docs
    wiki_results = Document(page_content=wiki_results)

    return {"documents": wiki_results, "question": question}

In [None]:
### Edges ###


def route_question(state):
    """
    Route question to wiki search or RAG.

    Args:
        state (dict): The current graph state

    Returns:
        str: Next node to call
    """

    print("---ROUTE QUESTION---")
    question = state["question"]
    source = question_router.invoke({"question": question})
    if source.datasource == "wiki_search":
        print("---ROUTE QUESTION TO Wiki SEARCH---")
        return "wiki_search"
    elif source.datasource == "vectorstore":
        print("---ROUTE QUESTION TO RAG---")
        return "vectorstore"

In [None]:
### Edges ###


def route_question(state):
    """
    Route question to wiki search or RAG.

    Args:
        state (dict): The current graph state

    Returns:
        str: Next node to call
    """

    print("---ROUTE QUESTION---")
    question = state["question"]
    source = question_router.invoke({"question": question})
    if source.datasource == "wiki_search":
        print("---ROUTE QUESTION TO Wiki SEARCH---")
        return "wiki_search"
    elif source.datasource == "vectorstore":
        print("---ROUTE QUESTION TO RAG---")
        return "vectorstore"

In [None]:
from langgraph.graph import END, StateGraph, START

workflow = StateGraph(GraphState)
# Define the nodes
workflow.add_node("wiki_search", wiki_search)  # web search
workflow.add_node("retrieve", retrieve)  # retrieve

# Build graph
workflow.add_conditional_edges(
    START,
    route_question,
    {
        "wiki_search": "wiki_search",
        "vectorstore": "retrieve",
    },
)
workflow.add_edge( "retrieve", END)
workflow.add_edge( "wiki_search", END)
# Compile
app = workflow.compile()

In [None]:
from IPython.display import Image, display

try:
    display(Image(app.get_graph().draw_mermaid_png()))
except Exception:
    # This requires some extra dependencies and is optional
    pass

In [None]:
from pprint import pprint

# Run
inputs = {
    "question": "What is agent?"
}
for output in app.stream(inputs):
    for key, value in output.items():
        # Node
        pprint(f"Node '{key}':")
        # Optional: print full state at each node
        # pprint.pprint(value["keys"], indent=2, width=80, depth=None)
    pprint("\n---\n")

# Final generation
pprint(value['documents'][0].dict()['metadata']['description'])

In [None]:
from pprint import pprint

# Run
inputs = {
    "question": "Avengers"
}
for output in app.stream(inputs):
    for key, value in output.items():
        # Node
        pprint(f"Node '{key}':")
        # Optional: print full state at each node
        # pprint.pprint(value["keys"], indent=2, width=80, depth=None)
    pprint("\n---\n")

# Final generation
pprint(value['documents'])