# Segmantic Cache to Improve RAG Systems

A common production pitfall in RAG systems are similar calls to the system. Calls in a RAG System can be performance and time intensive due to the search through the database, agentic or reflective systems, and LLM Queries. Commonly in other production applications, cache layers are placed between the database and bussiness modules that stores frequently accessed data in order to improve performance. Similarly, RAG Systems also have requests that frequently asked. By inputting a cache layer between the user-input and database, we could reduce calls onto the database and improve performance. 

## Requirements

In this notebook we will be using ChromaDB and HuggingFace Models and Transformers to create our RAG System.

In [1]:
!pip install -q transformers==4.38.1
!pip install -q accelerate==0.27.2
!pip install -q sentence-transformers==2.5.1
!pip install -q xformers==0.0.24
!pip install -q chromadb==0.4.24
!pip install -q datasets==2.17.1

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
fastai 2.7.13 requires torch<2.2,>=1.10, but you have torch 2.2.0 which is incompatible.
torchaudio 2.1.2+cu121 requires torch==2.1.2+cu121, but you have torch 2.2.0 which is incompatible.
torchvision 0.16.2+cu121 requires torch==2.1.2+cu121, but you have torch 2.2.0 which is incompatible.


## Imports

In [3]:
import numpy as np
import pandas as pd

In [5]:
from getpass import getpass
if 'hf_key' not in locals():
  hf_key = getpass("Your Hugging Face API Key: ")
!huggingface-cli login --token $hf_key


Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to C:\Users\tuvshno\.cache\huggingface\token
Login successful


# Load Dataset

We are going to be using a medical dataset from the "A Question-Entailment Approach to Question Answering" Paper.

The data consists of questions, answers, and question types. 

In [6]:
from datasets import load_dataset

data = load_dataset("keivalya/MedQuad-MedicalQnADataset", split="train")

In [7]:
data

Dataset({
    features: ['qtype', 'Question', 'Answer'],
    num_rows: 16407
})

In [8]:
data = data.to_pandas()
data.head()

Unnamed: 0,qtype,Question,Answer
0,susceptibility,Who is at risk for Lymphocytic Choriomeningiti...,LCMV infections can occur after exposure to fr...
1,symptoms,What are the symptoms of Lymphocytic Choriomen...,LCMV is most commonly recognized as causing ne...
2,susceptibility,Who is at risk for Lymphocytic Choriomeningiti...,Individuals of all ages who come into contact ...
3,exams and tests,How to diagnose Lymphocytic Choriomeningitis (...,"During the first phase of the disease, the mos..."
4,treatment,What are the treatments for Lymphocytic Chorio...,"Aseptic meningitis, encephalitis, or meningoen..."


ChromaDB requires that data has a unique identifier, therefore we are going to create a unique identifier using the index of the row.

In [9]:
data["id"] = data.index
data.head(10)

Unnamed: 0,qtype,Question,Answer,id
0,susceptibility,Who is at risk for Lymphocytic Choriomeningiti...,LCMV infections can occur after exposure to fr...,0
1,symptoms,What are the symptoms of Lymphocytic Choriomen...,LCMV is most commonly recognized as causing ne...,1
2,susceptibility,Who is at risk for Lymphocytic Choriomeningiti...,Individuals of all ages who come into contact ...,2
3,exams and tests,How to diagnose Lymphocytic Choriomeningitis (...,"During the first phase of the disease, the mos...",3
4,treatment,What are the treatments for Lymphocytic Chorio...,"Aseptic meningitis, encephalitis, or meningoen...",4
5,prevention,How to prevent Lymphocytic Choriomeningitis (L...,LCMV infection can be prevented by avoiding co...,5
6,information,What is (are) Parasites - Cysticercosis ?,Cysticercosis is an infection caused by the la...,6
7,susceptibility,Who is at risk for Parasites - Cysticercosis? ?,Cysticercosis is an infection caused by the la...,7
8,exams and tests,How to diagnose Parasites - Cysticercosis ?,"If you think that you may have cysticercosis, ...",8
9,treatment,What are the treatments for Parasites - Cystic...,Some people with cysticercosis do not need to ...,9


In [10]:
data.shape


(16407, 4)

## Import ChromaDB

ChromaDB is a popular open source vector database. In RAG Systems, vector databases are used to convert documents into vectors and embbedd them into vector space. In order to find context similar to a user query, we can query the vector database to find similar contexts using various methods.

ChromaDB also allows you to save the vectors in storage, so you don't have to recompute them using a Persisted Storage.

In [27]:
import chromadb

In [28]:
chroma_client = chromadb.PersistentClient(path="chromadb")

## Filtering and Querying the ChromaDB Database

ChromaDB uses a system called `collections` for each dataset. 

We want to create the collection and delete any old ones if they exist.

In [30]:
collection_name = "medical_data"
for collection in chroma_client.list_collections():
    if collection_name == collection.name:
        chroma_client.delete_collection(name=collection_name)
collection = chroma_client.create_collection(name=collection_name)

Now we want to convert our data into embeddings and add them to our collection. Because we have so much data, I am going to embed the data in batches. Each document that we add will have an `Answer`, `qtype`, and an `id`

In [31]:
batch_size = 5000

for i in range(0, data.shape[0], batch_size):
    subset_data = data
    if i+batch_size > data.shape[0]:
        subset_data = data[i:data.shape[0]-1]
       
    else:
        subset_data = data[i:i+batch_size-1]
    collection.add(
        documents=subset_data["Answer"].tolist(),
        metadatas=[{"qtype": qtype} for qtype in subset_data["qtype"].tolist()],
        ids=[f"id{id}" for id in subset_data["id"].tolist()]
    )   

Now we need a method that will query the chromadb collection. We can do this with a `query_database` function that takes in the `query_text` and the number of results `n_results`.

In [32]:
def query_database(query_text, n_results=10):
    results = collection.query(query_texts=query_text, n_results=n_results)
    return results

## Creating a Semantic Search Cache

Now in order to create the cache, we are going to use FAISS which is one of the most popular similarity search implementations.

By using FAISS, we can create vector embedding of our documents and embed them into an index and utilize a query vector to find similar vectors in that index.

In [66]:
!pip install -q faiss-cpu==1.8.0



In [33]:
import faiss
from sentence_transformers import SentenceTransformer
import time
import json

We are going to use `IndexFlatL2`, which computes the L2 (Eucledian Distance) between all the vectors in the index and the query vector. Its simple and accurate (but not very fast) though perfect for our use case.

![IndexFlatL2](https://cdn.sanity.io/images/vr8gru94/production/ea951a4be3acf9d379cc6f922be1468b37b7f9e5-1280x720.png)

We are also going to be using the SentenceTransformer `all-mpnet-base-v2` in order to create vector embeddings with a dimension of `768`.

In [34]:
def init_cache():
    index = faiss.IndexFlatL2(768)
    if index.is_trained:
        print("index trained")
    
    # Initialize Sentence Transformer Model
    encoder = SentenceTransformer("all-mpnet-base-v2")
    
    return index, encoder

We are going to save and retrieve the cache by saving it locally in a JSON File.

In [35]:
def retrieve_cache(json_file):
    try:
        with open(json_file, "r") as file:
            cache = json.load(file)
    except FileNotFoundError:
        cache = {"questions": [], "embeddings": [], "answers": [], "response_text": []}
    
    return cache

In [36]:
def store_cache(json_file, cache):
    with open(json_file, "w") as file:
        json.dump(cache, file)

This is the code for the `semantic_cache`. 

We are going to intialize the `index` for caching and the `encoder` for creating embeddings.

`euclidean_threshold` will hold the maximum `threshold` before the similarity is not valid.

`json_file` and `cache` will hold the local json file and cache for the semantic_cache.

`ask()` will take a `question` from the user and look for it in the cache. If it doesn't find a good enough match in the cache, it will query the ChromaDB database and then save it to the cache and return the `response_text`.

`index.nprobe` is the number of nearby cells to search from the nearby Voronoi cells in the index. 

`index.search` returns a distances and indecies array. 

`D` is the distances array which is an array of distances of the vectors from the query vector.

`I` is the indecies array which is an array of indecies of the vectors inside of the FAISS index.

In [37]:
class semantic_cache:
    def __init__(self, json_file="cache_file.json", threshold=0.35):
        # Initialize Cache with Euclidean Distance
        self.index, self.encoder = init_cache()
        
        # Set Euclidean Distance Threshold
        # A Distance of 0 means identical sentences
        # We will only accept under this threshold
        self.euclidean_threshold = threshold
        
        # Initialize the Cache Json File
        self.json_file = json_file
        self.cache = retrieve_cache(json_file)

    def ask(self, question: str) -> str:
        # Method to retrieve a cache or generate a new one    

        start_time = time.time()
        try:
            embedding = self.encoder.encode([question]) # Encode the question
            
            self.index.nprobe = 8 # How many nearby cells to search
            D, I = self.index.search(embedding, 1) # Distances and Indecies array from the index search, return 1 vector
            
            if D[0] >= 0: # Have valid distance values in the distance array
                # if we have a valid indecy and a distance is under the threshold
                if I[0][0] >= 0 and D[0][0] <= self.euclidean_threshold:  
                    row_id = int(I[0][0]) # Get the row id
                    
                    # Return the response text from the row id in the cache 
                    print("Answer recovered from Cache. ")
                    print(f"{D[0][0]:.3f} smaller than {self.euclidean_threshold}")
                    print(f"Found cache in row: {row_id} with score {D[0][0]:.3f}")
                    print(f"response_text: " + self.cache["response_text"][row_id])

                    end_time = time.time()
                    elapsed_time = end_time - start_time
                    print(f"Time taken: {elapsed_time:.3f} seconds")
                    return self.cache["response_text"][row_id]

            # Handle the case when there are not enough results
            # or Euclidean distance is not met, asking to chromaDB.
            answer = query_database([question], 1) # Query the database for 1 response
            response_text = answer["documents"][0][0] # Get the response text from that query

            # Add the response to the cache
            self.cache["questions"].append(question) 
            self.cache["embeddings"].append(embedding[0].tolist())
            self.cache["answers"].append(answer)
            self.cache["response_text"].append(response_text)

            print("Answer recovered from ChromaDB. ")
            print(f"response_text: {response_text}")

            self.index.add(embedding)
            store_cache(self.json_file, self.cache)
            end_time = time.time()
            elapsed_time = end_time - start_time
            print(f"Time taken: {elapsed_time:.3f} seconds")

            return response_text
        except Exception as e:
            raise RuntimeError(f"Error during 'ask' method: {e}")
        

## Testing

Now lets initialize the semantic cache and save it to a specific json file.

In [38]:
# Initialize the cache.
cache = semantic_cache("4cache.json")

index trained


We are going to ask the cache a question. If the cache cannot find a match with a similarity under the threshold, it will ask the chroma database instead and then save the response to the cache.

In [39]:
results = cache.ask("How do vaccines work?")

Answer recovered from ChromaDB. 
response_text: Summary : Shots may hurt a little, but the diseases they can prevent are a lot worse. Some are even life-threatening. Immunization shots, or vaccinations, are essential. They protect against things like measles, mumps, rubella, hepatitis B, polio, tetanus, diphtheria, and pertussis (whooping cough). Immunizations are important for adults as well as children.    Your immune system helps your body fight germs by producing substances to combat them. Once it does, the immune system "remembers" the germ and can fight it again. Vaccines contain germs that have been killed or weakened. When given to a healthy person, the vaccine triggers the immune system to respond and thus build immunity.     Before vaccines, people became immune only by actually getting a disease and surviving it. Immunizations are an easier and less risky way to become immune.     NIH: National Institute of Allergy and Infectious Diseases
Time taken: 0.093 seconds


Notice because we don't have anything in the cache yet, it asked the chromadb database first.

In [40]:
results = cache.ask("Explain briefly what is a Sydenham chorea")

Answer recovered from ChromaDB. 
response_text: Sydenham chorea (SD) is a neurological disorder of childhood resulting from infection via Group A beta-hemolytic streptococcus (GABHS), the bacterium that causes rheumatic fever. SD is characterized by rapid, irregular, and aimless involuntary movements of the arms and legs, trunk, and facial muscles. It affects girls more often than boys and typically occurs between 5 and 15 years of age. Some children will have a sore throat several weeks before the symptoms begin, but the disorder can also strike up to 6 months after the fever or infection has cleared. Symptoms can appear gradually or all at once, and also may include uncoordinated movements, muscular weakness, stumbling and falling, slurred speech, difficulty concentrating and writing, and emotional instability. The symptoms of SD can vary from a halting gait and slight grimacing to involuntary movements that are frequent and severe enough to be incapacitating. The random, writhing mo

This question is also totally different to the question inside the cache already, therefore it asked the chromadb database instead.

In [41]:
results = cache.ask("Briefly explain me what is a Sydenham chorea.")


Answer recovered from Cache. 
0.028 smaller than 0.35
Found cache in row: 1 with score 0.028
response_text: Sydenham chorea (SD) is a neurological disorder of childhood resulting from infection via Group A beta-hemolytic streptococcus (GABHS), the bacterium that causes rheumatic fever. SD is characterized by rapid, irregular, and aimless involuntary movements of the arms and legs, trunk, and facial muscles. It affects girls more often than boys and typically occurs between 5 and 15 years of age. Some children will have a sore throat several weeks before the symptoms begin, but the disorder can also strike up to 6 months after the fever or infection has cleared. Symptoms can appear gradually or all at once, and also may include uncoordinated movements, muscular weakness, stumbling and falling, slurred speech, difficulty concentrating and writing, and emotional instability. The symptoms of SD can vary from a halting gait and slight grimacing to involuntary movements that are frequent and

This question is similar to a question in the cache and therefore returned the response that was stored in the cache.

Notice how the time it took to query the cache took less time than to query the chromadb database.

In [42]:
question_def = "Write in 20 words what is a Sydenham chorea."
results = cache.ask(question_def)


Answer recovered from Cache. 
0.228 smaller than 0.35
Found cache in row: 1 with score 0.228
response_text: Sydenham chorea (SD) is a neurological disorder of childhood resulting from infection via Group A beta-hemolytic streptococcus (GABHS), the bacterium that causes rheumatic fever. SD is characterized by rapid, irregular, and aimless involuntary movements of the arms and legs, trunk, and facial muscles. It affects girls more often than boys and typically occurs between 5 and 15 years of age. Some children will have a sore throat several weeks before the symptoms begin, but the disorder can also strike up to 6 months after the fever or infection has cleared. Symptoms can appear gradually or all at once, and also may include uncoordinated movements, muscular weakness, stumbling and falling, slurred speech, difficulty concentrating and writing, and emotional instability. The symptoms of SD can vary from a halting gait and slight grimacing to involuntary movements that are frequent and

Again, we see a cached response with an decreased time to recieve a response.

## Initalize the Model

Now that we have a way to grab context for a query question efficiently, we want to ask the LLM the question.

We are going to be using the Google Gemma-2b-init model.

To do this, we will be using the HuggingFace `transformers` library and using the:


`AutoTokenizer` in order to convert our user queries into tokens valid for the model


`AutoModelForCausalLM` to initialize a pretrained version of the model, perfect for our use case.

In [43]:
from torch import cuda, torch

device = f"cuda:{cuda.current_device()}" if cuda.is_available() else "cpu"
device

'cuda:0'

In [44]:
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "google/gemma-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model=AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

In order to query the model efficiently, we will create a prompt template that inputs the context we retrieve from the cache or database and the user query.

In [45]:
prompt_template = f"Relevant Context: {results}\n\n The user's question: {question_def}"
prompt_template

"Relevant Context: Sydenham chorea (SD) is a neurological disorder of childhood resulting from infection via Group A beta-hemolytic streptococcus (GABHS), the bacterium that causes rheumatic fever. SD is characterized by rapid, irregular, and aimless involuntary movements of the arms and legs, trunk, and facial muscles. It affects girls more often than boys and typically occurs between 5 and 15 years of age. Some children will have a sore throat several weeks before the symptoms begin, but the disorder can also strike up to 6 months after the fever or infection has cleared. Symptoms can appear gradually or all at once, and also may include uncoordinated movements, muscular weakness, stumbling and falling, slurred speech, difficulty concentrating and writing, and emotional instability. The symptoms of SD can vary from a halting gait and slight grimacing to involuntary movements that are frequent and severe enough to be incapacitating. The random, writhing movements of chorea are caused 

In order to send it to the LLM, we need to tokenize our inputs using the tokenizer.

In [47]:
input_ids = tokenizer(prompt_template, return_tensors="pt").to("cuda")

We then need to decode the output given from the model using the `decode` method from the tokenizer.

In [51]:
outputs = model.generate(**input_ids, max_new_tokens=256)
print(tokenizer.decode(outputs[0]))

<bos>Relevant Context: Sydenham chorea (SD) is a neurological disorder of childhood resulting from infection via Group A beta-hemolytic streptococcus (GABHS), the bacterium that causes rheumatic fever. SD is characterized by rapid, irregular, and aimless involuntary movements of the arms and legs, trunk, and facial muscles. It affects girls more often than boys and typically occurs between 5 and 15 years of age. Some children will have a sore throat several weeks before the symptoms begin, but the disorder can also strike up to 6 months after the fever or infection has cleared. Symptoms can appear gradually or all at once, and also may include uncoordinated movements, muscular weakness, stumbling and falling, slurred speech, difficulty concentrating and writing, and emotional instability. The symptoms of SD can vary from a halting gait and slight grimacing to involuntary movements that are frequent and severe enough to be incapacitating. The random, writhing movements of chorea are cau

And just like that we recieve an output based on the context and user query!

## Conclusion

As you can see there has been an 50% reduction in the retrieval between querying the ChromaDB database and accessing the cache.

This was a simple integration of a cache layer. In production, there would be more instances of cache classes possibly based on user typology and there could be more data in the database. 

There could be further improvements to the cache layer by adding quantization or or different indexing methods inside the FAISS Cache.

In the end, we created a simple RAG System enhanced with a semantic cache layer.