In [1]:
%pip install -q langchain langchain-community langchain-nvidia-ai-endpoints gradio rich
%pip install -q arxiv pymupdf faiss-cpu

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
import os
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

# Set NVIDIA API key from environment variable
# If not found in environment, this will be None
api_key = os.environ.get("NVIDIA_API_KEY")

# Print a message based on whether the key was found
if api_key:
    print("NVIDIA API key loaded successfully")
else:
    raise ValueError("NVIDIA API key not found in .env file. Please set NVIDIA_API_KEY in your environment variables.")


NVIDIA API key loaded successfully


In [3]:
from langchain_nvidia_ai_endpoints import ChatNVIDIA

# Get all available models
all_models = ChatNVIDIA.get_available_models()

# Display only the first 20 models
print(f"First 20 of {len(all_models)} available models:")
for i, model in enumerate(all_models[:20]):
    print(f"{i+1}. {model}")

First 20 of 88 available models:
1. id='institute-of-science-tokyo/llama-3.1-swallow-70b-instruct-v0.1' model_type='chat' client='ChatNVIDIA' endpoint=None aliases=None supports_tools=False supports_structured_output=True base_model=None
2. id='nvidia/llama-3.1-nemotron-51b-instruct' model_type='chat' client='ChatNVIDIA' endpoint=None aliases=None supports_tools=False supports_structured_output=False base_model=None
3. id='qwen/qwen2.5-coder-32b-instruct' model_type='chat' client='ChatNVIDIA' endpoint=None aliases=None supports_tools=False supports_structured_output=False base_model=None
4. id='google/gemma-7b' model_type='chat' client='ChatNVIDIA' endpoint=None aliases=['ai-gemma-7b', 'playground_gemma_7b', 'gemma_7b'] supports_tools=False supports_structured_output=False base_model=None
5. id='yentinglin/llama-3-taiwan-70b-instruct' model_type='chat' client='ChatNVIDIA' endpoint=None aliases=None supports_tools=False supports_structured_output=False base_model=None
6. id='microsoft/k

In [4]:
###################################### Document Summarization  ##################################

from langchain_core.runnables import RunnableLambda
from langchain_core.runnables.passthrough import RunnableAssign
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser, PydanticOutputParser
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import UnstructuredFileLoader, ArxivLoader

from pydantic import BaseModel, Field
from typing import List
from IPython.display import clear_output
from functools import partial
from rich.console import Console
from rich.style import Style
from rich.theme import Theme

# Define the Pydantic model for document summarization
class DocumentSummaryBase(BaseModel):
    """Base class for document summarization results."""
    running_summary: str = Field(default="", description="The running summary of the document. Update, don't override.")
    main_ideas: List[str] = Field(default_factory=list, description="Maximum 3 important points from the document.")
    loose_ends: List[str] = Field(default_factory=list, description="Maximum 3 open questions or areas needing clarification.")

# Set up console formatting with NVIDIA brand color
console = Console()
nvidia_green = Style(color="#76B900", bold=True)
pprint = partial(console.print, style=nvidia_green)

# Document Loading 
# For local files (commented out as an option)
# loader = UnstructuredFileLoader("your_document.pdf")
# docs = loader.load()

# Load document from Arxiv (GraphRAG paper)
loader = ArxivLoader(query="2404.16130")  # GraphRAG paper
docs = loader.load()

# Configure text splitter for document chunking
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1200, 
    chunk_overlap=100,
    separators=["\n\n", "\n", ".", ";", ",", " ", ""],
)

# Optional preprocessing (commented out)
# def preprocess_text(text):
#     # Remove special characters, normalize whitespace, etc.
#     return text.replace("...", ".")

# Split documents into manageable chunks
docs_split = text_splitter.split_documents(docs)
print(f"Document split into {len(docs_split)} chunks")

# Create summarization prompt
summary_prompt = ChatPromptTemplate.from_messages([
    ("system", """
    You are a technical document summarizer tasked with generating a running summary of a research document.
    You are given chunks of text from a document and an existing knowledge base containing:
    1. A running summary of the document so far
    2. Main ideas (max 3)
    3. Loose ends or questions (max 3)
    
    Your goal is to update this knowledge base with new information from the current document chunk.
    
    IMPORTANT INSTRUCTIONS:
    - DO NOT lose information when updating the knowledge base
    - Integrate new information with existing running_summary
    - Update main_ideas and loose_ends as needed, but keep them limited to 3 items each
    - Follow the format instructions exactly
    - If nothing new is present in the chunk, return the existing knowledge base unchanged
    
    {format_instructions}
    """),
    ("user", "{input}")
])




Document split into 83 chunks


In [5]:
def RExtract(pydantic_class, llm, prompt):
    '''
    Runnable Extraction module
    Returns a knowledge dictionary populated by slot-filling extraction
    '''
    parser = PydanticOutputParser(pydantic_object=pydantic_class)
    instruct_merge = RunnableAssign({'format_instructions' : lambda x: parser.get_format_instructions()})
    def preparse(string):
        if '{' not in string: string = '{' + string
        if '}' not in string: string = string + '}'
        string = (string
            .replace("\\_", "_")
            .replace("\n", " ")
            .replace(r"\]", "]")
            .replace(r"\[", "[")
        )
        # print(string)  ## Good for diagnostics
        return string
    return instruct_merge | prompt | llm | preparse | parser


In [6]:
latest_summary = ""

def RSummarizer(knowledge, llm, prompt, verbose=False):
    def summarize_docs(docs):
        parse_chain = RunnableAssign({'info_base' : RExtract(knowledge.__class__, llm, prompt)})
        state = {'info_base' : knowledge}

        global latest_summary  ## If your loop crashes, you can check out the latest_summary

        for i, doc in enumerate(docs):
            state['input'] = doc.page_content
            state = parse_chain.invoke(state)

            assert 'info_base' in state
            if verbose:
                print(f"Considered {i+1} documents")
                pprint(state['info_base'])
                latest_summary = state['info_base']
                clear_output(wait=True)

        return state['info_base']
    return RunnableLambda(summarize_docs)

instruct_model = ChatNVIDIA(model="mistralai/mistral-7b-instruct-v0.3").bind(max_tokens=4096)
instruct_llm = instruct_model | StrOutputParser()

## Take the first 10 document chunks and accumulate a DocumentSummaryBase
summarizer = RSummarizer(DocumentSummaryBase(), instruct_llm, summary_prompt, verbose=True)
summary = summarizer.invoke(docs_split[:10])

Considered 10 documents


In [None]:
###################################### Retrieval-Augmented Generation (RAG) ##################################

import json
import gradio as gr
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import ArxivLoader, UnstructuredPDFLoader
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain.document_transformers import LongContextReorder
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnableLambda
from langchain.schema.runnable.passthrough import RunnableAssign
from faiss import IndexFlatL2
from operator import itemgetter
from functools import partial
from rich.console import Console
from rich.style import Style
from rich.theme import Theme

console = Console()
base_style = Style(color="#76B900", bold=True)
pprint = partial(console.print, style=base_style)

# NVIDIAEmbeddings.get_available_models()
embedder = NVIDIAEmbeddings(model="nvidia/nv-embed-v1", truncate="END")

# ChatNVIDIA.get_available_models()
instruct_llm = ChatNVIDIA(model="mistralai/mixtral-8x22b-instruct-v0.1")

## Utility Runnables/Methods
def RPrint(preface=""):
    """Simple passthrough "prints, then returns" chain"""
    def print_and_return(x, preface):
        if preface: print(preface, end="")
        pprint(x)
        return x
    return RunnableLambda(partial(print_and_return, preface=preface))

def docs2str(docs, title="Document"):
    """Useful utility for making chunks into context string. Optional, but useful"""
    out_str = ""
    for doc in docs:
        doc_name = getattr(doc, 'metadata', {}).get('Title', title)
        if doc_name:
            out_str += f"[Quote from {doc_name}] "
        out_str += getattr(doc, 'page_content', str(doc)) + "\n"
    return out_str

## Optional; Reorders longer documents to center of output text
long_reorder = RunnableLambda(LongContextReorder().transform_documents)

## Data Loading/Text Splitting
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000, chunk_overlap=100,
    separators=["\n\n", "\n", ".", ";", ",", " "],
)

print("Loading Documents")
docs = [
    ArxivLoader(query="1706.03762").load(),  ## Attention Is All You Need Paper
    ArxivLoader(query="1810.04805").load(),  ## BERT Paper
    ArxivLoader(query="2005.11401").load(),  ## RAG Paper
    ArxivLoader(query="2205.00445").load(),  ## MRKL Paper
    ArxivLoader(query="2310.06825").load(),  ## Mistral Paper
    ArxivLoader(query="2306.05685").load(),  ## LLM-as-a-Judge
    ## Some longer papers
    # ArxivLoader(query="2210.03629").load(),  ## ReAct Paper
    # ArxivLoader(query="2112.10752").load(),  ## Latent Stable Diffusion Paper
    # ArxivLoader(query="2103.00020").load(),  ## CLIP Paper
]

## Cut the paper short if references is included.
## This is a standard string in papers.
for doc in docs:
    content = json.dumps(doc[0].page_content)
    if "References" in content:
        doc[0].page_content = content[:content.index("References")]

## Split the documents and also filter out stubs (overly short chunks)
print("Chunking Documents")
docs_chunks = [text_splitter.split_documents(doc) for doc in docs]
docs_chunks = [[c for c in dchunks if len(c.page_content) > 200] for dchunks in docs_chunks]

## Make some custom Chunks to give big-picture details
doc_string = "Available Documents:"
doc_metadata = []
for chunks in docs_chunks:
    metadata = getattr(chunks[0], 'metadata', {})
    doc_string += "\n - " + metadata.get('Title')
    doc_metadata += [str(metadata)]

extra_chunks = [doc_string] + doc_metadata

## Printing out some summary information for reference
pprint(doc_string, '\n')
for i, chunks in enumerate(docs_chunks):
    print(f"Document {i}")
    print(f" - # Chunks: {len(chunks)}")
    print(f" - Metadata: ")
    pprint(chunks[0].metadata)
    print()

Loading Documents
Chunking Documents


Document 0
 - # Chunks: 35
 - Metadata: 



Document 1
 - # Chunks: 45
 - Metadata: 



Document 2
 - # Chunks: 46
 - Metadata: 



Document 3
 - # Chunks: 40
 - Metadata: 



Document 4
 - # Chunks: 21
 - Metadata: 



Document 5
 - # Chunks: 44
 - Metadata: 





In [None]:
%%time
## Constructing Your Document Vector Stores
print("Constructing Vector Stores")
vecstores = [FAISS.from_texts(extra_chunks, embedder)]
vecstores += [FAISS.from_documents(doc_chunks, embedder) for doc_chunks in docs_chunks]

embed_dims = len(embedder.embed_query("test"))
def default_FAISS():
    '''Useful utility for making an empty FAISS vectorstore'''
    return FAISS(
        embedding_function=embedder,
        index=IndexFlatL2(embed_dims),
        docstore=InMemoryDocstore(),
        index_to_docstore_id={},
        normalize_L2=False
    )

def aggregate_vstores(vectorstores):
    ## Initialize an empty FAISS Index and merge others into it
    ## We'll use default_faiss for simplicity, though it's tied to your embedder by reference
    agg_vstore = default_FAISS()
    for vstore in vectorstores:
        agg_vstore.merge_from(vstore)
    return agg_vstore

## Unintuitive optimization; merge_from seems to optimize constituent vector stores away
docstore = aggregate_vstores(vecstores)

print(f"Constructed aggregate docstore with {len(docstore.docstore._dict)} chunks")

Constructing Vector Stores
Constructed aggregate docstore with 238 chunks
CPU times: user 557 ms, sys: 30.6 ms, total: 588 ms
Wall time: 16.6 s


In [None]:
# NVIDIAEmbeddings.get_available_models()
embedder = NVIDIAEmbeddings(model="nvidia/nv-embed-v1", truncate="END")
# ChatNVIDIA.get_available_models()
instruct_llm = ChatNVIDIA(model="mistralai/mixtral-8x7b-instruct-v0.1")
# instruct_llm = ChatNVIDIA(model="meta/llama-3.1-8b-instruct")

convstore = default_FAISS()

def save_memory_and_get_output(d, vstore):
    """Accepts 'input'/'output' dictionary and saves to convstore"""
    vstore.add_texts([
        f"User previously responded with {d.get('input')}",
        f"Agent previously responded with {d.get('output')}"
    ])
    return d.get('output')

initial_msg = (
    "Hello! I am a document chat agent here to help the user!"
    f" I have access to the following documents: {doc_string}\n\nHow can I help you?"
)

chat_prompt = ChatPromptTemplate.from_messages([("system",
    "You are a document chatbot. Help the user as they ask questions about documents."
    " User messaged just asked: {input}\n\n"
    " From this, we have retrieved the following potentially-useful info: "
    " Conversation History Retrieval:\n{history}\n\n"
    " Document Retrieval:\n{context}\n\n"
    " (Answer only from retrieval. Only cite sources that are used. Make your response conversational.)"
), ('user', '{input}')])

## Implement the Stream Chain
stream_chain = chat_prompt| RPrint() | instruct_llm | StrOutputParser()

## Implement the RAG Chain
retrieval_chain = (
    {'input' : (lambda x: x)}
    | RunnableAssign({'history' : itemgetter('input') | convstore.as_retriever() | long_reorder | docs2str})
    | RunnableAssign({'context' : itemgetter('input') | docstore.as_retriever()  | long_reorder | docs2str})
    | RPrint()
)

def chat_gen(message, history=[], return_buffer=True):
    buffer = ""
    ## First perform the retrieval based on the input message
    retrieval = retrieval_chain.invoke(message)
    line_buffer = ""

    ## Then, stream the results of the stream_chain
    for token in stream_chain.stream(retrieval):
        buffer += token
        ## If you're using standard print, keep line from getting too long
        yield buffer if return_buffer else token

    ## Lastly, save the chat exchange to the conversation memory buffer
    save_memory_and_get_output({'input':  message, 'output': buffer}, convstore)


## Start of Agent Event Loop
test_question = "Tell me about RAG!"  ## <- modify as desired

## Before you launch your gradio interface, make sure your thing works
for response in chat_gen(test_question, return_buffer=False):
    print(response, end='')

RAG stands for Retrieval-Augmented Generation. It is a method that combines a retriever model, which identifies relevant documents from a large corpus, with a generator model, which produces answers to natural language processing (NLP) tasks. The retriever model is based on the DPR (Dense Passage Retrieval) model, and the generator model is based on a transformer model.

RAG can be employed in a variety of knowledge-intensive NLP tasks and is trained on a single Wikipedia dump, which is split into disjoint 100-word chunks to make a total of 21M documents.

There are two ways that RAG can be used, RAG-Sequence and RAG-Token. RAG-Sequence uses the same document to predict each target token, while RAG-Token can predict each target token based on a different document.

There are also two decoding procedures that can be used with RAG, Thorough Decoding and Fast Decoding. Thorough decoding runs additional forward passes for longer output sequences, while Fast Decoding makes an approximation 

In [None]:
chatbot = gr.Chatbot(value = [[None, initial_msg]])
demo = gr.ChatInterface(chat_gen, chatbot=chatbot).queue()

try:
    demo.launch(debug=True, share=True, show_api=False)
    demo.close()
except Exception as e:
    demo.close()
    print(e)
    raise e

  chatbot = gr.Chatbot(value = [[None, initial_msg]])


Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://66d7f153f0437a03c7.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://66d7f153f0437a03c7.gradio.live
Closing server running on port: 7860
