In [None]:
#!/usr/bin/env python
# coding: utf-8

import cassio
import os
from dotenv import load_dotenv

load_dotenv()

ASTRA_DB_TOKEN = os.getenv('ASTRA_DB_TOKEN')
ASTRA_DB_ID = os.getenv('ASTRA_DB_ID')

# Initialize cassio database session (assuming session and keyspace)
cassio.init(token=ASTRA_DB_TOKEN, database_id=ASTRA_DB_ID)
session = cassio.get_session()
keyspace = "your_keyspace"  # Ensure you set the correct keyspace


from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader

urls = [
    "https://lilianweng.github.io/posts/2023-06-23-agent/",
    "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
    "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]

docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]

text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=512,
    chunk_overlap=0
)

doc_splits = text_splitter.split_documents(docs_list)


from langchain_huggingface import HuggingFaceEmbeddings

load_dotenv()
HF_API_KEY = os.getenv('HF_API_KEY')
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")

from langchain.vectorstores.cassandra import Cassandra

astra_vector_store = Cassandra(
    embedding=embeddings,
    table_name="qa_table",
    session=session,
    keyspace=keyspace
)

from langchain.indexes.vectorstore import VectorStoreIndexWrapper

astra_vector_store.add_documents(doc_splits)
print(f"Inserted {len(doc_splits)} documents")
astra_vector_index = VectorStoreIndexWrapper(vectorstore=astra_vector_store)

retriever = astra_vector_store.as_retriever()

from typing import Literal
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_groq import ChatGroq

load_dotenv()
GROQ_API_KEY = os.getenv("GROQ_API_KEY")

class RouteQuery(BaseModel):
    datasource: Literal['vectorstore', 'wikisearch'] = Field(
        ...,
        description="Given the user query, choose to route it to wikipedia or vectorstore."
    )

llm = ChatGroq(
    groq_api_key=GROQ_API_KEY, 
    model_name="Llama-3.1-70B-Versatile"
)
structured_llm_router = llm.with_structured_output(RouteQuery)

system = """
You are an expert at routing a user question to a vectorstore or wikipedia.
The vectorstore contains documents related to agents, prompt engineering, and adversarial attacks.
Use the vectorstore for questions on these topics, otherwise perform a wikipedia search.
"""

route_prompt = ChatPromptTemplate.from_messages(
    [
        ('system', system),
        ('human', "{question}"),
    ]
)

ques_router = route_prompt | structured_llm_router


from langchain_community.utilities import WikipediaAPIWrapper
from langchain_community.tools import WikipediaQueryRun

api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=256)
wiki = WikipediaQueryRun(api_wrapper=api_wrapper)

from typing import List
from typing_extensions import TypedDict

class GraphState(TypedDict):
    question: str
    generation: str
    documents: List[str]

def retrieve(state):
    print(f"Retrieving from vectorstore:")
    question = state['question']
    documents = retriever.retrieve(question)
    return {
        "documents": [doc.page_content for doc in documents],
        "questions": question
    }

def wiki_search(state):
    print(f"Performing Wiki Search:")
    question = state['question']
    results = wiki.invoke(
        {
            "query": question
        }
    )
    return {
        "documents": [results],
        "questions": question
    }

def route_question(state):
    print(f"Routing question")
    question = state["question"]
    source = ques_router.invoke(
        {
            "question": question
        }
    )
    if source.datasource == "wikisearch":
        print(f"Routing question to Wiki Search")
        return "wiki_search"
    elif source.datasource == "vectorstore":
        print(f"Routing question to Vectorstore (RAG SYSTEM)")
        return "retrieve"


from langgraph.graph import END, StateGraph, START

workflow = StateGraph(GraphState)
workflow.add_node("wiki_search", wiki_search)
workflow.add_node("retrieve", retrieve)

workflow.add_conditional_edges(
    START,
    route_question,
    {
        "wiki_search": "wiki_search",
        "retrieve": "retrieve"
    }
)

workflow.add_edge("retrieve", END)
workflow.add_edge("wiki_search", END)

app = workflow.compile()
