## 1. Installing libraries & Imports

In [1]:
%%capture
!pip install transformers==4.38.1
!pip install accelerate==0.27.2
!pip install sentence-transformers==2.5.1
!pip install xformers==0.0.24
!pip install chromadb==0.4.24
!pip install datasets==2.17.1
!pip install faiss-cpu==1.8.0

In [2]:
import time
import json

import numpy as np
import pandas as pd
import chromadb
from datasets import load_dataset
from torch import cuda, torch
import faiss
from sentence_transformers import SentenceTransformer

## 2. Load the Dataset

In [3]:
from getpass import getpass

if 'hf_key' not in locals():
  hf_key = getpass("Your HuggingFace token:")
!huggingface-cli login --token $hf_key

Your HuggingFace token:··········
Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [4]:
data = load_dataset("keivalya/MedQuad-MedicalQnADataset", split="train")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading readme:   0%|          | 0.00/233 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/22.5M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [5]:
data = data.to_pandas()
data["id"] = data.index
data.head(5)

Unnamed: 0,qtype,Question,Answer,id
0,susceptibility,Who is at risk for Lymphocytic Choriomeningiti...,LCMV infections can occur after exposure to fr...,0
1,symptoms,What are the symptoms of Lymphocytic Choriomen...,LCMV is most commonly recognized as causing ne...,1
2,susceptibility,Who is at risk for Lymphocytic Choriomeningiti...,Individuals of all ages who come into contact ...,2
3,exams and tests,How to diagnose Lymphocytic Choriomeningitis (...,"During the first phase of the disease, the mos...",3
4,treatment,What are the treatments for Lymphocytic Chorio...,"Aseptic meningitis, encephalitis, or meningoen...",4


In [6]:
MAX_ROWS = 15000
DOCUMENT = "Answer"
TOPIC = "qtype"

In [7]:
subset_data = data.head(MAX_ROWS)

## 3. Configure the vector database

In [8]:
chroma_client = chromadb.PersistentClient(path="/chroma/")

## 4. Filling and querying the ChromaDB

In [9]:
collection_name = "news_collection"
if len(chroma_client.list_collections()) > 0 and collection_name in [chroma_client.list_collections()[0].name]:
    chroma_client.delete_collection(name=collection_name)

collection = chroma_client.create_collection(name=collection_name)

In [10]:
collection.add(
    documents=subset_data[DOCUMENT].tolist(),
    metadatas=[{TOPIC: topic} for topic in subset_data[TOPIC].tolist()],
    ids=[f"id{x}" for x in range(MAX_ROWS)],
)

/root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx.tar.gz: 100%|██████████| 79.3M/79.3M [00:01<00:00, 69.9MiB/s]


In [11]:
def query_database(query_text, n_results=10):
    results = collection.query(query_texts=query_text, n_results=n_results)
    return results

## 5. Creating the semantic cache system

In [12]:
def init_cache():
    # How to choose index for a particular task https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index
    index = faiss.IndexFlatL2(768)
    if index.is_trained:
        print("Index trained")

    # Initialize SentenceTransformer model
    encoder = SentenceTransformer("all-mpnet-base-v2")

    return index, encoder

def retrieve_cache(json_file):
    try:
        with open(json_file, "r") as file:
            cache = json.load(file)

    except FileNotFoundError:
        cache = {"questions": [], "embeddings": [], "answers": [], "response_text": []}

    return cache

def store_cache(json_file, cache):
    with open(json_file, "w") as file:
        json.dump(cache, file)

In [13]:
class SemanticCache:
    def __init__(self, json_file="cache_file.json", threshold=0.35):
        # Initialize Faiss index with Euclidean distance
        self.index, self.encoder = init_cache()

        # Set Euclidean distance threshold
        # A distance of 0 means identicals sentences
        # We only return sentences from cache under this threshold
        self.euclidean_threshold = threshold

        self.json_file = json_file
        self.cache = retrieve_cache(self.json_file)

    def ask(self, question: str) -> str:
        # Method to retrieve an answer from the cache or generate a new one
        start_time = time.time()
        try:
            # First we obtain the embeddings corresponding to the user question
            embedding = self.encoder.encode([question])

            # Search for the nearest neighbor in the index
            self.index.nprobe = 8
            D, I = self.index.search(embedding, 1)

            if D[0] >= 0:
                if I[0][0] >= 0 and D[0][0] <= self.euclidean_threshold:
                    row_id = int(I[0][0])

                    print("Answer recovered from Cache. ")
                    print(f"{D[0][0]:.3f} smaller than {self.euclidean_threshold}")
                    print(f"Found cache in row: {row_id} with score {D[0][0]:.3f}")
                    print(f"response_text: " + self.cache["response_text"][row_id])

                    end_time = time.time()
                    elapsed_time = end_time - start_time
                    print(f"Time taken: {elapsed_time:.3f} seconds")
                    return self.cache["response_text"][row_id]

            # Handle the case when there are not enough results
            # or Euclidean distance is not met, asking to chromaDB.
            answer = query_database([question], 1)
            response_text = answer["documents"][0][0]

            self.cache["questions"].append(question)
            self.cache["embeddings"].append(embedding[0].tolist())
            self.cache["answers"].append(answer)
            self.cache["response_text"].append(response_text)

            print("Answer recovered from ChromaDB. ")
            print(f"response_text: {response_text}")

            self.index.add(embedding)
            store_cache(self.json_file, self.cache)
            end_time = time.time()
            elapsed_time = end_time - start_time
            print(f"Time taken: {elapsed_time:.3f} seconds")

            return response_text

        except Exception as e:
            raise RuntimeError(f"Error during 'ask' method: {e}")

## 6. Testing out the semantic cache class

In [14]:
cache = SemanticCache("4cache.json")

Index trained


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.6k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [15]:
results = cache.ask("How do vaccines work?")

Answer recovered from ChromaDB. 
response_text: Summary : Shots may hurt a little, but the diseases they can prevent are a lot worse. Some are even life-threatening. Immunization shots, or vaccinations, are essential. They protect against things like measles, mumps, rubella, hepatitis B, polio, tetanus, diphtheria, and pertussis (whooping cough). Immunizations are important for adults as well as children.    Your immune system helps your body fight germs by producing substances to combat them. Once it does, the immune system "remembers" the germ and can fight it again. Vaccines contain germs that have been killed or weakened. When given to a healthy person, the vaccine triggers the immune system to respond and thus build immunity.     Before vaccines, people became immune only by actually getting a disease and surviving it. Immunizations are an easier and less risky way to become immune.     NIH: National Institute of Allergy and Infectious Diseases
Time taken: 1.184 seconds


In [16]:
results = cache.ask("Explain briefly what is a Sydenham chorea")

Answer recovered from ChromaDB. 
response_text: Chorea-acanthocytosis is one of a group of conditions called the neuroacanthocytoses that involve neurological problems and abnormal red blood cells. The condition is characterized by involuntary jerking movements (chorea), abnormal star-shaped red blood cells (acanthocytosis), and involuntary tensing of various muscles (dystonia), such as those in the limbs, face, mouth, tongue, and throat. Chorea-acanthocytosis is caused by mutations in the VPS13A gene and is inherited in an autosomal recessive manner. There are currently no treatments to prevent or slow the progression of chorea-acanthocytosis; treatment is symptomatic and supportive.
Time taken: 0.171 seconds


## 7. Loading the model and creating the prompt

In [17]:
from torch import cuda, torch

device = f"cuda:{cuda.current_device()}" if cuda.is_available() else "cpu"

In [19]:
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "google/gemma-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, torch_dtype=torch.bfloat16)

tokenizer_config.json:   0%|          | 0.00/2.16k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/888 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/627 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

## 8. Creating the extended prompt

In [20]:
question_def = "Write in 20 words what is a Sydenham chorea."
prompt_template = f"Relevant context: {results}\n\n The user's question: {question_def}"

In [21]:
input_ids = tokenizer(prompt_template, return_tensors="pt").to(device)

In [22]:
outputs = model.generate(**input_ids, max_new_tokens=256)
print(tokenizer.decode(outputs[0]))

<bos>Relevant context: Chorea-acanthocytosis is one of a group of conditions called the neuroacanthocytoses that involve neurological problems and abnormal red blood cells. The condition is characterized by involuntary jerking movements (chorea), abnormal star-shaped red blood cells (acanthocytosis), and involuntary tensing of various muscles (dystonia), such as those in the limbs, face, mouth, tongue, and throat. Chorea-acanthocytosis is caused by mutations in the VPS13A gene and is inherited in an autosomal recessive manner. There are currently no treatments to prevent or slow the progression of chorea-acanthocytosis; treatment is symptomatic and supportive.

 The user's question: Write in 20 words what is a Sydenham chorea.

The context does not provide any information about a Sydenham chorea, so I cannot answer this question from the provided context.<eos>
