# Retrieval Augmented Generation

In [None]:
%pip install --quiet --upgrade langchain-text-splitters langchain-community langgraph
%pip install --quiet langchain-ollama langchain-pinecone
%pip install --quiet bs4

In [2]:
import os
from dotenv import load_dotenv

load_dotenv()

# Define MariTalk API key and LLM model
MARITALK_API_KEY = os.getenv('MARITALK_API_KEY')
MARITALK_LLM_MODEL = "sabia-3"

# Define Langsmith API key and tracing
LANGSMITH_API_KEY = os.getenv('LANGSMITH_API_KEY')
LANGSMITH_TRACING = os.getenv('LANGSMITH_TRACING')

# Define Pinecone API key
PINECONE_API_KEY = os.getenv('PINECONE_API_KEY')

In [4]:
from langchain_community.chat_models import ChatMaritalk

# Initialize MariTalk LLM model
llm = ChatMaritalk(
    model=MARITALK_LLM_MODEL,
    api_key=MARITALK_API_KEY,
    max_tokens=1000,
)

In [None]:
from langchain_ollama import OllamaEmbeddings
from langchain_pinecone import PineconeVectorStore
from pinecone import Pinecone

# Define Ollama embeddings
embeddings = OllamaEmbeddings(model="nomic-embed-text")
# embeddings = OllamaEmbeddings(model="llama3")

# Pinecone database index
index_name = "nomic-embed-text-capiara-algorithm-mentor"
#index_name = "ollama-llama3-capiara-algorithm-mentor"

pinecone = Pinecone(api_key=PINECONE_API_KEY)
index = pinecone.Index(index_name)

# Initialize Pinecone vector store
vector_store = PineconeVectorStore(embedding=embeddings, index=index)

In [None]:
import bs4
from langchain_community.document_loaders import WebBaseLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter

# Load and chunk contents of the blog
loader = WebBaseLoader(
    web_paths=("https://lilianweng.github.io/posts/2023-06-23-agent/",),
    bs_kwargs=dict(
        parse_only=bs4.SoupStrainer()
    ),
)
docs = loader.load()

# Total number of characters
assert len(docs) == 1
print(f"Total characters: {len(docs[0].page_content)}")

# Chunk the document into smaller pieces
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
all_splits = text_splitter.split_documents(docs)

print(f"Split into {len(all_splits)} sub-documents")

In [None]:
# Important: Running this cell is RAM intensive consuming depending on the embedding model, please consider using a smaller model.
# Remember to run Ollama in the background and make sure the model is downloaded.

# Index chunks
document_ids = vector_store.add_documents(documents=all_splits)
print(f"Document IDs: {document_ids}")

# LangGraph to Tie Together Retrieval and Generation

In [10]:
from typing import List, TypedDict
from langgraph.graph import START, StateGraph
from langchain_core.documents import Document
from langchain import hub

# Define prompt for question-answering
prompt = hub.pull("rlm/rag-prompt")


# Define state for application
class State(TypedDict):
    question: str
    context: List[Document]
    answer: str


# Define application steps
def retrieve(state: State):
    retrieved_docs = vector_store.similarity_search(state["question"])
    return {"context": retrieved_docs}


def generate(state: State):
    docs_content = "\n\n".join(doc.page_content for doc in state["context"])
    messages = prompt.invoke({"question": state["question"], "context": docs_content})
    response = llm.invoke(messages)
    return {"answer": response.content}


# Compile application and test
graph_builder = StateGraph(State).add_sequence([retrieve, generate])
graph_builder.add_edge(START, "retrieve")
graph = graph_builder.compile()

In [None]:
response = graph.invoke({"question": "What is Task Decomposition?"})
print(response["answer"])

# Retrieve as a Tool Calling

The retrieve function is designed as a tool within the LangChain framework to fetch and format relevant information from a vector store based on a query. It integrates seamlessly with LangChain's tool-calling mechanism, allowing an AI model to invoke it dynamically during a conversation. By binding the retrieve tool to the language model, the system can generate responses that incorporate retrieved context, enabling concise and context-aware answers to user queries.

In [None]:
# Running this may fix issues with the plugin
%pip uninstall pinecone-plugin-inference

In [None]:
from langgraph.graph import MessagesState, StateGraph

# Define state for application with messages
graph_builder = StateGraph(MessagesState)

In [15]:
from langchain_core.tools import tool

# Define tool for retrieving information
@tool(response_format="content_and_artifact")
def retrieve(query: str):
    """Retrieve information related to a query."""
    retrieved_docs = vector_store.similarity_search(query, k=2)
    serialized = "\n\n".join(
        (f"Source: {doc.metadata}\n" f"Content: {doc.page_content}")
        for doc in retrieved_docs
    )
    return serialized, retrieved_docs

In [16]:
from langchain_core.messages import SystemMessage
from langgraph.prebuilt import ToolNode


# Generate an AIMessage that may include a tool-call to be sent.
def query_or_respond(state: MessagesState):
    """Generate tool call for retrieval or respond."""
    llm_with_tools = llm.bind_tools([retrieve])
    response = llm_with_tools.invoke(state["messages"])
    # MessagesState appends messages to state instead of overwriting
    return {"messages": [response]}


# Register the tool with the LangGraph application
tools = ToolNode([retrieve])


# Generate a response using the retrieved content.
def generate(state: MessagesState):
    """Generate answer."""
    # Get generated ToolMessages
    recent_tool_messages = []
    for message in reversed(state["messages"]):
        if message.type == "tool":
            recent_tool_messages.append(message)
        else:
            break
    tool_messages = recent_tool_messages[::-1]

    # Format into prompt
    docs_content = "\n\n".join(doc.content for doc in tool_messages)
    system_message_content = f"""
        You are an assistant for question-answering tasks.
        Use the following pieces of retrieved context to answer the question.
        If you don't know the answer, say that you don't know. Use three sentences maximum and keep the answer concise.\n\n
        {docs_content}
    """

    # Filter out tool calls from the conversation
    conversation_messages = [
        message
        for message in state["messages"]
        if message.type in ("human", "system")
        or (message.type == "ai" and not message.tool_calls)
    ]
    prompt = [SystemMessage(system_message_content)] + conversation_messages

    # Invoke the LLM with the prompt
    response = llm.invoke(prompt)
    return {"messages": [response]}

In [17]:
from langgraph.graph import END
from langgraph.prebuilt import tools_condition

# Build the graph
graph_builder.add_node(query_or_respond)
graph_builder.add_node(tools)
graph_builder.add_node(generate)

# Connect the nodes
graph_builder.set_entry_point("query_or_respond")
graph_builder.add_conditional_edges(
    "query_or_respond", # state
    tools_condition, # condition
    {END: END, "tools": "tools"}, 
)
graph_builder.add_edge("tools", "generate")
graph_builder.add_edge("generate", END)

# Compile the graph
graph = graph_builder.compile()

In [None]:
from IPython.display import Image, display

# Visualize the graph
display(Image(graph.get_graph().draw_mermaid_png()))

In [None]:
# Tool calling is not universal, but is supported by many popular LLM providers. 
# You can find a list of all models that support tool calling here.
# https://python.langchain.com/docs/integrations/chat/

input_message = ""

for step in graph.stream(
    {"messages": [{"role": "user", "content": input_message}]},
    stream_mode="values",
):
    step["messages"][-1].pretty_print()