# RAG with MLflow and LangChain

Tutorial can be found at: [MLflow LangChain Retriever](https://mlflow.org/docs/latest/llms/langchain/notebooks/langchain-retriever.html)

In [1]:
import os
import mlflow
import requests

from bs4 import BeautifulSoup
from pathlib import Path

from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import AIMessage, HumanMessage
from langchain.chains.history_aware_retriever import create_history_aware_retriever
from langchain.chains.retrieval import create_retrieval_chain
from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.document_loaders import TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings

from dotenv import load_dotenv
load_dotenv()

True

In [2]:
assert ("GOOGLE_API_KEY" in os.environ), "Please set your GOOGLE_API_KEY environment variable."

# Setup

In [3]:
# Setup Models
google_llm = ChatGoogleGenerativeAI(
    model="gemini-1.5-flash-002",
    temperature=0,
)

google_llm_embeddings = GoogleGenerativeAIEmbeddings(
    model='models/text-embedding-004',
    task_type="retrieval_document",
)

In [4]:
# Setup MLflow
## Tracking URI
mlflow.set_tracking_uri("http://localhost:5000")

## Create a new experiment
experiment = mlflow.set_experiment("LangChain RAG Tracing")

## Enable LangChain autologging
mlflow.langchain.autolog(
    silent=True,
    log_traces=True,
    log_models=True,
    log_model_signatures=True,
    log_input_examples=True,
    )


# FAISS Database Creation

In [5]:
# Scraping Federal Documents
def fetch_federal_document(url, div_class):
    """
    Scrapes the transcript of the Act Establishing Yellowstone National Park from the given URL.

    Args:
    url (str): URL of the webpage to scrape.

    Returns:
    str: The transcript text of the Act.
    """
    # Sending a request to the URL
    response = requests.get(url)
    if response.status_code == 200:
        # Parsing the HTML content of the page
        soup = BeautifulSoup(response.text, "html.parser")

        # Finding the transcript section by its HTML structure
        transcript_section = soup.find("div", class_=div_class)
        if transcript_section:
            transcript_text = transcript_section.get_text(separator="\n", strip=True)
            return transcript_text
        else:
            return "Transcript section not found."
    else:
        return f"Failed to retrieve the webpage. Status code: {response.status_code}"

In [6]:
# Document Fetching and Saving
def fetch_and_save_documents(url_list, doc_path):
    """
    Fetches documents from given URLs and saves them to a specified file path.

    Args:
        url_list (list): List of URLs to fetch documents from.
        doc_path (str): Path to the file where documents will be saved.
    """
    for url in url_list:
        document = fetch_federal_document(url, "col-sm-9")
        with open(doc_path, "a") as file:
            file.write(document)

In [7]:
# FAISS Database Creation
def create_faiss_database(document_path, database_save_directory, chunk_size=500, chunk_overlap=10):
    """
    Creates and saves a FAISS database using documents from the specified file.

    Args:
        document_path (str): Path to the file containing documents.
        database_save_directory (str): Directory where the FAISS database will be saved.
        chunk_size (int, optional): Size of each document chunk. Default is 500.
        chunk_overlap (int, optional): Overlap between consecutive chunks. Default is 10.

    Returns:
        FAISS database instance.
    """
    # Load documents from the specified file
    document_loader = TextLoader(document_path)
    raw_documents = document_loader.load()

    # Split documents into smaller chunks with specified size and overlap
    document_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
    document_chunks = document_splitter.split_documents(raw_documents)

    # Generate embeddings for each document chunk
    embedding_generator = google_llm_embeddings
    faiss_database = FAISS.from_documents(document_chunks, embedding_generator)

    # Save the FAISS database to the specified directory
    faiss_database.save_local(database_save_directory)

    return faiss_database

In [8]:
# Create a temporary directory to store documents
directory = Path('..', 'data')

# Document Paths and FAISS Index Directory
doc_path = os.path.join(directory, "docs.txt")
persist_dir = os.path.join(directory, "faiss_index_langchain_simple_rag")

# URLs of the Federal documents to scrape
url_listings = [
    "https://www.archives.gov/milestone-documents/act-establishing-yellowstone-national-park#transcript",
    "https://www.archives.gov/milestone-documents/sherman-anti-trust-act#transcript",
]

# Fetch and save documents from the URLs
if not os.path.exists(doc_path):
    print("Fetching and saving documents...")
    fetch_and_save_documents(url_listings, doc_path)

# Create a FAISS database from the saved documents
faiss_index = create_faiss_database(doc_path, persist_dir)

# Run Demo Examples

## RetrievalQA Chain

### Setup the Chain

In [9]:
# Setup RetrievalQA
retrievalQA = RetrievalQA.from_llm(llm=google_llm, retriever=faiss_index.as_retriever())

In [10]:
def load_retriever(persist_directory):
    embeddings = GoogleGenerativeAIEmbeddings(
        model='models/text-embedding-004',
        task_type="retrieval_document",
    )
    vectorstore = FAISS.load_local(
        persist_directory,
        embeddings,
        # you may need to add the line below
        # for langchain_community >= 0.0.27
        allow_dangerous_deserialization=True,
    )
    return vectorstore.as_retriever()


with mlflow.start_run() as run:
    logged_model = mlflow.langchain.log_model(
        retrievalQA,
        artifact_path="retrieval_qa",
        loader_fn=load_retriever,
        persist_dir=persist_dir,
    )

2024/11/11 13:59:14 INFO mlflow.tracking._tracking_service.client: 🏃 View run smiling-fish-215 at: http://localhost:5000/#/experiments/584542048209027571/runs/6507f4194a0b454a9c420fbfc29f6568.
2024/11/11 13:59:14 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://localhost:5000/#/experiments/584542048209027571.


In [11]:
#! This code is broken and doesn't work
# loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri)

In [12]:
def print_formatted_response(response_list, max_line_length=80):
    """
    Formats and prints responses with a maximum line length for better readability.

    Args:
    response_list (list): A list of strings representing responses.
    max_line_length (int): Maximum number of characters in a line. Defaults to 80.
    """
    for response in response_list:
        words = response.split()
        line = ""
        for word in words:
            if len(line) + len(word) + 1 <= max_line_length:
                line += word + " "
            else:
                print(line)
                line = word + " "
        print(line)

### Examples

In [13]:
answer1 = retrievalQA.invoke({"query": "What does the document say about trespassers?"})

print_formatted_response([answer1['result']])

The document states that all persons who locate, settle upon, or occupy a 
public park or pleasuring-ground, except as otherwise provided, shall be 
considered trespassers and removed. Additionally, it mentions that a park 
warden will remove trespassers after the passage of the act. 


In [14]:
answer2 = retrievalQA.invoke({"query": "What is a bridle-path and can I use one at Yellowstone?"})

print_formatted_response([answer2['result']])

I'm sorry, but this text does not contain information about bridle paths or 
whether they exist in Yellowstone National Park. I cannot answer your question. 


## Chat with RAG

### Setup the Chain

In [15]:
contextualize_q_system_prompt = (
    "Given a chat history and the latest user question "
    "which might reference context in the chat history, "
    "formulate a standalone question which can be understood "
    "without the chat history. Do NOT answer the question, "
    "just reformulate it if needed and otherwise return it as is."
)

contextualize_q_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", contextualize_q_system_prompt),
        MessagesPlaceholder("chat_history"),
        ("human", "{input}"),
    ]
)

In [16]:
history_aware_retriever = create_history_aware_retriever(
    google_llm, faiss_index.as_retriever(), contextualize_q_prompt
)

In [17]:
system_prompt = (
    "You are an assistant for question-answering tasks. "
    "Use the following pieces of retrieved context to answer "
    "the question. If you don't know the answer, say that you "
    "don't know. Use three sentences maximum and keep the "
    "answer concise."
    "\n\n"
    "{context}"
)

qa_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt),
        MessagesPlaceholder("chat_history"),
        ("human", "{input}"),
    ]
)

In [18]:
question_answer_chain = create_stuff_documents_chain(google_llm, qa_prompt)
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)

In [19]:
# This code doesn't work with this rag_chain object, need to dig into the documentation. One of my reason to not use these wrappers in production!!!
# # def load_retriever(persist_directory):
# #     embeddings = GoogleGenerativeAIEmbeddings(
# #         model='models/text-embedding-004',
# #         task_type="retrieval_document",
# #     )
# #     vectorstore = FAISS.load_local(
# #         persist_directory,
# #         embeddings,
# #         # you may need to add the line below
# #         # for langchain_community >= 0.0.27
# #         allow_dangerous_deserialization=True,
# #     )
# #     return vectorstore.as_retriever()


# # with mlflow.start_run() as run:
# #     logged_model = mlflow.langchain.log_model(
# #         rag_chain,
# #         artifact_path="rag_chain_chat",
# #         loader_fn=load_retriever,
# #         persist_dir=persist_dir,
# #     )

### Examples

In [20]:
chat_history = []
question = "What is the document say about the Yellowstone?"
ai_msg_1 = rag_chain.invoke({"input": question, "chat_history": chat_history})
print_formatted_response([ai_msg_1['answer']])

The document describes a tract of land near the headwaters of the Yellowstone 
River in Montana and Wyoming, reserving it as a public park. This land is 
described by its boundaries, starting at the junction of Gardiner's River and 
the Yellowstone River. The act withdraws the land from settlement, occupancy, 
or sale. 


In [21]:
chat_history.extend(
    [
        HumanMessage(content=question),
        AIMessage(content=ai_msg_1["answer"]),
    ]
)

second_question = "Can I buy it from the Fedral Government?"
ai_msg_2 = rag_chain.invoke({"input": second_question, "chat_history": chat_history})

print_formatted_response([ai_msg_2["answer"]])

No, the document explicitly states that the land is withdrawn from sale under 
the laws of the United States. 
