In [2]:
! pip install langchain_community tiktoken langchain-openai langchainhub chromadb langchain langgraph

Collecting langchain_community
  Downloading langchain_community-0.0.36-py3-none-any.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m23.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting tiktoken
  Downloading tiktoken-0.6.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m76.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting langchain-openai
  Downloading langchain_openai-0.1.6-py3-none-any.whl (34 kB)
Collecting langchainhub
  Downloading langchainhub-0.1.15-py3-none-any.whl (4.6 kB)
Collecting chromadb
  Downloading chromadb-0.5.0-py3-none-any.whl (526 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m526.8/526.8 kB[0m [31m25.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting langchain
  Downloading langchain-0.1.17-py3-none-any.whl (867 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m867.6/867.6 kB

In [3]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
from langchain.tools.retriever import create_retriever_tool
from langchain.docstore.document import Document
import os

import openai as OpenAI
from google.colab import userdata

#LangGraph Retrieval Agent

##Retriever

In [4]:
def read_text_files(directory_path):
    texts = {}
    # List all files in the given directory
    for filename in os.listdir(directory_path):
        if filename.endswith('.txt'):  # Check if the file is a text file
            file_path = os.path.join(directory_path, filename)
            with open(file_path, 'r', encoding='utf-8') as file:
                texts[filename] = file.read()
    return texts

In [5]:
Openai_API_key=userdata.get('Openai_API_key')

In [19]:
# Read data
file_path = '/content/drive/MyDrive/vector_embedding'  # Replace with your actual text file path
text = read_text_files(file_path)

text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=1000, chunk_overlap=50
)

doc_splits = []

for filename, content in text.items():

    tokenized_texts = text_splitter.split_text(content)  # Tokenize the content of each file
    for chunk in tokenized_texts:
        doc_splits.append(Document(page_content=chunk, metadata={'source_document':filename}))

print(f"Number of chunks: {len(doc_splits)}")
print(f"Source document of 2nd chunk: {doc_splits[1].metadata['source_document']}")

# doc_splits = text_splitter.split_documents(text)

# Add to vectorDB
vectorstore = Chroma.from_documents(
    documents=doc_splits,
    collection_name="rag-chroma",
    embedding=OpenAIEmbeddings(openai_api_key=Openai_API_key),
)
retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})


Number of chunks: 7
Source document of 2nd chunk: diagnosis.txt


In [20]:
doc_splits[0].page_content

"Duchenne Muscular Dystrophy Diagnosis\nIn diagnosing any form of muscular dystrophy, a doctor usually begins by taking a patient and family history and performing a physical examination. Doctors may find pseudohypertrophy, lumbar spine deviation, gait abnormalities, and several grades of diminished muscle reflexes.\nMuch can be learned from these observations, including the pattern of weakness. A patient’s history and physical go a long way toward making a diagnosis, even before any complicated diagnostic tests are done.\nCardiomyopathy in patients with DMD may be associated with conduction abnormalities as well. A doctor may observe characteristic changes in an electrocardiogram. Also, structural changes in the heart, as valvular heart disease (specially affecting the mitral valve when it occurs) can be detected by echocardiography. Therefore, electrocardiogram, noninvasive imaging with echocardiography, or cardiac MRI are essential, along with consultation with a cardiologist. \n\nC

## Retreival Tool

In [8]:
from langchain.tools.retriever import create_retriever_tool

tool = create_retriever_tool(
    retriever,
    "retrieve_text",
    "Search and return relevant information from documents about the user's questions to generate the accurate answers.",
)

tools = [tool]

from langgraph.prebuilt import ToolExecutor

tool_executor = ToolExecutor(tools)

###Agent state
#####We will defined a graph. A state object that it passes around to each node. Our state will be a list of messages. Each node in our graph will append to it.

In [10]:
import operator
from typing import Annotated, Sequence, TypedDict

from langchain_core.messages import BaseMessage


class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]

In [43]:
import json
import operator
from typing import Annotated, Sequence, TypedDict

from langchain import hub
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts import PromptTemplate
from langchain.tools.render import format_tool_to_openai_function
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.messages import BaseMessage, FunctionMessage
from langchain.output_parsers.openai_tools import PydanticToolsParser
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import ToolInvocation
from langchain_core.messages import HumanMessage
from langchain_core.output_parsers import StrOutputParser

### Edges


def should_retrieve(state):
    """
    Decides whether the agent should retrieve more information or end the process.

    This function checks the last message in the state for a function call. If a function call is
    present, the process continues to retrieve information. Otherwise, it ends the process.

    Args:
        state (messages): The current state

    Returns:
        str: A decision to either "continue" the retrieval process or "end" it
    """

    # print("---DECIDE TO RETRIEVE---")
    messages = state["messages"]
    last_message = messages[-1]

    # If there is no function call, then we finish
    if "function_call" not in last_message.additional_kwargs:
        print("---DECISION: DO NOT RETRIEVE / DONE---")
        return "end"
    # Otherwise there is a function call, so we continue
    else:
        print("---DECISION: RETRIEVE---")
        return "continue"


def grade_documents(state):
    """
    Determines whether the retrieved documents are relevant to the question.

    Args:
        state (messages): The current state

    Returns:
        str: A decision for whether the documents are relevant or not
    """

    print("---CHECK RELEVANCE---")

    # Data model
    class grade(BaseModel):
      """Binary score for relevance check on retrieved documents with explanation."""

      binary_score: str = Field(description="Documents are relevant to the question, 'yes' or 'no'")
      explanation: str = Field(description="Explanation of why the document is or is not relevant to the question")

    # LLM
    model = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-0125", streaming=True, api_key=Openai_API_key)

    # Tool
    grade_tool_oai = convert_to_openai_tool(grade)

    # LLM with tool and enforce invocation
    llm_with_tool = model.bind(
        tools=[grade_tool_oai],
        tool_choice={"type": "function", "function": {"name": "grade"}},
    )

    # Parser
    parser_tool = PydanticToolsParser(tools=[grade])

    # Prompt
    prompt = PromptTemplate(
        template="""You are a grader assessing relevance of a retrieved document to a user question. \n
        Here is the retrieved document: \n\n {context} \n\n
        Here is the user question: {question} \n
        If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
        Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.\n
        Explain in detail your reasoning based on the presence of keyword(s) or semantic meanings or content logic related to the question.""",
        input_variables=["context", "question"],
    )

    # Chain
    chain = prompt | llm_with_tool | parser_tool

    messages = state["messages"]
    last_message = messages[-1]

    question = messages[0].content
    docs = last_message.content

    score = chain.invoke(
        {"question": question,
         "context": docs}
    )

    grade = score[0].binary_score
    reason = score[0].explanation

    if grade == "yes":
        print("---DECISION: DOCS RELEVANT---")
        print(reason)
        return "yes"

    else:
        print("---DECISION: DOCS NOT RELEVANT---")
        print(grade)
        return "no"


### Nodes


def agent(state):
    """
    Invokes the agent model to generate a response based on the current state. Given
    the question, it will decide to retrieve using the retriever tool, or simply end.

    Args:
        state (messages): The current state

    Returns:
        dict: The updated state with the agent response apended to messages
    """
    print("---CALL AGENT---")
    messages = state["messages"]
    model = ChatOpenAI(temperature=0, streaming=True, model="gpt-3.5-turbo-0125", api_key=Openai_API_key)
    functions = [format_tool_to_openai_function(t) for t in tools]
    model = model.bind_functions(functions)
    response = model.invoke(messages)
    # We return a list, because this will get added to the existing list
    return {"messages": [response]}

def retrieve(state):
    """
    Uses tool to execute retrieval.

    Args:
        state (messages): The current state

    Returns:
        dict: The updated state with retrieved docs
    """
    print("---EXECUTE RETRIEVAL---")
    messages = state["messages"]
    # Based on the continue condition
    # we know the last message involves a function call
    last_message = messages[-1]
    # We construct an ToolInvocation from the function_call
    action = ToolInvocation(
        tool=last_message.additional_kwargs["function_call"]["name"],
        tool_input=json.loads(
            last_message.additional_kwargs["function_call"]["arguments"]
        ),
    )
    # We call the tool_executor and get back a response
    response = tool_executor.invoke(action)
    function_message = FunctionMessage(content=str(response), name=action.tool)
    print(f"Function message form state retreive: {function_message}")
    print(f"Action for tool exicution: {action}")

    # We return a list, because this will get added to the existing list
    return {"messages": [function_message]}

def rewrite(state):
    """
    Transform the query to produce a better question.

    Args:
        state (messages): The current state

    Returns:
        dict: The updated state with re-phrased question
    """

    print("---TRANSFORM QUERY---")
    messages = state["messages"]
    question = messages[0].content

    msg = [HumanMessage(
        content=f""" \n
    Look at the input and try to reason about the underlying semantic intent / meaning. \n
    Here is the initial question:
    \n ------- \n
    {question}
    \n ------- \n
    Formulate an improved question: """,
    )]

    # Grader
    model = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-0125", streaming=True, api_key=Openai_API_key)
    response = model.invoke(msg)
    return {"messages": [response]}

def generate(state):
    """
    Generate answer

    Args:
        state (messages): The current state

    Returns:
         dict: The updated state with re-phrased question
    """
    print("---GENERATE---")
    messages = state["messages"]
    question = messages[0].content
    last_message = messages[-1]

    question = messages[0].content
    docs = last_message.content

    # Prompt
    temp = """Answer the question based only on the following context:
{context}

Question: {question}
"""

    prompt = PromptTemplate(template=temp, input_variables=["context","question"])

    # LLM
    llm = ChatOpenAI(model_name="gpt-3.5-turbo-0125", temperature=0, streaming=True, api_key=Openai_API_key)

    # Post-processing
    def format_docs(docs):
        return "\n\n".join(doc.page_content for doc in docs)

    # Chain
    rag_chain = prompt | llm | StrOutputParser()

    # Run
    response = rag_chain.invoke({"context": docs, "question": question})
    return {"messages": [response]}

###Graph
######Start with an agent, call_model. Agent make a decision to call a function. If so, then action to call tool (retriever). Then call agent with the tool output added to messages (state)

In [44]:
from langgraph.graph import END, StateGraph

# Define a new graph
workflow = StateGraph(AgentState)

# Define the nodes we will cycle between
workflow.add_node("agent", agent)  # agent
workflow.add_node("retrieve", retrieve)  # retrieval
workflow.add_node("rewrite", rewrite)  # retrieval
workflow.add_node("generate", generate)  # retrieval

In [45]:
# Call agent node to decide to retrieve or not
workflow.set_entry_point("agent")

# Decide whether to retrieve
workflow.add_conditional_edges(
    "agent",
    # Assess agent decision
    should_retrieve,
    {
        # Call tool node
        "continue": "retrieve",
        "end": END,
    },
)

# Edges taken after the `action` node is called.
workflow.add_conditional_edges(
    "retrieve",
    # Assess agent decision
    grade_documents,
    {
        "yes": "generate",
        "no": "rewrite",
    },
)
workflow.add_edge("generate", END)
workflow.add_edge("rewrite", "agent")

# Compile
app = workflow.compile()

In [46]:
import pprint
from langchain_core.messages import HumanMessage

inputs = {
    "messages": [
        HumanMessage(
            content="What were the key discoveries made in the 1980s that advanced our understanding of the causes of Duchenne Muscular Dystrophy (DMD)?"
        )
    ]
}
for output in app.stream(inputs):
    for key, value in output.items():
        pprint.pprint(f"Output from node '{key}':")
        pprint.pprint("---")
        pprint.pprint(value, indent=2, width=80, depth=None)
    pprint.pprint("\n---\n")

---CALL AGENT---
---DECISION: RETRIEVE---
"Output from node 'agent':"
'---'
{ 'messages': [ AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{"query":"key discoveries Duchenne Muscular Dystrophy 1980s"}', 'name': 'retrieve_text'}}, response_metadata={'finish_reason': 'function_call'}, id='run-2ad827cf-bd4c-4720-ac16-dd2ecff21523-0')]}
'\n---\n'
---EXECUTE RETRIEVAL---
Function message form state retreive: content='Duchenne Muscular Dystrophy (DMD)\n\nDuchenne Muscular Dystrophy (DMD)\n\nWhat is Duchenne muscular dystrophy?\n\nWhat is Duchenne muscular dystrophy?' name='retrieve_text'
Action for tool exicution: tool='retrieve_text' tool_input={'query': 'key discoveries Duchenne Muscular Dystrophy 1980s'}
---CHECK RELEVANCE---
---DECISION: DOCS RELEVANT---
The retrieved document contains the keyword 'Duchenne Muscular Dystrophy' which is directly related to the user question about key discoveries made in the 1980s regarding the causes of Duchenne Muscular Dystroph