In [11]:
import os
import sys
sys.path.append(os.path.join(os.getcwd(), os.pardir, os.pardir))
# from src.logger import Logger
# from src.models import QuestionModel
from src.prompts import *
from src.db.vector_store import load_config, load_vector_store

from langchain_core.tools import tool
from langchain_qdrant import Qdrant
from langchain_openai import OpenAIEmbeddings
from langchain.schema import (
    SystemMessage,    # Generic system message
    HumanMessage,     # HumanMessage is a message from the human
    AIMessage,        # AIMessage is a message from the AI
)
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import MessagesState, END, START, StateGraph
from langgraph.prebuilt import ToolNode

from typing import Annotated, Literal, TypedDict

In [2]:

config = load_config()
print(config)
embeddings = OpenAIEmbeddings(model=config.get('embeddings').get('model'))

{'qdrant': {'collection_name': 'pymol-assistant', 'url': 'https://f01bac7d-144a-44f8-b643-17d9c6db29f0.us-east4-0.gcp.cloud.qdrant.io:6333', 'vector_config': {'size': 768, 'distance': 'cosine'}, 'vector_name': 'content', 'content_payload_key': 'content'}, 'embeddings': {'model': 'text-embedding-3-small'}, 'data': {'to_process': '../../prep/data/pymol-docs.csv', 'parquet_file': '../../prep/data/pymol-docs.parquet'}}


In [3]:
qdrant = load_vector_store(config=config, embeddings=embeddings)

QDRANT_URL:  https://f01bac7d-144a-44f8-b643-17d9c6db29f0.us-east4-0.gcp.cloud.qdrant.io:6333
QDRANT_COLLECTION_NAME:  pymol-assistant


In [5]:
# RAG toolkit
@tool
def query_pymol_docs(query: str, qdrant: Qdrant, top_k: int = 5):
    """
    Query the PyMOL documentation using Qdrant

    Args:
        query (str): The query to search for
        qdrant (Qdrant): The Qdrant instance
        top_k (int, optional): The number of results to return. Defaults to 5.
    
    Returns:
        List[Dict]: The top k results for the query with the highest cosine similarity score
    """
    # Get the top k results for the query
    results = qdrant.similarity_search_with_score(query=query, k=top_k)

    context: str = ""

    for doc, score in results:
        d = f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]"
        context += d + "\n"
    
    return context

In [8]:
# results = query_pymol_docs("how to load a pdb file", qdrant=qdrant, top_k=4)
# print(results)

In [None]:
tools = [query_pymol_docs]
tool_node = ToolNode(tools=tools, name='tools')

In [None]:
# Define the function that determines whether to continue or not
def should_continue(state: MessagesState) -> Literal["tools", END]:
    messages = state['messages']
    last_message = messages[-1]
    # If the LLM makes a tool call, then we route to the "tools" node
    if last_message.tool_calls:
        return "tools"
    # Otherwise, we stop (reply to the user)
    return END

# Define the function that calls the model
def call_model(state: MessagesState):
    messages = state['messages']
    response = model.invoke(messages)
    # We return a list, because this will get added to the existing list
    return {"messages": [response]}

## Construct the Nodes and the Edges

In [None]:
# Define a new graph
workflow = StateGraph(MessagesState)

# Define the two nodes we will cycle between
workflow.add_node("agent", call_model)
workflow.add_node("tools", tool_node)

# Set the entrypoint as `agent`
# This means that this node is the first one called
workflow.add_edge(START, "agent")

# We now add a conditional edge
workflow.add_conditional_edges(
    # First, we define the start node. We use `agent`.
    # This means these are the edges taken after the `agent` node is called.
    "agent",
    # Next, we pass in the function that will determine which node is called next.
    should_continue,
)

# We now add a normal edge from `tools` to `agent`.
# This means that after `tools` is called, `agent` node is called next.
workflow.add_edge("tools", 'agent')

In [None]:
# Initialize memory to persist state between graph runs
checkpointer = MemorySaver()

In [None]:
# Note that we're (optionally) passing the memory when compiling the graph
app = workflow.compile(checkpointer=checkpointer)