In [None]:
import os
import cassio
from dotenv import load_dotenv
load_dotenv()

In [None]:
cassio.init(token=os.getenv("ASTRA_DB_APPLICATION_TOKEN"), database_id=os.getenv("ASTRA_DB_ID"))

In [None]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader

In [None]:
urls = [
    "https://lilianweng.github.io/posts/2023-06-23-agent/",
    "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
    "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm"
]

In [None]:
docs = [WebBaseLoader(url).load() for url in urls]
doc_list = [item for sublist in docs for item in sublist]
print(doc_list)
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=500, chunk_overlap=0)
docs_split = text_splitter.split_documents(doc_list)

In [None]:
docs_split

In [None]:
from langchain_huggingface import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")

In [None]:
from langchain.vectorstores.cassandra import Cassandra
astra_vector_store = Cassandra(embedding=embeddings,
                               table_name="qa_mini_demo",
                               session=None,
                               keyspace=None)

In [None]:
from langchain.indexes.vectorstore import VectorStoreIndexWrapper
astra_vector_store.add_documents(docs_split)
print("Inserted %i headlines." %len(docs_split))
astra_vector_index = VectorStoreIndexWrapper(vectorstore=astra_vector_store)

In [None]:
retriver = astra_vector_store.as_retriever()
retriver.invoke("What is agent")

In [None]:
from typing import Literal
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field

In [None]:
class RouteQuery(BaseModel):
    """Route a user query to the most relevant datasource."""
    datasource: Literal["vectorstore", "wiki_search"] = Field(
        ...,
        description="Given a user question chose to route it to wikipedia or vectorstore.",
    )

In [None]:
from langchain_groq import ChatGroq
import os

groq_api_key = os.getenv("GROQ_API_KEY")
llm = ChatGroq(groq_api_key=groq_api_key, model_name="Llama-3.3-70b-Versatile")
llm

In [None]:
structured_llm_router = llm.with_structured_output(RouteQuery)

In [None]:
system = """You are an expert at routing a user question to a verctorstore or wikipedia.
The vectorstore contains documents related to agents, prompt engineering, and adversial attacks.
Use the vectorstore for questions on these topics. Otherwise, use wiki-search."""

route_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "{question}"),
    ]
)

question_router = route_prompt|structured_llm_router

In [None]:
print(question_router.invoke(
    {
        "question": "What is agent ?"
    }
))

In [None]:
print(question_router.invoke(
    {
        "question": "Who is sharukh khan ?"
    }
))

In [None]:
from langchain_community.utilities import WikipediaAPIWrapper
from langchain_community.tools import WikipediaQueryRun

api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=200)
wiki = WikipediaQueryRun(api_wrapper=api_wrapper)

In [None]:
from typing import List
from typing_extensions import TypedDict


class GraphState(TypedDict):

    """
    Represents the state of our graph.
    Attributes:
        question: question
        generation: LLM generation
        documents: list of documents
    """

    question:str
    generation: str
    documents: List[str]

In [None]:
from langchain.schema import Document

def retrive(state):
    """
        Retireve documents
        Args:
            state (dict): The current graph state

        Returns:
            state (dict): New key added to state, documents, that contains retrived documents
    """

    print("----Retrieve----")
    question = state["question"]
    documents = retriver.invoke(question)
    return {"documents": documents, "question": question}

In [None]:
def wiki_search(state):
    """
    wiki search based on the re-phrased question.
    Args:
        state (dict): The current graph state

    Returns:
        state (dict): Updates documents key with appended web results
    """

    print("----wikipedia----")
    
    question = state["question"]
    print(question)

    docs = wiki.invoke({"query": question})
    wiki_results = docs
    wiki_results = Document(page_content=wiki_results)

    return {"documents": wiki_results, "question": question}

In [None]:
def route_question(state):
    """
        Route question to wiki search or RAG.
        Args:
            state (dict): The current graph state
        Returns:
            str: Next node to call
    """

    print("--- ROUTE QUESTION ---")
    question = state["question"]
    source = question_router.invoke({"question": question})

    if source.datasource == "wiki_search":
        print("---Route Question to Wiki Search---")
        return "wiki_search"
    elif source.datasource == "vectorstore":
        print("---ROUTE Question to RAG---")
        return "vectorstore"

In [None]:
from langgraph.graph import END, StateGraph, START

In [None]:
workflow = StateGraph(GraphState)
workflow.add_node("wiki_search", wiki_search)
workflow.add_node("retrieve", retrive)

workflow.add_conditional_edges(
    START,
    route_question,
    {
        "wiki_search": "wiki_search",
        "vectorstore": "retrieve"
    },
)
workflow.add_edge("retrieve", END)
workflow.add_edge("wiki_search", END)

app = workflow.compile()

In [None]:
from IPython.display import Image, display

try:
    display(Image(app.get_graph().draw_mermaid_png()))
except Exception:
    pass

In [None]:
from pprint import pprint

inputs = {
    "question": "What is agent ?"
}

for output in app.stream(inputs):
    for key, value in output.items():
        pprint(f"Node ' {key}' :")
    pprint("\n---\n")

pprint(value['documents'][0].dict()['metadata']['description'])

In [None]:
from pprint import pprint

inputs = {
    "question": "Avengers"
}

for output in app.stream(inputs):
    for key, value in output.items():
        pprint(f"Node ' {key}' :")
    pprint("\n---\n")

pprint(value['documents'])