# Implementing semantic cache to improve a RAG system

<!-- IMPLEMENTING SEMANTIC CACHE TO IMPROVE A RAG SYSTEM -->

In this notebook, we will explore a typical RAG solution where we will utilize an open-source model and the vector database Chroma DB. **However, we will integrate a semantic cache system that will store various user queries and decide whether to generate the prompt enriched with information from the vector database or the cache.**

A semantic caching system aims to identify similar or identical user requests. When a matching request is found, the system retrieves the corresponding information from the cache, reducing the need to fetch it from the original source.

As the comparison takes into account the semantic meaning of the requests, they don't have to be identical for the system to recognize them as the same question.  They can be formulated differently or contain inaccuracies, be they typographical or in the sentence structure, and we can identify that the user is actually requesting the same information.

For instance, queries like **What is the capital of France?**, **Tell me the name of the capital of France?**, and **What The capital of France is?** all convey the same intent and should be identified as the same question.

While the model's response may differ based on the request for a concise answer in the second example, the information retrieved from the vector database should be the same. This is why I'm placing the cache system between the user and the vector database, not between the user and the Large Language Model.

Most tutorials that guide you through creating a RAG system are designed for single-user use, meant to operate in a testing environment. In other words, within a notebook, interacting with a local vector database and making API calls or using a locally stored model.

This architecture quickly becomes insufficient when attempting to transition one of these models to production, where they might encounter from tens to thousands of recurrent requests.

One way to enhance performance is through one or multiple semantic caches. This cache retains the results of previous requests, and before resolving a new request, it checks if a similar one has been received before. If so, instead of re-executing the process, it retrieves the information from the cache.

In a RAG system, there are two points that are time consuming:

Retrieve the information used to construct the enriched prompt:
Call the Large Language Model to obtain the response.
In both points, a semantic cache system can be implemented, and we could even have two caches, one for each point.

Placing it at the model's response point may lead to a loss of influence over the obtained response. Our cache system could consider "Explain the French Revolution in 10 words" and "Explain the French Revolution in a hundred words" as the same query. If our cache system stores model responses, users might think that their instructions are not being followed accurately.

But both requests will require the same information to enrich the prompt. This is the main reason why I chose to place the semantic cache system between the user's request and the retrieval of information from the vector database.

However, this is a design decision. Depending on the type of responses and system requests, it can be placed at one point or another. It's evident that caching model responses would yield the most time savings, but as I've already explained, it comes at the cost of losing user influence over the response.


# Import and load the libraries.
To start we need to install the necesary Python packages.
* **[sentence transformers](http:/www.sbert.net/)**. This library is necessary to transform the sentences into fixed-length vectors, also know as embeddings.
* **[xformers](https://github.com/facebookresearch/xformers)**. it's a package that provides libraries an utilities to facilitate the work with transformers models. We need to install in order to avoid an error when we work with the model and embeddings.  
* **[chromadb](https://www.trychroma.com/)**. This is our vector Database. ChromaDB is easy to use and open source, maybe the most used Vector Database used to store embeddings.
* **[accelerate](https://github.com/huggingface/accelerate)** Necesary to run the Model in a GPU.  

In [4]:
# !pip install -q transformers==4.38.1
# !pip install -q accelerate==0.27.2
# !pip install -q sentence-transformers==2.5.1
# !pip install -q xformers==0.0.24
# !pip install -q chromadb==0.4.24
# !pip install -q datasets==2.17.1

In [2]:
import numpy as np
import pandas as pd

# Load the Dataset
As we are working in a free and limited space, and we can use just a few gb of memory I limited the number of rows to use from the Dataset with the variable `MAX_ROWS`.



In [3]:
#Login to Hugging Face. It is mandatory to use the Gemma Model,
#and recommended to acces public models and Datasets.
from getpass import getpass
if 'hf_key' not in locals():
  hf_key = getpass("Your Hugging Face API Key: ")
!huggingface-cli login --token $hf_key

Your Hugging Face API Key: ··········
Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [4]:
from datasets import load_dataset

data = load_dataset("keivalya/MedQuad-MedicalQnADataset", split='train')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading readme:   0%|          | 0.00/233 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/22.5M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [5]:
data = data.to_pandas()
data["id"]=data.index
data.head(10)

Unnamed: 0,qtype,Question,Answer,id
0,susceptibility,Who is at risk for Lymphocytic Choriomeningiti...,LCMV infections can occur after exposure to fr...,0
1,symptoms,What are the symptoms of Lymphocytic Choriomen...,LCMV is most commonly recognized as causing ne...,1
2,susceptibility,Who is at risk for Lymphocytic Choriomeningiti...,Individuals of all ages who come into contact ...,2
3,exams and tests,How to diagnose Lymphocytic Choriomeningitis (...,"During the first phase of the disease, the mos...",3
4,treatment,What are the treatments for Lymphocytic Chorio...,"Aseptic meningitis, encephalitis, or meningoen...",4
5,prevention,How to prevent Lymphocytic Choriomeningitis (L...,LCMV infection can be prevented by avoiding co...,5
6,information,What is (are) Parasites - Cysticercosis ?,Cysticercosis is an infection caused by the la...,6
7,susceptibility,Who is at risk for Parasites - Cysticercosis? ?,Cysticercosis is an infection caused by the la...,7
8,exams and tests,How to diagnose Parasites - Cysticercosis ?,"If you think that you may have cysticercosis, ...",8
9,treatment,What are the treatments for Parasites - Cystic...,Some people with cysticercosis do not need to ...,9


In [6]:
MAX_ROWS = 15000
DOCUMENT="Answer"
TOPIC="qtype"


ChromaDB requires that the data has a unique identifier. We can make it with this statement, which will create a new column called **Id**.


In [7]:
#Because it is just a sample we select a small portion of News.
subset_data = data.head(MAX_ROWS)

# Import and configure the Vector Database
I'm going to use ChromaDB, the most popular OpenSource vector Database.

First we need to import ChromaDB, and after that import the **Settings** class from **chromadb.config** module. This class allows us to change the setting for the ChromaDB system, and customize its behavior.

In [8]:
import chromadb
from chromadb.config import Settings

Now we only need to indicate the path where the vector database will be stored.

In [9]:
chroma_client = chromadb.PersistentClient(path="/path/to/persist/directory")

# Filling and Querying the ChromaDB Database
The Data in ChromaDB is stored in collections. If the collection exist we need to delete it.

In the next lines, we are creating the collection by calling the ***create_collection*** function in the ***chroma_client*** created above.

In [10]:
collection_name = "news_collection"
if len(chroma_client.list_collections()) > 0 and collection_name in [chroma_client.list_collections()[0].name]:
        chroma_client.delete_collection(name=collection_name)

collection = chroma_client.create_collection(name=collection_name)


It's time to add the data to the collection. Using the function ***add*** we need to inform, at least ***documents***, ***metadatas*** and ***ids***.
* In the **document** we store the big text, it's a different column in each Dataset.
* In **metadatas**, we can informa a list of topics.
* In **id** we need to inform an unique identificator for each row. It MUST be unique! I'm creating the ID using the range of MAX_ROWS.


In [12]:
collection.add(
    documents=subset_data[DOCUMENT].tolist(),
    metadatas=[{TOPIC: topic} for topic in subset_data[TOPIC].tolist()],
    ids=[f"id{x}" for x in range(MAX_ROWS)],
)

/root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx.tar.gz: 100%|██████████| 79.3M/79.3M [00:08<00:00, 10.2MiB/s]


Once we have the information inside the Database we can query It, and ask for data that matches our needs. The search is done inside the content of the document, and it dosn't look for the exact word, or phrase. The results will be based on the similarity between the search terms and the content of documents.

The metadata is not used in the search, but they can be utilized for filtering or refining the results after the initial search.

Let's define a function to query the ChromaDB Database.

In [13]:
def query_database(query_text, n_results=10):
    results = collection.query(query_texts=query_text, n_results=n_results )
    return results

## Creating the semantic cache system
To implement the cache system, we will use Faiss, a library that allows storing embeddings in memory. It's quite similar to what Chroma does, but without its persistence.

For this purpose, we will create a class called semantic_cache that will work with its own encoder and provide the necessary functions for the user to perform queries.

In this class, we first query Faiss (the cache), and if the returned results are above the specified threshold, it will return the result from the cache. Otherwise, it will fetch the result from the Chroma database.

The cache is stored in .json file.

In [14]:
!pip install -q faiss-cpu==1.8.0

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m27.0/27.0 MB[0m [31m62.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [15]:
import faiss
from sentence_transformers import SentenceTransformer
import time
import json

This function initializes the semantic cache.

It employs the FlatLS index, which might not be the fastest but is ideal for small datasets. Depending on the characteristics of the data intended for the cache and the expected dataset size, another index such as HNSW or IVF could be utilized.

In [16]:
def init_cache():
  index = faiss.IndexFlatL2(768)
  if index.is_trained:
            print('Index trained')

  # Initialize Sentence Transformer model
  encoder = SentenceTransformer('all-mpnet-base-v2')

  return index, encoder

In the `retrieve_cache` function, the .json file is retrieved from disk in case there is a need to reuse the cache across sessions.

In [17]:
def retrieve_cache(json_file):
      try:
          with open(json_file, 'r') as file:
              cache = json.load(file)
      except FileNotFoundError:
          cache = {'questions': [], 'embeddings': [], 'answers': [], 'response_text': []}

      return cache

The `store_cache` function saves the file containing the cache data to disk.

In [18]:
def store_cache(json_file, cache):
  with open(json_file, 'w') as file:
        json.dump(cache, file)

These functions will be used within the `SemanticCache` class, which includes the search function and its initialization function.

Even though the `ask` function has a substantial amount of code, its purpose is quite straightforward. It looks in the cache for the closest question to the one just made by the user.

Afterward, checks if it is within the specified threshold. If positive, it directly returns the response from the cache; otherwise, it calls the `query_database` function to retrieve the data from ChromaDB.

I've used Euclidean distance instead of Cosine, which is widely employed in vector comparisons. This choice is based on the fact that Euclidean distance is the default metric used by Faiss. Although Cosine distance can also be calculated, doing so adds complexity that may not significantly contribute to the final result.


In [19]:
class semantic_cache:
  def __init__(self, json_file="cache_file.json", thresold=0.35):
      # Initialize Faiss index with Euclidean distance
      self.index, self.encoder = init_cache()

      # Set Euclidean distance threshold
      # a distance of 0 means identicals sentences
      # We only return from cache sentences under this thresold
      self.euclidean_threshold = thresold

      self.json_file = json_file
      self.cache = retrieve_cache(self.json_file)

  def ask(self, question: str) -> str:
      # Method to retrieve an answer from the cache or generate a new one
      start_time = time.time()
      try:
          #First we obtain the embeddings corresponding to the user question
          embedding = self.encoder.encode([question])

          # Search for the nearest neighbor in the index
          self.index.nprobe = 8
          D, I = self.index.search(embedding, 1)

          if D[0] >= 0:
              if I[0][0] >= 0 and D[0][0] <= self.euclidean_threshold:
                  row_id = int(I[0][0])

                  print('Answer recovered from Cache. ')
                  print(f'{D[0][0]:.3f} smaller than {self.euclidean_threshold}')
                  print(f'Found cache in row: {row_id} with score {D[0][0]:.3f}')
                  print(f'response_text: ' + self.cache['response_text'][row_id])

                  end_time = time.time()
                  elapsed_time = end_time - start_time
                  print(f"Time taken: {elapsed_time:.3f} seconds")
                  return self.cache['response_text'][row_id]

          # Handle the case when there are not enough results
          # or Euclidean distance is not met, asking to chromaDB.
          answer  = query_database([question], 1)
          response_text = answer['documents'][0][0]

          self.cache['questions'].append(question)
          self.cache['embeddings'].append(embedding[0].tolist())
          self.cache['answers'].append(answer)
          self.cache['response_text'].append(response_text)

          print('Answer recovered from ChromaDB. ')
          print(f'response_text: {response_text}')

          self.index.add(embedding)
          store_cache(self.json_file, self.cache)
          end_time = time.time()
          elapsed_time = end_time - start_time
          print(f"Time taken: {elapsed_time:.3f} seconds")

          return response_text
      except Exception as e:
          raise RuntimeError(f"Error during 'ask' method: {e}")

### Testing the semantic_cache class.

In [20]:
# Initialize the cache.
cache = semantic_cache('4cache_file.json')

Index trained


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.6k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [21]:
results = cache.ask("How work a vaccine?")

Answer recovered from ChromaDB. 
response_text: Summary : Shots may hurt a little, but the diseases they can prevent are a lot worse. Some are even life-threatening. Immunization shots, or vaccinations, are essential. They protect against things like measles, mumps, rubella, hepatitis B, polio, tetanus, diphtheria, and pertussis (whooping cough). Immunizations are important for adults as well as children.    Your immune system helps your body fight germs by producing substances to combat them. Once it does, the immune system "remembers" the germ and can fight it again. Vaccines contain germs that have been killed or weakened. When given to a healthy person, the vaccine triggers the immune system to respond and thus build immunity.     Before vaccines, people became immune only by actually getting a disease and surviving it. Immunizations are an easier and less risky way to become immune.     NIH: National Institute of Allergy and Infectious Diseases
Time taken: 0.655 seconds


As expected, this response has been obtained from ChromaDB. The class then stores it in the cache.

Now, if we send a second question that is quite different, the response should also be retrieved from ChromaDB. This is because the question stored previously is so dissimilar that it would surpass the specified threshold in terms of Euclidean distance.

In [22]:
results = cache.ask("Explain briefly what is a Periodic Paralyses")

Answer recovered from ChromaDB. 
response_text: Familial periodic paralyses are a group of inherited neurological disorders caused by mutations in genes that regulate sodium and calcium channels in nerve cells. They are characterized by episodes in which the affected muscles become slack, weak, and unable to contract. Between attacks, the affected muscles usually work as normal.
                
The two most common types of periodic paralyses are: Hypokalemic periodic paralysis is characterized by a fall in potassium levels in the blood. In individuals with this mutation attacks often begin in adolescence and are triggered by strenuous exercise, high carbohydrate meals, or by injection of insulin, glucose, or epinephrine. Weakness may be mild and limited to certain muscle groups, or more severe and affect the arms and legs. Attacks may last for a few hours or persist for several days. Some patients may develop chronic muscle weakness later in life. Hyperkalemic periodic paralysis is ch

Perfect, the semantic cache system is behaving as expected.

Let's proceed to test it with a question very similar to the one we just asked.

In this case, the response should come directly from the cache without the need to access the ChromaDB database.

In [23]:
results = cache.ask("Briefly explain me what is a periodic paralyses")

Answer recovered from Cache. 
0.018 smaller than 0.35
Found cache in row: 1 with score 0.018
response_text: Familial periodic paralyses are a group of inherited neurological disorders caused by mutations in genes that regulate sodium and calcium channels in nerve cells. They are characterized by episodes in which the affected muscles become slack, weak, and unable to contract. Between attacks, the affected muscles usually work as normal.
                
The two most common types of periodic paralyses are: Hypokalemic periodic paralysis is characterized by a fall in potassium levels in the blood. In individuals with this mutation attacks often begin in adolescence and are triggered by strenuous exercise, high carbohydrate meals, or by injection of insulin, glucose, or epinephrine. Weakness may be mild and limited to certain muscle groups, or more severe and affect the arms and legs. Attacks may last for a few hours or persist for several days. Some patients may develop chronic muscle w

The two questions are so similar that their Euclidean distance is truly minimal, almost as if they were identical.

Now, let's try another question, this time a bit more distinct, and observe how the system behaves.

In [24]:
question_def = "Write in 20 words what is a periodic paralyses"
results = cache.ask(question_def)

Answer recovered from Cache. 
0.220 smaller than 0.35
Found cache in row: 1 with score 0.220
response_text: Familial periodic paralyses are a group of inherited neurological disorders caused by mutations in genes that regulate sodium and calcium channels in nerve cells. They are characterized by episodes in which the affected muscles become slack, weak, and unable to contract. Between attacks, the affected muscles usually work as normal.
                
The two most common types of periodic paralyses are: Hypokalemic periodic paralysis is characterized by a fall in potassium levels in the blood. In individuals with this mutation attacks often begin in adolescence and are triggered by strenuous exercise, high carbohydrate meals, or by injection of insulin, glucose, or epinephrine. Weakness may be mild and limited to certain muscle groups, or more severe and affect the arms and legs. Attacks may last for a few hours or persist for several days. Some patients may develop chronic muscle w

We observe that the Euclidean distance has increased, but it still remains within the specified threshold. Therefore, it continues to return the response directly from the cache.

# Loading the model and creating the prompt
Time to use the library **transformers**, the most famous library from [hugging face](https://huggingface.co/) for working with language models.

We are importing:
* **Autotokenizer**: It is a utility class for tokenizing text inputs that are compatible with various pre-trained language models.
* **AutoModelForCasualLLM**: it provides an interface to pre-trained language models specifically designed for language generation tasks using causal language modeling (e.g., GPT models), or the model used in this notebook ***Gemma-2b-it**.

The model selected is [Gemma-2b-it](https://huggingface.co/google/gemma-2b-it).

Please, feel free to test [different Models](https://huggingface.co/models?pipeline_tag=text-generation&sort=trending), you need to search for NLP models trained for text-generation.


In [26]:
from torch import cuda, torch
#In a MAC Silicon the device must be 'mps'
# device = torch.device('mps') #to use with MAC Silicon
device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'

In [27]:
from transformers import AutoTokenizer, AutoModelForCausalLM

#model_id = "databricks/dolly-v2-3b"
model_id = "google/gemma-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id,
                                             device_map="cuda",
                                            torch_dtype=torch.bfloat16)


tokenizer_config.json:   0%|          | 0.00/2.16k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/888 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/627 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

## Creating the extended prompt
To create the prompt we use the result from query the 'semantic_cache' class  and the question introduced by the user.

The prompt have two parts, the **relevant context** that is the information recovered from the database and the **user's question**.

We only need to put the two parts together to create the prompt then send it to the model.

In [28]:
prompt_template = f"Relevant context: {results}\n\n The user's question: {question_def}"
prompt_template

"Relevant context: Familial periodic paralyses are a group of inherited neurological disorders caused by mutations in genes that regulate sodium and calcium channels in nerve cells. They are characterized by episodes in which the affected muscles become slack, weak, and unable to contract. Between attacks, the affected muscles usually work as normal.\n                \nThe two most common types of periodic paralyses are: Hypokalemic periodic paralysis is characterized by a fall in potassium levels in the blood. In individuals with this mutation attacks often begin in adolescence and are triggered by strenuous exercise, high carbohydrate meals, or by injection of insulin, glucose, or epinephrine. Weakness may be mild and limited to certain muscle groups, or more severe and affect the arms and legs. Attacks may last for a few hours or persist for several days. Some patients may develop chronic muscle weakness later in life. Hyperkalemic periodic paralysis is characterized by a rise in po

In [29]:
input_ids = tokenizer(prompt_template, return_tensors="pt").to("cuda")

Now all that remains is to send the prompt to the model and wait for its response!


In [30]:
outputs = model.generate(**input_ids,
                         max_new_tokens=256)
print(tokenizer.decode(outputs[0]))

<bos>Relevant context: Familial periodic paralyses are a group of inherited neurological disorders caused by mutations in genes that regulate sodium and calcium channels in nerve cells. They are characterized by episodes in which the affected muscles become slack, weak, and unable to contract. Between attacks, the affected muscles usually work as normal.
                
The two most common types of periodic paralyses are: Hypokalemic periodic paralysis is characterized by a fall in potassium levels in the blood. In individuals with this mutation attacks often begin in adolescence and are triggered by strenuous exercise, high carbohydrate meals, or by injection of insulin, glucose, or epinephrine. Weakness may be mild and limited to certain muscle groups, or more severe and affect the arms and legs. Attacks may last for a few hours or persist for several days. Some patients may develop chronic muscle weakness later in life. Hyperkalemic periodic paralysis is characterized by a rise in 

# Conclusions.
The performance improvement is approximately 50% between accessing ChromaDB and going directly to the cache. However, in larger projects, this difference increases, leading to enhancements of 90-95%.

We have very few data in Chroma, and only a single instance of the cache class. Typically, the data behind the cache system is much larger, possibly involving more than just a query to a vector database but sourced from various places.

It's common to have multiple instances of the cache class, usually based on user typology, as questions tend to repeat more among users who share common traits.

In summary, we have created a very simple RAG (Retrieval-Augmented Generation) system and enhanced it with a semantic cache layer between the user's question and obtaining the information necessary to create the enriched prompt.