## Chatbot with RAG with Evaluation on Gradio UI

This notebook aims to create a customized chatbot, leveraging RAG (Retrieval Augmented Generation) technique that incorporate preliminary domain context from ArXiv papers, which significantly enhances the response relevance and accuracy. 

RAG combines the strengths of both retrieval-based and generation-based models to improve the quality and efficiency of text generation.

The basic idea behind RAG is to use a retrieval model to find relevant information from a large corpus of text which stored in the vector database (e.g. FAISS), and then use a generation model to generate new text that incorporates the retrieved information. This can help to improve the coherence and accuracy of the generated text, as well as reduce the risk of generating irrelevant or repetitive content.

The Chatbot is able to utilize both basic chain and RAG-enhanced chain to generate response per user query. Evaluation is empowered to compare two chain performance based on sythethic questions from Judge LLM model.

### Environment Setup

The notebook applies NVIDIA AI foundation models, which requires NVIDIA API key. 

Langchain is a popular LLM orchestration tool to manage the LLM workflow components, which connects the vector database to LLM model. This notebook will be using the LangChain Expression Language (LCEL) from basic chain specification to more advanced dialog management practices.

LangServe is used to deploy the Langchain workflow to application server, to create and distribute accessible API routes.

Gradio is applied to create customized UI, with functions to allow streaming user query, dynamic selection of basic chain or RAG-enhanced chain, and evaluate the performance of two chains.

In [1]:
%%capture
## ^^ Comment out if you want to see the pip install process

## Necessary for Colab, not necessary for course environment
# %pip install -q langchain langchain-nvidia-ai-endpoints gradio rich
# %pip install -q arxiv pymupdf faiss-cpu
    
## If you're in colab and encounter a typing-extensions issue,
##  restart your runtime and try again
from langchain_nvidia_ai_endpoints._common import NVEModel

In [2]:
from functools import partial
from rich.console import Console
from rich.style import Style
from rich.theme import Theme

console = Console()
base_style = Style(color="#76B900", bold=True)
pprint = partial(console.print, style=base_style)

from getpass import getpass
import requests
import os

hard_reset = False  ## <-- Set to True if you want to reset your NVIDIA_API_KEY
while "nvapi-" not in os.environ.get("NVIDIA_API_KEY", "") or hard_reset:
    try: 
        assert not hard_reset
        response = requests.get("http://docker_router:8070/get_key").json()
        assert response.get('nvapi_key')
    except: response = {'nvapi_key' : getpass("NVIDIA API Key: ")}
    os.environ["NVIDIA_API_KEY"] = response.get("nvapi_key")
    try: requests.post("http://docker_router:8070/set_key/", json={'nvapi_key' : os.environ["NVIDIA_API_KEY"]}).json()
    except: pass
    hard_reset = False
    if "nvapi-" not in os.environ.get("NVIDIA_API_KEY", ""):
        print("[!] API key assignment failed. Make sure it starts with `nvapi-` as generated from the model pages.")

print(f"Retrieved NVIDIA_API_KEY beginning with \"{os.environ.get('NVIDIA_API_KEY')[:9]}...\"")
from langchain_nvidia_ai_endpoints._common import NVEModel
NVEModel().available_models

Retrieved NVIDIA_API_KEY beginning with "nvapi-uZd..."


{'playground_kosmos_2': '0bcd1a8c-451f-4b12-b7f0-64b4781190d1',
 'playground_llama2_70b': '0e349b44-440a-44e1-93e9-abe8dcb27158',
 'playground_nemotron_qa_8b': '0c60f14d-46cb-465e-b994-227e1c3d5047',
 'playground_mistral_7b': '35ec3354-2681-4d0e-a8dd-80325dcf7c63',
 'playground_seamless': '72ad9555-2e3d-4e73-9050-a37129064743',
 'playground_nvolveqa_40k': '091a03bb-7364-4087-8090-bd71e9277520',
 'playground_llama2_code_70b': '2ae529dc-f728-4a46-9b8d-2697213666d8',
 'playground_deplot': '3bc390c7-eeec-40f7-a64d-0c6a719985f7',
 'playground_llama2_13b': 'e0bb7fb9-5333-4a27-8534-c6288f921d3f',
 'playground_llama2_code_13b': 'f6a96af4-8bf9-4294-96d6-d71aa787612e',
 'playground_cuopt': '8f2fbd00-2633-41ce-ab4e-e5736d74bff7',
 'playground_sdxl': '89848fb8-549f-41bb-88cb-95d6597044a4',
 'playground_nv_llama2_rlhf_70b': '7b3e3361-4266-41c8-b312-f5e33c81fc92',
 'playground_neva_22b': '8bf70738-59b9-4e5f-bc87-7ab4203be7a0',
 'playground_llama2_code_34b': 'df2bee43-fb69-42b9-9ee5-f4eabbeaf3a8',
 '

### Summary of RAG Workflows

> **Incorporate External Documents into LLM Workflow:**
- Divide **each document** into chunks and process them into useful messages.
- Generate semantic embedding for each new document chunk.
- Add the chunk bodies to **a scalable vector database for fast retrieval**.
- Query the **vector database** for relevant chunks to fill in the LLM context.


<!-- > <img src="https://drive.google.com/uc?export=view&id=1cFbKbVvLLnFPs3yWCKIuzXkhBWh6nLQY" width=1200px/> -->
> <img src="https://dli-lms.s3.amazonaws.com/assets/s-fx-15-v1/imgs/data_connection_langchain.jpeg" width=1200px/>
>
> From [**Retrieval | LangChain**🦜️🔗](https://python.langchain.com/docs/modules/data_connection/)

### Loading And Chunking External Documents

The following code block incorporate recent Arxiv papers to load in for the RAG chain. A few simplifying assumptions and additional processing steps are included to help improve naive RAG performance:

- Documents are cut off prior to the "References" section if one exists. This will keep the system from considering the citations and appendix sections, which tend to be long and distracting.

- A chunk that lists the available documents is inserted to provide a high-level view of all available documents in a single chunk. 

- Additionally, the metadata entries are also inserted to provide general information. Ideally, there would also be some synthetic chunks that merge the metadata into interesting cross-document chunks.

In [4]:
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings

from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import ArxivLoader

text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000, chunk_overlap=100,
    separators=["\n\n", "\n", ".", ";", ",", " ", ""],
)

# Incorporate Arxiv papers as external documents
print("Loading Documents")
docs = [
    # ArxivLoader(query="1706.03762").load(),  ## Attention Is All You Need Paper
    # ArxivLoader(query="1810.04805").load(),  ## BERT Paper
    ArxivLoader(query="2005.11401").load(),  ## RAG Paper
    # ArxivLoader(query="2205.00445").load(),  ## MRKL Paper
    # ArxivLoader(query="2310.06825").load(),  ## Mistral Paper
    ArxivLoader(query="2306.05685").load(),  ## LLM-as-a-Judge
    # ArxivLoader(query="2112.10752").load(),  ## Latent Stable Diffusion Paper
    # ArxivLoader(query="2103.00020").load(),  ## CLIP Paper
    ArxivLoader(query="2312.10997").load(),  ## RAG for LLM
    ArxivLoader(query="2402.07867").load(),  ## Knowledge Poisoning Attacks to RAG
    ArxivLoader(query="2402.08416").load(),  ## Jailbreak attack RAG
    ArxivLoader(query="2402.09939").load(), ## RAG in Construction
    ArxivLoader(query="2402.07179").load(), ## Prompt Perturbation in RAG
    ArxivLoader(query="2402.07016").load(), ## RAG in Health Records
    ArxivLoader(query="2402.01767").load(), ## Hierarchical Contextual RAG
    ArxivLoader(query="2402.05131").load(), ## Financial Report RAG
]

## Cut the paper short if references is included.
## This is a standard string in papers.
for doc in docs:
    content = doc[0].page_content
    if "References" in content:
        doc[0].page_content = content[:content.index("References")]

## Split the documents and also filter out stubs (overly short chunks)
print("Chunking Documents")
docs_chunks = [text_splitter.split_documents(doc) for doc in docs]
docs_chunks = [[c for c in dchunks if len(c.page_content) > 200] for dchunks in docs_chunks]

## Make some custom Chunks to give big-picture details
doc_string = "Available Documents:"
doc_metadata = []
for chunks in docs_chunks:
    metadata = getattr(chunks[0], 'metadata', {})
    doc_string += "\n - " + metadata.get('Title')
    doc_metadata += [str(metadata)]

extra_chunks = [doc_string] + doc_metadata

## Printing out some summary information for reference
pprint(doc_string, '\n')
for i, chunks in enumerate(docs_chunks):
    print(f"Document {i}")
    print(f" - Metadata: {chunks[0].metadata}")
    print(f" - # Chunks: {len(chunks)}")
    print()

Loading Documents
Chunking Documents


Document 0
 - Metadata: {'Published': '2021-04-12', 'Title': 'Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks', 'Authors': 'Patrick Lewis, Ethan Perez, Aleksandra Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela', 'Summary': 'Large pre-trained language models have been shown to store factual knowledge\nin their parameters, and achieve state-of-the-art results when fine-tuned on\ndownstream NLP tasks. However, their ability to access and precisely manipulate\nknowledge is still limited, and hence on knowledge-intensive tasks, their\nperformance lags behind task-specific architectures. Additionally, providing\nprovenance for their decisions and updating their world knowledge remain open\nresearch problems. Pre-trained models with a differentiable access mechanism to\nexplicit non-parametric memory can overcome this issue, but have so far been\nonly investigated for extractive

### Construct Document Vector Stores

1) Create indices surrounding document chunks.

In [5]:
%%time
## ^^ This cell will output a time
from faiss import IndexFlatL2
from langchain_community.docstore.in_memory import InMemoryDocstore

from langchain_core.prompts import ChatPromptTemplate

embedder = NVIDIAEmbeddings(model="nvolveqa_40k", model_type=None)

## Construct series of document vector stores
print("Constructing Vector Stores")
vecstores = [FAISS.from_texts(extra_chunks, embedder)]
vecstores += [FAISS.from_documents(doc_chunks, embedder) for doc_chunks in docs_chunks]

Constructing Vector Stores
CPU times: user 2.45 s, sys: 99.2 ms, total: 2.55 s
Wall time: 48.9 s


2. Combine the indices into a single one using the following utility.

In [6]:
embed_dims = len(embedder.embed_query("test"))
def default_FAISS():
    '''Useful utility for making an empty FAISS vectorstore'''
    return FAISS(
        embedding_function=embedder,
        index=IndexFlatL2(embed_dims),
        docstore=InMemoryDocstore(),
        index_to_docstore_id={},
        normalize_L2=False
    )

def aggregate_vstores(vectorstores):
    ## Initialize an empty FAISS Index and merge others into it
    ## Use default_faiss for simplicity, though it's tied to embedder by reference
    agg_vstore = default_FAISS()
    for vstore in vectorstores:
        agg_vstore.merge_from(vstore)
    return agg_vstore

if 'docstore' not in globals():
    ## Unintuitive optimization; merge_from seems to optimize constituent vector stores away
    docstore = aggregate_vstores(vecstores)

print(f"Constructed aggregate docstore with {len(docstore.docstore._dict)} chunks")

Constructed aggregate docstore with 625 chunks


### Implement RAG Chain
- A way to construct a from-scratch vector store for conversational memory (and a way to initialize an empty one with `default_FAISS()`)

- A vector store pre-loaded with useful document information from our `ArxivLoader` utility (stored in `docstore`).


In [9]:
from langchain.document_transformers import LongContextReorder
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnableLambda
from langchain.schema.runnable.passthrough import RunnableAssign
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings

from functools import partial
from operator import itemgetter

########################################################################
## Utility Runnables/Methods
def RPrint(preface=""):
    """Simple passthrough "prints, then returns" chain"""
    def print_and_return(x, preface):
        print(f"{preface}{x}")
        return x
    return RunnableLambda(partial(print_and_return, preface=preface))

def docs2str(docs, title="Document"):
    """Useful utility for making chunks into context string. Optional, but useful"""
    out_str = ""
    for doc in docs:
        doc_name = getattr(doc, 'metadata', {}).get('Title', title)
        if doc_name:
            out_str += f"[Quote from {doc_name}] "
        out_str += getattr(doc, 'page_content', str(doc)) + "\n"
    return out_str

## Optional; Reorders longer documents to center of output text
long_reorder = RunnableLambda(LongContextReorder().transform_documents)
########################################################################

In [10]:
from langchain.document_transformers import LongContextReorder
from langchain_core.runnables import RunnableLambda
from langchain_core.runnables.passthrough import RunnableAssign
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

import gradio as gr
from functools import partial
from operator import itemgetter


# llm = ChatNVIDIA(model="mixtral_8x7b") | StrOutputParser()
llm = ChatNVIDIA(model="llama2_70b") | StrOutputParser()

convstore = default_FAISS()

def save_memory_and_get_output(d, vstore):
    """Accepts 'input'/'output' dictionary and saves to convstore"""
    vstore.add_texts([
        f"User previously responded with {d.get('input')}",
        f"Agent previously responded with {d.get('output')}"
    ])
    return d.get('output')

initial_msg = (
    "Hello! I am a document chat agent here to help the user!"
    f" I have access to the following documents: {doc_string}\n\nHow can I help you?"
)

chat_prompt = ChatPromptTemplate.from_messages([("system",
    "You are a document chatbot. Help the user as they ask questions about documents."
    " User messaged just asked: {input}\n\n"
    " From this, we have retrieved the following potentially-useful info: "
    " Conversation History Retrieval:\n{history}\n\n"
    " Document Retrieval:\n{context}\n\n"
    " (Answer only from retrieval. Only cite sources that are used. Make your response conversational.)"
), ('user', '{input}')])

## Implement the retrieval chain
retrieval_chain = (
    {'input' : (lambda x: x)}
    | RunnableAssign({'history' : itemgetter('input') | convstore.as_retriever() | long_reorder | docs2str})
    | RunnableAssign({'context' : itemgetter('input') | docstore.as_retriever()  | long_reorder | docs2str})
    | RPrint()
)

stream_chain = chat_prompt | llm

def chat_gen(message, history=[], return_buffer=True):
    buffer = ""
    ## First perform the retrieval based on the input message
    retrieval = retrieval_chain.invoke(message)
    line_buffer = ""

    ## Then, stream the results of the stream_chain
    for token in stream_chain.stream(retrieval):
        buffer += token
        ## If you're using standard print, keep line from getting too long
        if not return_buffer:
            line_buffer += token
            if "\n" in line_buffer:
                line_buffer = ""
            if ((len(line_buffer)>84 and token and token[0] == " ") or len(line_buffer)>100):
                line_buffer = ""
                yield "\n"
                token = "  " + token.lstrip()
        yield buffer if return_buffer else token

    ## Lastly, save the chat exchange to the conversation memory buffer
    save_memory_and_get_output({'input':  message, 'output': buffer}, convstore)


## Start of Agent Event Loop
test_question = "Tell me about RAG!"  ## <- modify as desired

## Before you launch your gradio interface, make sure your thing works
for response in chat_gen(test_question, return_buffer=False):
    print(response, end='')

{'input': 'Tell me about RAG!', 'history': '', 'context': '[Quote from Retrieval-Augmented Generation for Large Language Models: A Survey] to the RAG process, specifically focusing on the aspects\nof “Retrieval”, “Generator” and “Augmentation”, and\ndelve into their synergies, elucidating how these com-\nponents intricately collaborate to form a cohesive and\neffective RAG framework.\n• We construct a thorough evaluation framework for RAG,\noutlining the evaluation objectives and metrics.\nOur\ncomparative analysis clarifies the strengths and weak-\nnesses of RAG compared to fine-tuning from various\nperspectives. Additionally, we anticipate future direc-\ntions for RAG, emphasizing potential enhancements to\ntackle current challenges, expansions into multi-modal\nsettings, and the development of its ecosystem.\nThe paper unfolds as follows: Section 2 and 3 define RAG\nand detail its developmental process. Section 4 through 6 ex-\nplore core components—Retrieval, “Generation” and “Aug-

## Save Vector Stores

Save the accumulated vector store as shown [in the official documentation](https://python.langchain.com/docs/integrations/vectorstores/faiss#saving-and-loading).

In [11]:
## Save and compress index
docstore.save_local("docstore_index")
!tar czvf docstore_index.tgz docstore_index

!rm -rf docstore_index

docstore_index/
docstore_index/index.pkl
docstore_index/index.faiss


Validation: If everything was properly saved, the following line can be invoked to pull the index from the compressed `tgz` file 

In [13]:
## Load vector database from tgz file
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
from langchain_community.vectorstores import FAISS

embedder = NVIDIAEmbeddings(model="nvolveqa_40k")
!tar xzvf docstore_index.tgz
new_db = FAISS.load_local("docstore_index", embedder)
docs = new_db.similarity_search("Testing the index")
print(docs[0].page_content[:10])

docstore_index/
docstore_index/index.pkl
docstore_index/index.faiss
0.99
1.0
1


### LangServe Server Deployment

LangServe helps developers deploy LangChain runnables and chains as a REST API.
This library is integrated with FastAPI and uses pydantic for data validation.
In addition, it provides a client that can be used to call into runnables deployed on a server. A javascript client is available in LangChainJS.

LangServe integrates a LangChain model, such as ChatNVIDIA, to create and distribute accessible API routes. Using this, we will be able to supply functionality to the frontend service's server_app.py session, which includes:
- A simple endpoint named :9012/basic_chat for the basic chatbot, exemplified below.
- A pair of endpoints named :9012/retriever and :9012/generator for the RAG chatbot.

This is an ***always-on RAG formulation*** where:
- A retriever is always retrieving context by default.
- A generator is acting on the retrieved context.

***Note we need two notebooks to create Gradio application, one to deploy and run the basic chain/RAG chain, another to start the Gradio server for Chatbot UI.***

In [19]:
%%writefile server_app.py
import typing
import os
import random

from datetime import datetime
from fastapi import FastAPI
from time import sleep

from functools import partial
from operator import itemgetter

from langchain.document_loaders import ArxivLoader
from langchain.document_transformers import LongContextReorder
from langchain.schema import SystemMessage, HumanMessage
from langchain.prompts import ChatPromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.runnables import RunnableMap, RunnableLambda
from langchain_core.runnables.passthrough import RunnableAssign
from langchain_community.vectorstores import FAISS
from langchain_core.output_parsers import StrOutputParser
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda, RunnableBranch
from langchain_core.runnables.passthrough import RunnableAssign
from langchain.document_transformers import LongContextReorder
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
from langchain.pydantic_v1 import BaseModel
from langserve import RemoteRunnable
import gradio as gr

from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
from langchain_community.vectorstores import FAISS

# https://python.langchain.com/docs/langserve#server
from fastapi import FastAPI
from langchain.prompts import ChatPromptTemplate
from langchain_nvidia_ai_endpoints import ChatNVIDIA
from langserve import add_routes
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

## LLM Model
llm = ChatNVIDIA(model="llama2_70b") | StrOutputParser()

## Prompt
chat_prompt = ChatPromptTemplate.from_messages([("system",
    "You are a document chatbot. Help the user as they ask questions about documents."
    " User messaged just asked you a question: {input}\n\n"
    " The following information may be useful for your response: "
    " Document Retrieval:\n{context}\n\n"
    " (Answer only from retrieval. Only cite sources that are used. Make your response conversational)"
), ('user', '{input}')])


## Embedding model
embedder = NVIDIAEmbeddings(model="nvolveqa_40k")

## Load vector database
import tarfile 
file = tarfile.open('docstore_index.tgz') 
file.extractall('.') 
# file.close() 

docstore = FAISS.load_local("docstore_index", embedder)
docs = list(docstore.docstore._dict.values())

## Deploy models to FAST Application
app = FastAPI(
  title="LangChain Server",
  version="1.0",
  description="A simple api server using Langchain's Runnable interfaces",
)

add_routes(
    app,
    llm,
    path="/basic_chat",
)

add_routes(
    app,
    docstore.as_retriever(),
    path="/retriever",
)

add_routes(
    app,
    chat_prompt | llm ,
    path="/generator",
)

# Might be encountered if this were for a standalone python file...
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=9012)


Overwriting server_app.py


Below script is to deploy and run basic chain, and RAG chain (retriever and generator) via LangServe, we need another notebook to call these REST API and start the Gradio UI. 



In [20]:
!python server_app.py  

[32mINFO[0m:     Started server process [[36m1852[0m]
[32mINFO[0m:     Waiting for application startup.

 __          ___      .__   __.   _______      _______. _______ .______     ____    ____  _______
|  |        /   \     |  \ |  |  /  _____|    /       ||   ____||   _  \    \   \  /   / |   ____|
|  |       /  ^  \    |   \|  | |  |  __     |   (----`|  |__   |  |_)  |    \   \/   /  |  |__
|  |      /  /_\  \   |  . `  | |  | |_ |     \   \    |   __|  |      /      \      /   |   __|
|  `----./  _____  \  |  |\   | |  |__| | .----)   |   |  |____ |  |\  \----.  \    /    |  |____
|_______/__/     \__\ |__| \__|  \______| |_______/    |_______|| _| `._____|   \__/     |_______|

[1;32;40mLANGSERVE:[0m Playground for chain "/generator/" is live at:
[1;32;40mLANGSERVE:[0m  │
[1;32;40mLANGSERVE:[0m  └──> /generator/playground/
[1;32;40mLANGSERVE:[0m
[1;32;40mLANGSERVE:[0m Playground for chain "/retriever/" is live at:
[1;32;40mLANGSERVE:[0m  │
[1;32;40mLANGSERVE:[

Stop above process and run the server_app.py script in another notebook. Use below code to start Gradio application.

Access the `basic_chat` endpoint using the following interface.

In [21]:
## Test call basic chat API that is running on LangServe
from langserve import RemoteRunnable
from langchain_core.output_parsers import StrOutputParser

llm = RemoteRunnable("http://0.0.0.0:9012/basic_chat/") | StrOutputParser()
for token in llm.stream("tell me something about RAG - Retrieval-Augmented Generation, give reference paper"):
    print(token, end='')

RAG (Retrieval-Augmented Generation) is a technique used in natural language processing (NLP) and machine learning (ML) that combines the strengths of both retrieval-based and generation-based models to improve the quality and efficiency of text generation.

The basic idea behind RAG is to use a retrieval model to find relevant information from a large corpus of text, and then use a generation model to generate new text that incorporates the retrieved information. This can help to improve the coherence and accuracy of the generated text, as well as reduce the risk of generating irrelevant or repetitive content.

One reference paper for RAG is "Retrieval-Augmented Generation: A New Paradigm for Text Generation" by Xia et al. (2020) [1]. This paper proposes a RAG framework that combines a retrieval model and a generation model to generate high-quality text. The authors evaluate the effectiveness of RAG on several benchmark datasets and show that it outperforms state-of-the-art generation

## Gradio Chatbot UI

While basic chain and RAG chain are running on the server, we could integrate them into the main chain, and evaluate the performance of two chains based on synthethic questions per LLM-as-a-Judge.

#### Generating Synthetic Question-Answer Pairs

The evaluation routine is as below:

- Sample the RAG agent document pool to find two document chunks.
- Use those two document chunks to generate a synthetic "baseline" question-answer pair.
- Use the RAG agent to generate its own answer.
- Use a judge LLM to compare the two responses while grounding the synthetic generation as "ground-truth correct."

The chain should be a simple but powerful process that tests for the following objective:

> Does my RAG chain outperform a narrow chatbot with limited document access.

#### LLM-as-a-Judge Formulation

In the realm of conversational AI, using LLMs as evaluators or 'judges' has emerged as a useful approach for configurable automatic testing of natural language task performance:

- An LLM can simulate a range of interaction scenarios and generate synthetic data, allowing an evaluation developer to generate targeted inputs to eliciting a range of behaviors from your chatbot.

- The chatbot's correspondence/retrieval on the synthetic data can be evaluated or parsed by an LLM and a consistent output format such as "Pass"/"Fail", similarity, or extraction can be enforced.

- Many such results can be aggregated and a metric can be derived which explains something like "% of passing evaluations", "average number of relevant details from the sources", "average cosine similarity", etc.

This idea of using LLMs to test out and quantify chatbot quality, known as [**"LLM-as-a-Judge,"**](https://arxiv.org/abs/2306.05685) allows for easy test specifications that align closely with human judgment and can be fine-tuned and replicated at scale.

In [22]:
import typing
import os
import random

from datetime import datetime
from fastapi import FastAPI
from time import sleep

from functools import partial
from operator import itemgetter

from langchain.document_loaders import ArxivLoader
from langchain.document_transformers import LongContextReorder
from langchain.schema import SystemMessage, HumanMessage
from langchain.prompts import ChatPromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.runnables import RunnableMap, RunnableLambda
from langchain_core.runnables.passthrough import RunnableAssign
from langchain_community.vectorstores import FAISS

from langchain.pydantic_v1 import BaseModel
from langserve import RemoteRunnable
import gradio as gr

import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


#####################################################################
## Chain Dictionary

def docs2str(docs, title="Document"):
    """Useful utility for making chunks into context string. Optional, but useful"""
    out_str = ""
    for doc in docs:
        doc_name = getattr(doc, 'metadata', {}).get('Title', title)
        if doc_name:
            out_str += f"[Quote from {doc_name}] "
        out_str += getattr(doc, 'page_content', str(doc)) + "\n"
    return out_str


def output_puller(inputs):
    """If you want to support streaming, implement final step as a generator extractor."""
    for token in inputs:
        if token.get('output'):
            yield token.get('output')

## Necessary Endpoints
chains_dict = {
    'basic' : RemoteRunnable("http://lab:9012/basic_chat/"),
    'retriever' : RemoteRunnable("http://lab:9012/retriever/"),
    'generator' : RemoteRunnable("http://lab:9012/generator/"),
}

basic_chain = chains_dict['basic']


## Retrieval-Augmented Generation Chain

retrieval_chain = (
    {'input' : (lambda x: x)}
    | RunnableAssign(
        {'context' : itemgetter('input') 
        | chains_dict['retriever'] 
        | LongContextReorder().transform_documents
        | docs2str
    })
)

output_chain = RunnableAssign({"output" : chains_dict['generator'] }) | output_puller
rag_chain = retrieval_chain | output_chain

#####################################################################
## ChatBot utilities

def add_message(message, history, role=0, preface=""):
    if not history or history[-1][role] is not None:
        history += [[None, None]]
    history[-1][role] = preface
    buffer = ""
    try:
        for chunk in message:
            token = getattr(chunk, 'content', chunk)
            buffer += token
            history[-1][role] += token
            yield history, buffer, False 
    except Exception as e:
        logger.error(f"Gradio Stream failed: {e}\nFor Input {history}")
        history[-1][role] += f"...\nGradio Stream failed: {e}"
        yield history, buffer, True


def add_text(history, text):
    history = history + [(text, None)]
    return history, gr.Textbox(value="", interactive=False)


def bot(history, chain_key):
    chain = {'Basic' : basic_chain, 'RAG' : rag_chain}.get(chain_key)
    msg_stream = chain.stream(history[-1][0])
    for history, buffer, is_error in add_message(msg_stream, history, role=1):
        yield history


#####################################################################
## Document/Assessment Utilities


def get_chunks(document):
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=1000,
        chunk_overlap=100,
        separators=["\n\n", "\n", ".", ";", ",", " ", ""],
    )
    content = document[0].page_content
    content = content.replace("{", "[").replace("}", "]")
    if "References" in content:
        content = content[:content.index("References")]
    document[0].page_content = content
    return text_splitter.split_documents(document)


def get_day_difference(date_str):
    given_date = datetime.strptime(date_str, '%Y-%m-%d').date()
    current_date = datetime.now().date()
    difference = current_date - given_date
    return difference.days


def get_fresh_chunks(chunks):
    return [
        chunk for chunk in chunks 
            # if get_day_difference(chunk.metadata.get("Published", "2000-01-01")) < 30
            if get_day_difference(chunk.metadata.get("Published", "2000-01-01")) < 365
    ]


def format_chunk(doc):
    doc_content = doc.page_content.replace('{', '\{').replace('}', '\}')
    return (
        f"Paper: {doc.metadata.get('Title', 'unknown')}"
        f"\n\nSummary: {doc.metadata.get('Summary', 'unknown')}"
        f"\n\nPage Body: {doc_content}"
    )


def get_synth_prompt(docs):
    doc1, doc2 = random.sample(docs, 2)
    sys_msg = (
        "Use the documents provided by the user to generate an interesting question-answer pair."
        " Try to use both documents if possible, and rely more on the document bodies than the summary. Be specific!"
        " Use the format:\nQuestion: (good question, 1-3 sentences, detailed)\n\nAnswer: (answer derived from the documents)"
    )
    usr_msg = (f"Document1: {format_chunk(doc1)}\n\nDocument2: {format_chunk(doc2)}")
    return ChatPromptTemplate.from_messages([('system', sys_msg), ('user', usr_msg)])


def get_eval_prompt():
    eval_instruction = (
        "Evaluate the following Question-Answer pair for human preference and consistency."
        "Ask question only related to RAG (Retrieval Augmented Generation)."
        "\nAssume the first answer is a ground truth answer and has to be correct."
        "\nAssume the second answer may or may not be true."
        "\n[1] The first answer is extremely preferable, or the second answer heavily deviates."
        "\n[2] The second answer does not contradict the first and significantly improves upon it."
        "\n\nOutput Format:"
        "\nJustification\n[2] if 2 is strongly preferred, [1] otherwise"
    )
    return {"input" : lambda x:x} | ChatPromptTemplate.from_messages([('system', eval_instruction), ('user', '{input}')])


## Document names, and the overall chunk list
class Globals:
    doc_names = set()
    doc_chunks = []


def rag_eval(history, chain_key):
    """RAG Evaluation Chain"""
    if not len(history) or history[-1][0] is not None:
        history += [[None, None]]
    
    if not Globals.doc_chunks:
        try: 
            docstore = FAISS.load_local("docstore_index", lambda x:x)
            Globals.doc_chunks = list(docstore.docstore._dict.values())
            Globals.doc_names = {doc.metadata.get("Title", "Unknown") for doc in Globals.doc_chunks}
        except: 
            pass

    doc_names = Globals.doc_names 
    doc_chunks = get_fresh_chunks(Globals.doc_chunks)

    if len(doc_chunks) < 2:
        logger.error(f"Attempted to evaluate with less than two fresh chunks submitted (last modified < 30 days ago)")
        history[-1][1] = "Please upload a fresh paper (<30 days) inside your saved docstore_index directory that so we can ask our chain some questions"
        yield history
    else:
        main_chain = {'Basic' : basic_chain, 'RAG' : rag_chain}.get(chain_key)
        eval_llm = basic_chain
        num_points = 0
        # num_questions = 8
        num_questions = 5

        for i in range(num_questions):

            synth_chain = get_synth_prompt(doc_chunks) | eval_llm
            
            preface = "Generating Synthetic QA Pair:\n"
            msg_stream = synth_chain.stream({})
            for history, synth_qa, is_error in add_message(msg_stream, history, role=0, preface=preface):
                yield history
            if is_error: break

            synth_pair = synth_qa.split("\n\n")
            if len(synth_pair) < 2:
                logger.error(f"Illegal QA with no break")
                history[-1][0] += f"...\nIllegal QA with no break"
                yield history
            else:   
                synth_q, synth_a = synth_pair[:2]

                msg_stream = main_chain.stream(synth_q)
                for history, rag_response, is_error in add_message(msg_stream, history, role=1):
                    yield history
                if is_error: break

                eval_chain = get_eval_prompt() | eval_llm
                usr_msg = f"Question: {synth_q}\n\nAnswer 1: {synth_a}\n\n Answer 2: {rag_response}"
                msg_stream = eval_chain.stream(usr_msg)
                for history, eval_response, is_error in add_message(msg_stream, history, role=0, preface="Evaluation: "):
                    yield history

                num_points += ("[2]" in eval_response)
            
            history[-1][0] += f"\n[{num_points} / {i+1}]"
        
        if (num_points / num_questions > 0.60):
            msg_stream = (
                "Congrats! You've passed the assessment!! 😁\n"
                "Please make sure to click the ASSESS TASK button before shutting down your course environment"
            )
            for history, eval_response, is_error in add_message(msg_stream, history, role=0):
                yield history

            ## secret

        else: 
            msg_stream = f"Metric score of {num_points / num_questions}, while 0.60 is required\n"
            for history, eval_response, is_error in add_message(msg_stream, history, role=0):
                yield history            
        
        yield history


#####################################################################
## GRADIO EVENT LOOP

# https://github.com/gradio-app/gradio/issues/4001
CSS ="""
.contain { display: flex; flex-direction: column; height:80vh;}
#component-0 { height: 100%; }
#chatbot { flex-grow: 1; overflow: auto;}
"""
THEME = gr.themes.Default(primary_hue="green")

with gr.Blocks(css=CSS, theme=THEME) as demo:
    chatbot = gr.Chatbot(
        [],
        elem_id="chatbot",
        bubble_full_width=False,
        avatar_images=(None, (os.path.join("frontend/", "parrot.png"))),
    )

    with gr.Row():
        txt = gr.Textbox(
            scale=4,
            show_label=False,
            placeholder="Enter text and press enter, or upload an image",
            container=False,
        )

        chain_btn  = gr.Radio(["Basic", "RAG"], value="Basic", label="Main Route")
        test_btn   = gr.Button("🎓\nEvaluate")

    # Reference: https://www.gradio.app/guides/blocks-and-event-listeners

    # This listener is triggered when the user presses the Enter key while the Textbox is focused.
    txt_msg = (
        # first update the chatbot with the user message immediately. Also, disable the textbox
        txt.submit(              ## On textbox submit (or enter)...
            fn=add_text,            ## Run the add_text function...
            inputs=[chatbot, txt],  ## Pass in the values of chatbot and txt...
            outputs=[chatbot, txt], ## Assign the results to the values of chatbot and txt...
            queue=False             ## And don't use the function as a generator (so no streaming)!
        )
        # then update the chatbot with the bot response (same variable logic)
        .then(bot, [chatbot, chain_btn], [chatbot])
        ## Then, unblock the textbox by assigning an active status to it
        .then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
    )

    test_msg = test_btn.click(
        rag_eval, 
        inputs=[chatbot, chain_btn], 
        outputs=chatbot, 
    )

#####################################################################
## Final App Deployment

demo.queue()

logger.warning("Starting FastAPI app")
app = FastAPI()

app = gr.mount_gradio_app(app, demo, '/')

@app.route("/health")
async def health():
    return {"success": True}, 200


Starting FastAPI app


#### Start Gradio Chatbot

***Note endpoints should keep running in LangServe with another notebook during starting Gradio Chatbot***

In [None]:
# chatbot = gr.Chatbot(value = [[None, initial_msg]])
# demo = gr.ChatInterface(chat_gen, chatbot=chatbot).queue()

demo.queue()

try:
    demo.launch(debug=True, share=True, show_api=False)
    demo.close()
except Exception as e:
    demo.close()
    print(e)
    raise e

Running on local URL:  http://127.0.0.1:7861
Running on public URL: https://12883ff5ec98be9ee2.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)
