# Docs to Read

https://python.langchain.com/api_reference/openai/chat_models/langchain_openai.chat_models.base.ChatOpenAI.html

https://python.langchain.com/v0.1/docs/use_cases/tool_use/prompting/

https://python.langchain.com/docs/how_to/tools_model_specific/

https://python.langchain.com/docs/tutorials/qa_chat_history/

In [None]:
%pip install --quiet --upgrade langchain-text-splitters langchain-community langgraph
%pip install --quiet langchain-ollama langchain-pinecone
%pip install --quiet pypdf

In [2]:
import os
from dotenv import load_dotenv

load_dotenv()

# Define MariTalk API key and LLM model
MARITALK_API_KEY = os.getenv('MARITALK_API_KEY')
MARITALK_LLM_MODEL = "sabia-3"

# Define Langsmith API key and tracing
LANGSMITH_API_KEY = os.getenv('LANGSMITH_API_KEY')
LANGSMITH_TRACING = os.getenv('LANGSMITH_TRACING')

# Define Pinecone API key
PINECONE_API_KEY = os.getenv('PINECONE_API_KEY')

In [None]:
from langchain_ollama import OllamaEmbeddings
from langchain_pinecone import PineconeVectorStore
from pinecone import Pinecone

# Define Ollama embeddings
embeddings = OllamaEmbeddings(model="nomic-embed-text")

# Pinecone database index
index_name = "nomic-embed-text-capiara-algorithm-mentor"

pinecone = Pinecone(api_key=PINECONE_API_KEY)
index = pinecone.Index(index_name)

# Initialize Pinecone vector store
vector_store = PineconeVectorStore(embedding=embeddings, index=index)

In [None]:
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

# Load a PDF document
document_to_load = "../docs/normas_atividade_extensao.pdf"

loader = PyPDFLoader(document_to_load)
docs = loader.load()

# Chunk the document into smaller pieces
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
all_splits = text_splitter.split_documents(docs)

print(f"Split into {len(all_splits)} sub-documents")

In [None]:
# Important: Running this cell is RAM intensive consuming depending on the embedding model, please consider using a smaller model.
# Remember to run Ollama in the background and make sure the model is downloaded.

# Index chunks
document_ids = vector_store.add_documents(documents=all_splits)

print(f"Document IDs: {document_ids}")

In [6]:
from langchain_community.chat_models import ChatMaritalk

# Define the tool schema for the LLM
tools_schema = [
    {
        "type": "function",
        "function": {
            "name": "retrieve",
            "description": "Retrieve relevant documents from the knowledge base given a user query.",
            "parameters": {
                "type": "object",
                "properties": {
                    "query": {
                        "type": "string",
                        "description": "The user's question or query to search in the vector database."
                    }
                },
                "required": ["query"],
                "additionalProperties": False
            }
        }
    }
]

# MariTalk LLM initialization
llm = ChatMaritalk(
    model=MARITALK_LLM_MODEL,
    api_key=MARITALK_API_KEY,
    max_tokens=10000,
    temperature=0.2,
    tools=tools_schema,
)

In [7]:
import json
import uuid
from langchain_core.tools import tool
from typing_extensions import TypedDict, List

# Define the state for the graph
class MessagesState(TypedDict):
    messages: List


# Tool to retrieve documents from the vector store
@tool(response_format="content_and_artifact")
def retrieve(query: str):
    """Retrieve relevant documents from the vector store based on a user question."""
    print(f"🔍 [TOOL] Calling 'retrieve' with query: {query}")
    
    retrieved_docs = vector_store.similarity_search(query, k=3)
    print(f"📚 [TOOL] Documents found: {len(retrieved_docs)}")
    
    for i, doc in enumerate(retrieved_docs):
        print(f"    📘 Doc {i+1}: {doc.page_content[:80]}...")
    
    serialized = "\n\n".join(
        f"Source: {doc.metadata}\nContent: {doc.page_content}"
        for doc in retrieved_docs
    )
    return serialized, retrieved_docs


# Parse the tool call from the LLM response
def parse_tool_call(response):
    """
    Attempts to convert response.content into JSON and detect the "tool_call" key.
    If found, converts the keys 'function' to 'name' and 'arguments' to 'args',
    and injects a unique "id" (if it doesn't exist) for compatibility with the ToolNode.
    Returns a dictionary with tool data or None.
    """
    try:
        content = response.content.strip()
        parsed = json.loads(content)
        if "tool_call" in parsed:
            call = parsed["tool_call"]
            
            # Convert 'function' to 'name'
            if "function" in call:
                call["name"] = call.pop("function")
            
            # Convert 'arguments' to 'args'
            if "arguments" in call:
                call["args"] = call.pop("arguments")
            
            # Inject a unique id if it doesn't exist
            if "id" not in call:
                call["id"] = str(uuid.uuid4())
            return call
    
    except Exception as e:
        print("⛔️ [ERROR] Error converting to JSON:", e)
    return None

In [8]:
from langchain_core.messages import HumanMessage, SystemMessage


# Decide whether to call the tool or respond directly
def query_or_respond(state: MessagesState):
    #FIXME: Better system instructions prompt
    system_instructions = ("""
        You are a helpful assistant with access to a specialized document database containing information related to university files and related subjects.
        ONLY call the tool 'retrieve' (by returning a JSON object as specified below) if the user's query is clearly about this domain.
        If the user's query is about general topics or subjects not related to this domain, answer directly without calling any tool. \n
        When calling the tool, respond ONLY with a JSON object in the following format and NOTHING else:\n
        {"tool_call": {"function": "retrieve", "arguments": {"query": "<your query>"}}}\n
        If no external specialized information is required, answer directly.
    """)
    # Prompt the LLM with system instructions and the conversation history
    new_messages = [SystemMessage(system_instructions)] + state["messages"]
    
    # Invoke the LLM with the new messages
    print("🤖 [LLM] Generating response [Validating the need for tool call]")
    response = llm.invoke(new_messages)
    
    print("📥 [LLM] Response from model:")
    response.pretty_print()

    try:
        print("\n\n🛃 [DEBUG] Raw data from model response:\n", json.dumps(response.model_dump(), indent=4))
    except Exception as e:
        print("⛔️ [ERROR] Error printing JSON response:", e)

    # Check if the response contains a tool call
    tool_call = parse_tool_call(response)
    if tool_call:
        print("🔧 [LLM] Detected tool call:")
        print("   - Tool:", tool_call.get("name"))
        print("   - Args:", tool_call.get("args"))
        response.tool_calls = [tool_call]
    else:
        print("🚧 [LLM] No tool call detected.")
        response.tool_calls = []
    return {"messages": [response]}

# Generate response using the tool's content
def generate(state: MessagesState):
    print("🛠️ [GENERATE] Generating final response using tools:")
    recent_tool_messages = []
    
    # Iterate through the messages in reverse order to find the most recent tool messages
    for message in reversed(state["messages"]):
        if message.type == "tool":
            recent_tool_messages.append(message)
        else:
            break
    tool_messages = recent_tool_messages[::-1]
    if not tool_messages:
        print("❌ [ERROR] ToolMessage not found. Cannot generate final response.")
    else:
        print(f"📦 [GENERATE] ToolMessages: {len(tool_messages)}")
    
    # Construct the system message with the retrieved documents
    docs_content = "\n\n".join(t.content for t in tool_messages)
    system_message_content = (
        "You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. "
        "If you don't know the answer, say that you don't know. Use three sentences maximum and keep the answer concise.\n\n"
        f"{docs_content}"
    )

    # Filter the conversation messages to exclude tool calls
    conversation_messages = [
        m for m in state["messages"]
        if m.type in ("human", "system") or (m.type == "ai" and not getattr(m, "tool_calls", []))
    ]

    # Add the system message and conversation messages to the prompt
    prompt = [SystemMessage(system_message_content)] + conversation_messages
    response = llm.invoke(prompt)

    print("✅ [FINAL RESPONSE] RAG as tool response:")
    response.pretty_print()
    return {"messages": [response]}

In [9]:
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.graph import StateGraph, END

# Build the state graph
builder = StateGraph(MessagesState)

# Define the nodes and edges
builder.add_node("query_or_respond", query_or_respond)
tool_node = ToolNode([retrieve])
builder.add_node("tools", tool_node)
builder.add_node("generate", generate)

# Define entry point
builder.set_entry_point("query_or_respond")

# Define conditional 
builder.add_conditional_edges(
    "query_or_respond",
    tools_condition,
    {"tools": "tools", END: END},
)

# Define edges
builder.add_edge("tools", "generate")
builder.add_edge("generate", END)

graph = builder.compile()

In [None]:
from IPython.display import Image, display

# Visualize the graph
display(Image(graph.get_graph().draw_mermaid_png()))

In [None]:
if __name__ == "__main__":
    user_input = "O que é langchain"
    initial_state = {"messages": [HumanMessage(content=user_input)]}
    
    print("🚀 Initializing graph")

    for step in graph.stream(initial_state, stream_mode="values"):
        print("\n🦜 [STEP] New graph step")
        step["messages"][-1].pretty_print()