You can also follow along on Google Colab!

<a target="_blank" href="https://colab.research.google.com/github/MadryLab/context-cite/blob/main/notebooks/rag_langchain_example.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Running `ContextCite` with a RAG using LangChain

In this notebook, we'll show a quick example of how to use `ContextCite` with a RAG chain using the `langchain` library. **If running in Colab, be sure to change your to a GPU runtime!** Thanks to Bagatur Askaryan for helpful feedback!

In [None]:
!pip install -qU context-cite langchain-community langchain-openai langchain-core langchain-text-splitters faiss-gpu

In [2]:
!wget https://raw.githubusercontent.com/MadryLab/context-cite/main/assets/solar_eclipse.txt

--2024-05-05 18:56:13--  https://raw.githubusercontent.com/MadryLab/context-cite/main/assets/solar_eclipse.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 24030 (23K) [text/plain]
Saving to: ‘solar_eclipse.txt’


2024-05-05 18:56:13 (158 MB/s) - ‘solar_eclipse.txt’ saved [24030/24030]



In [3]:
import os
import torch as ch
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import CharacterTextSplitter
from langchain_core.runnables import RunnablePassthrough, chain
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

from context_cite import ContextCiter

[nltk_data] Downloading package punkt to
[nltk_data]     /mnt/xfs/home/bencw/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [4]:
os.environ["OPENAI_API_KEY"] = None # Add your OpenAI key here!

Let's start with a langchain RAG chain that does not involve `ContextCite`.

# A simple RAG chain (without ContextCite)

First, we'll load a model and tokenizer (which we'll use later on with ContextCite too)

In [5]:
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
PROMPT_TEMPLATE = "Context: {context}\n\nQuery: {query}"
GENERATE_KWARGS = {"max_new_tokens": 512, "do_sample": False}

model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=ch.float16).cuda()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=model.device, **GENERATE_KWARGS)

Next, we'll create a RAG chain using a local `txt` file (a Wikipedia article about the Transformer architecture) as our "database" to keep things simple.

In [6]:
loader = TextLoader("solar_eclipse.txt")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = text_splitter.split_documents(documents)
embeddings = OpenAIEmbeddings()
db = FAISS.from_documents(texts, embeddings)
retriever = db.as_retriever()
messages = [{"role": "user", "content": PROMPT_TEMPLATE}]
chat_prompt_template = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
prompt = PromptTemplate.from_template(chat_prompt_template)
llm = HuggingFacePipeline(pipeline=pipe)

def format_docs(docs):
    return "\n\n".join([d.page_content for d in docs])

chain = (
    {"context": retriever | format_docs, "query": RunnablePassthrough()}
    | prompt
    | llm
    | StrOutputParser()
)

Created a chunk of size 1508, which is longer than the specified 1000
Created a chunk of size 1399, which is longer than the specified 1000
Created a chunk of size 1273, which is longer than the specified 1000
Created a chunk of size 1490, which is longer than the specified 1000
Created a chunk of size 1275, which is longer than the specified 1000
Created a chunk of size 1385, which is longer than the specified 1000
Created a chunk of size 1012, which is longer than the specified 1000


In [7]:
query = "Where was the longest duration of totality for this solar eclipse?"
output = chain.invoke(query)
response = output.split(f"<|assistant|>\n")[-1]
print(response)

The longest duration of totality for this solar eclipse was 4 minutes and 28 seconds near the Mexican town of Nazas, Durango.


# Adding in `ContextCite`

Now, we'll add ContextCite by wrapping the `prompt` and `llm` in a `ContextCiter` Runnable class. This class will take care of formatting the context and query, as well as run generation with our LLM.

In [8]:
from langchain_core.runnables.base import Runnable, Input, Output

In [9]:
class ContextCiteRunnable(Runnable):
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def invoke(self, context_and_query: Input, _: Input) -> Output:
        context = context_and_query["context"]
        query = context_and_query["query"]
        cc = ContextCiter(model, tokenizer, context, query)
        return cc.response, cc.get_attributions(as_dataframe=True, top_k=8)

In [10]:
cc_runnable = ContextCiteRunnable(model, tokenizer)
cc_chain = (
    {"context": retriever | format_docs, "query": RunnablePassthrough()}
    | cc_runnable
)

In [11]:
response, attributions = cc_chain.invoke(query)

Attributed: The longest duration of totality for this solar eclipse was 4 minutes and 28 seconds near the Mexican town of Nazas, Durango.</s>


  0%|          | 0/64 [00:00<?, ?it/s]

In [12]:
attributions

Unnamed: 0,Score,Source
0,45.271,"With a magnitude of 1.0566, the eclipse's longest duration of totality was 4 minutes and 28 seconds near the Mexican town of Nazas, Durango."
1,2.482,"This gave the eclipse a wider path of totality and more maximum time in totality (4 min 28 s) compared to the total eclipse in 2017 (2 min 40 s), which had a magnitude of 1.0306."
2,1.328,"Later, the total solar eclipse was visible from North America, starting from the west coast of Mexico then ascending in a northeasterly direction through Mexico, the United States, and Canada, before ending in the Atlantic Ocean about 700 kilometers southwest of Ireland."
3,0.827,"TOP: Solar prominences as seen from Third Connecticut Lake, New Hampshire - MIDDLE: Solar activity 08 April 2024 imaged by NASA Solar Dynamics Observatory AIA 304 telescope."
4,0.638,"A solar eclipse occurs when the Moon passes between Earth and the Sun, thereby obscuring the Sun."
5,0.621,"- BOTTOM: National Solar Observatory GONG telescope movie of solar activity in H-Alpha for the day of the April 8, 2024 eclipse, showing how prominences hardly changed during the eclipse."
6,0.375,Animation of the eclipse path (including the path of totality)
7,0.0,[2][3]
