# Retrieving Original Documents via Summaries with Weaviate and LangChain

This notebook contains the code implementation for my blog post on [dev.to](https://dev.to/oleh-halytskyi/retrieving-original-documents-via-summaries-with-weaviate-and-langchain-2a68). It demonstrates how to use [Weaviate's cross-references](https://weaviate.io/developers/weaviate/manage-data/cross-references) to retrieve original documents by querying their summaries within the LangChain framework. This approach is inspired by the technique described in [Optimizing RAG Context: Chunking and Summarization for Technical Docs](https://dev.to/oleh-halytskyi/optimizing-rag-context-chunking-and-summarization-for-technical-docs-3pel), where Chroma VectorDB was used for [querying summaries while retrieving original documents](https://dev.to/oleh-halytskyi/optimizing-rag-context-chunking-and-summarization-for-technical-docs-3pel#querying-summaries-but-retrieving-original-text). Here, Weaviate's cross-reference capabilities are utilized instead.

## Preparing the Environment

First, set up a [Conda](https://conda.io/projects/conda/en/latest/user-guide/install/macos.html) environment to manage dependencies and keep the project isolated.

```bash
# Create a new environment called 'rag-env'
conda create -n rag-env python=3.12

# Activate the environment
conda activate rag-env

# Install necessary packages
pip install weaviate-client==4.9.0 \
    langchain==0.3.5 \
    langchain-core==0.3.13 \
    langchain-ollama==0.2.0
```

Additionally, the following are required:

- [Weaviate](https://weaviate.io), which can be installed locally using [Docker](https://weaviate.io/developers/weaviate/installation/docker-compose#starter-docker-compose-file) or a [Kubernetes](https://weaviate.io/developers/weaviate/installation/kubernetes) cluster. It can also be accessed via the [Weaviate Cloud](https://weaviate.io/developers/wcs).
- [Ollama](https://ollama.com), specifically the [mxbai-embed-large](https://ollama.com/library/mxbai-embed-large) model for creating embedding vectors and [llama3.1](https://ollama.com/library/llama3.1) for example of simple RAG.

## Initializing the Weaviate Client

Import necessary libraries and initialize the Weaviate client with authentication.

In [1]:
import getpass
import weaviate
from weaviate.classes.init import Auth

# Prompt for the Weaviate API key
WEAVIATE_API_KEY = getpass.getpass()

# Initialize the Weaviate client with authentication
weaviate_client = weaviate.connect_to_local(
    auth_credentials=Auth.api_key(WEAVIATE_API_KEY)
)

# Check if the client is ready
print("Client is Ready?", weaviate_client.is_ready())

Client is Ready? True


## Importing Data with Cross-References

Load the original and summarized documents, create Weaviate collections, and insert the data with cross-references.

**Note:** I took [chunked_docs.json](./files/chunked_docs.json) and [summarized_docs.json](./files/summarized_docs.json) files from the [Optimizing RAG Context: Chunking and Summarization for Technical Docs](https://dev.to/oleh-halytskyi/optimizing-rag-context-chunking-and-summarization-for-technical-docs-3pel) blog post. They were created by adding next code to the [Summarization Based on Headers and Chunk Text](https://dev.to/oleh-halytskyi/optimizing-rag-context-chunking-and-summarization-for-technical-docs-3pel#summarization-based-on-headers-and-chunk-text):

```python
# Save the chunked and summarized documents to JSON files
import json
chunked_docs_json = [{'page_content': doc.page_content, 'metadata': doc.metadata} for doc in chunked_docs]
with open('files/generated/chunked_docs.json', 'w') as f:
    json.dump(chunked_docs_json, f, indent=4)

summarized_docs_json = [{'page_content': doc.page_content, 'metadata': doc.metadata} for doc in summarized_docs]
with open('files/generated/summarized_docs.json', 'w') as f:
    json.dump(summarized_docs_json, f, indent=4)
```

In the [summarized_docs.json](./files/summarized_docs.json) file, `metadata.id` was changed to the `metadata.doc_id` to avoid conflicts with Weaviate's `id` field.

In [2]:
import json
from langchain_ollama import OllamaEmbeddings

# Load the chunked and summarized documents
with open("files/chunked_docs.json", "r") as f:
    chunked_docs = json.load(f)

with open("files/summarized_docs.json", "r") as f:
    summarized_docs = json.load(f)

# Define the collection names
collections = {
    "original": "OriginalDocuments",
    "summary": "SummarizedDocuments",
}

# Delete collections if they already exist
for collection in collections.values():
    if collection in weaviate_client.collections.list_all(simple=True):
        weaviate_client.collections.delete(collection)

# Create the collections
original_collection_db = weaviate_client.collections.create(collections["original"])
summary_collection_db = weaviate_client.collections.create(collections["summary"])

# Initialize the Ollama embedding model
ollama_emb = OllamaEmbeddings(model="mxbai-embed-large")

# Insert the documents into the collections
for summarized_doc in summarized_docs:
    summarized_doc_id = summarized_doc["metadata"]["doc_id"]
    original_doc = next((doc for doc in chunked_docs if doc.get("metadata", {}).get("summary_id") == summarized_doc_id), None)

    if original_doc:
        original_uuid = original_collection_db.data.insert(
            {
                "page_content": original_doc["page_content"],
            },
            vector=ollama_emb.embed_query(original_doc["page_content"]),
        )
        summary_collection_db.data.insert(
            {
                "page_content": summarized_doc["page_content"],
            },
            references={"originalDocument": original_uuid},
            vector=ollama_emb.embed_query(summarized_doc["page_content"]),
        )

# Verify the number of documents in the collections
original_count = len(original_collection_db)
summary_count = len(summary_collection_db)
print(f"Number of documents in the original collection: {original_count}")
print(f"Number of documents in the summary collection: {summary_count}")

Number of documents in the original collection: 34
Number of documents in the summary collection: 34


## Querying with Cross-References

Define a function to retrieve documents by querying the summaries and obtaining the original documents via cross-references.

In [3]:
from weaviate.classes.query import QueryReference, MetadataQuery

# Define a function to retrieve documents
def retrieve_documents(query, vector, limit=2, score_threshold=0.8):
    response = summary_collection_db.query.hybrid(
        query,
        vector=vector,
        limit=limit,
        return_references=QueryReference(link_on="originalDocument"),
        return_metadata=MetadataQuery(score=True),
    )

    summary_docs = []
    original_docs = []
    for o in response.objects:
        if o.metadata.score is not None and o.metadata.score >= score_threshold:
            summary_doc = {"page_content": o.properties["page_content"]}
            summary_docs.append(summary_doc)
            for ref_obj in o.references["originalDocument"].objects:
                original_doc = {"page_content": ref_obj.properties["page_content"]}
                original_docs.append(original_doc)

    return summary_docs, original_docs

# Define a query
query = "I want to write a Python script that prints numbers from 1 to 30."
vector = ollama_emb.embed_query(query)
summary_docs, original_docs = retrieve_documents(query, vector)

# Print the summarized and original documents
print("Summarized Documents:")
for i, doc in enumerate(summary_docs, start=1):
    print(f"Summarized Document #{i}")
    print("--------------------")
    print(doc["page_content"])
    print("--------------------")
    print()

print("Original Documents:")
for i, doc in enumerate(original_docs, start=1):
    print(f"Original Document #{i}")
    print("--------------------")
    print(doc["page_content"])
    print("--------------------")
    print()

Summarized Documents:
Summarized Document #1
--------------------
The `break` statement exits the innermost enclosing for or while loop, stopping execution of the loop and continuing with the next statement. This is demonstrated by a nested for loop that prints factors of numbers from 2 to 9, where the break statement stops the loop when a factor is found. The `continue` statement skips the rest of the current iteration in a loop and moves on to the next one, as shown by a for loop that iterates over numbers from 2 to 9, printing even numbers and skipping odd ones.
--------------------

Summarized Document #2
--------------------
The built-in `range()` function generates arithmetic progressions that can be used for iteration over a sequence of numbers. It takes three parameters: start point, end point, and step (default is 1), and returns an iterator that produces the specified range of values. The end point is never part of the generated sequence. To iterate over the indices of a sequ

## Creating a Custom Retriever

Implement a [custom retriever](https://python.langchain.com/docs/how_to/custom_retriever) that integrates with LangChain and leverages Weaviate's cross-references to fetch the original documents based on summary queries.

In [4]:
from typing import List, Any
from langchain_core.retrievers import BaseRetriever
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from weaviate.classes.query import QueryReference, MetadataQuery

class VectorDBRetrieverCrossReferences(BaseRetriever):
    """A custom retriever that retrieves documents from a Weaviate vector database."""
    summary_collection_db: Any
    ollama_emb: Any
    k: int = 2
    score_threshold: float = 0.8
    return_source_documents: bool = False

    def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:
        """Sync implementation for retriever."""
        vector = self.ollama_emb.embed_query(query)
        response = self.summary_collection_db.query.hybrid(
            query,
            vector=vector,
            limit=self.k,
            return_references=QueryReference(link_on="originalDocument"),
            return_metadata=MetadataQuery(score=True),
        )

        original_docs = []
        for o in response.objects:
            if o.metadata.score is not None and o.metadata.score >= self.score_threshold:
                for ref_obj in o.references["originalDocument"].objects:
                    doc_content = ref_obj.properties["page_content"]
                    metadata = {"source": ref_obj.properties.get("source")} if self.return_source_documents else {}
                    original_docs.append(Document(page_content=doc_content, metadata=metadata))

        return original_docs

# Initialize the custom retriever
retriever = VectorDBRetrieverCrossReferences(
    summary_collection_db=summary_collection_db,
    ollama_emb=ollama_emb
)

# Retrieve documents using the custom retriever
query = "I want to write a Python script that prints numbers from 1 to 30."
documents = retriever.invoke(query)

# Print the retrieved documents
for i, doc in enumerate(documents, start=1):
    print(f"Document #{i}")
    print("--------------------")
    print(doc.page_content)
    print("--------------------")
    print()

Document #1
--------------------
The [`break`](../reference/simple_stmts.html#break) statement breaks out of the innermost enclosing
[`for`](../reference/compound_stmts.html#for) or [`while`](../reference/compound_stmts.html#while) loop:

```
>>> for n in range(2, 10):
...     for x in range(2, n):
...         if n % x == 0:
...             print(f"{n} equals {x} * {n//x}")
...             break
...
4 equals 2 * 2
6 equals 2 * 3
8 equals 2 * 4
9 equals 3 * 3
```

The [`continue`](../reference/simple_stmts.html#continue) statement continues with the next
iteration of the loop:

```
>>> for num in range(2, 10):
...     if num % 2 == 0:
...         print(f"Found an even number {num}")
...         continue
...     print(f"Found an odd number {num}")
...
Found an even number 2
Found an odd number 3
Found an even number 4
Found an odd number 5
Found an even number 6
Found an odd number 7
Found an even number 8
Found an odd number 9
```
--------------------

Document #2
--------------------
I

## Example of Simple Retrieval-Augmented Generation (RAG)

This section demonstrates how to build a simple RAG pipeline using the custom retriever and a language model.

In [5]:
from langchain_ollama.chat_models import ChatOllama
from langchain.chains import create_retrieval_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain

# Initialize the ChatOllama model
llm = ChatOllama(model="llama3.1", temperature=0, num_ctx=16384)

# Define the system prompt template
system_prompt = (
    "You are an assistant for answering questions. "
    "Use only the exact information provided in the context, do not include external knowledge or guesses. "
    "If the answer cannot be inferred from the context, reply: 'I don't know based on the provided context.' "
    "Do not provide answers that are not based on the context, including code examples or references to other libraries. "
    "Format your entire response in valid Markdown, including code snippets and links. "
    "Always adhere to these rules strictly.\n\n"
    "Context: \n"
    "{context}"
)

# Define the chat prompt
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt),
        ("human", "{input}"),
    ]
)

# Create the question-answer chain
question_answer_chain = create_stuff_documents_chain(llm, prompt)

# Create the retrieval-augmented generation (RAG) chain
rag_chain = create_retrieval_chain(retriever, question_answer_chain)

# Query #1
query = "I want to write a Python script that prints numbers from 1 to 30."
response = rag_chain.invoke({"input": query})

# Print the response
print("Example #1")
print("--------------------")
print(f"Query: {query}")
print(f"Answer: {response['answer']}")
print("--------------------")
print("\n")

# Query #2 (check that context only is used)
query = "I want to write a Go script that prints numbers from 1 to 30."
response = rag_chain.invoke({"input": query})

# Print the response
print("Example #2")
print("--------------------")
print(f"Query: {query}")
print(f"Answer: {response['answer']}")
print("--------------------")

Example #1
--------------------
Query: I want to write a Python script that prints numbers from 1 to 30.
Answer: You can use the `range()` function in Python to generate a sequence of numbers and print them.

Here's how you can do it:

```
for i in range(1, 31):
    print(i)
```

This will print numbers from 1 to 30. The `range()` function generates numbers starting from 0 by default, so we start at 1 and end at 30 (which is exclusive).
--------------------


Example #2
--------------------
Query: I want to write a Go script that prints numbers from 1 to 30.
Answer: I don't know based on the provided context. The given text is about Python programming and does not provide any information about writing a Go script. If you need help with a specific task, I'll be happy to assist you in another way.
--------------------
