# Retrieval Augmented Generation (RAG)

## 1. Install and load necessary libraries

In [None]:
! pip install chromadb langchain langchain_community sentence_transformers xmltodict accelerate rank_bm25 Bio

Collecting chromadb
  Downloading chromadb-0.5.20-py3-none-any.whl.metadata (6.8 kB)
Collecting langchain_community
  Downloading langchain_community-0.3.8-py3-none-any.whl.metadata (2.9 kB)
Collecting xmltodict
  Downloading xmltodict-0.14.2-py2.py3-none-any.whl.metadata (8.0 kB)
Collecting rank_bm25
  Downloading rank_bm25-0.2.2-py3-none-any.whl.metadata (3.2 kB)
Collecting Bio
  Downloading bio-1.7.1-py3-none-any.whl.metadata (5.7 kB)
Collecting build>=1.0.3 (from chromadb)
  Downloading build-1.2.2.post1-py3-none-any.whl.metadata (6.5 kB)
Collecting chroma-hnswlib==0.7.6 (from chromadb)
  Downloading chroma_hnswlib-0.7.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (252 bytes)
Collecting fastapi>=0.95.2 (from chromadb)
  Downloading fastapi-0.115.5-py3-none-any.whl.metadata (27 kB)
Collecting uvicorn>=0.18.3 (from uvicorn[standard]>=0.18.3->chromadb)
  Downloading uvicorn-0.32.1-py3-none-any.whl.metadata (6.6 kB)
Collecting posthog>=2.4.0 (from chromadb)
  Do

In [None]:
! pip install -U bitsandbytes

Collecting bitsandbytes
  Downloading bitsandbytes-0.44.1-py3-none-manylinux_2_24_x86_64.whl.metadata (3.5 kB)
Downloading bitsandbytes-0.44.1-py3-none-manylinux_2_24_x86_64.whl (122.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m122.4/122.4 MB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.44.1


In [None]:
from huggingface_hub import login
login(token="INSERT HERE YOUR HF TOKEN")

In [None]:
# Libraries for models and RAG (HF, LangChain, ChromaDB)
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import chromadb
import chromadb.config
from langchain.vectorstores import Chroma
from langchain_community.embeddings.sentence_transformer import (
    SentenceTransformerEmbeddings,
)
from langchain_text_splitters import NLTKTextSplitter
from sentence_transformers import SentenceTransformer, CrossEncoder
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.retrievers import BM25Retriever, EnsembleRetriever

# NLTK libraries for stopwords
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
import string

# Library to retrieve PMC articles
from Bio import Entrez, Medline


# Various libraries
import json
import re
import time
import urllib.request
import requests
from time import sleep
import xmltodict
import numpy as np

In [None]:
Entrez.email = 'INSERT YOUR EMAIL HERE' # Insert the email to use Entrez and retrieve PMC articles

## 2. Models loading

Embedding model (bge-base-en) to create the embeddings of the retrieved articles

In [None]:
embedding_model = SentenceTransformer("BAAI/bge-base-en")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/90.1k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/719 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/366 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Load the same model using HuggingFaceEmbeddings (to be used during the RAG process)

In [None]:
rag_model = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en")

  rag_model = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en")


Load the reranker (mxbai-rerank-large-v1)

In [None]:
rerank_model = CrossEncoder(model_name = 'mixedbread-ai/mxbai-rerank-large-v1')

config.json:   0%|          | 0.00/970 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/870M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.45k [00:00<?, ?B/s]

spm.model:   0%|          | 0.00/2.46M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/8.65M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/23.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/970 [00:00<?, ?B/s]

Load the model to be used for generation (Mistral-7B-Instruct-v.0.2)

In [None]:
llm = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.2", load_in_4bit=True,
    device_map='auto' # load it in the current GPU
)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

config.json:   0%|          | 0.00/596 [00:00<?, ?B/s]

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


model.safetensors.index.json:   0%|          | 0.00/25.1k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/2.10k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

In [None]:
llm_generation_pipeline = pipeline(
    model=llm,
    tokenizer=tokenizer,
    task="text-generation",
    return_full_text=False,
    max_new_tokens=2000)

## 3. Functions to process the text

Download NLTK stopwords

In [None]:
nltk.download('stopwords')
nltk.download('punkt')
nltk.download('punkt_tab')

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


True

The user query determines what will be searched in the PMC database. For this reason we want to remove stopwords from the query so that common words in questions (such as "What" or the question mark) are not included in the database search.

In [None]:
def remove_stopwords_and_punctuation(sentence):

    """ Function that takes a sentence (string) and return the sentence without stopwords and punctuation """

    stop_words = set(stopwords.words('english'))
    word_tokens = word_tokenize(sentence)

    filtered_sentence = [word for word in word_tokens if word.lower() not in stop_words and word not in string.punctuation]

    filtered_string_sentence = ' '.join(filtered_sentence)

    return filtered_string_sentence


Examples:

In [None]:
sentence = 'What is the role of RNase L in host-pathogen interaction and immune signaling?'
result = remove_stopwords_and_punctuation(sentence)
print(result)

role RNase L host-pathogen interaction immune signaling


In [None]:
sentence = 'What is the role of STAT3 in the field of Inborn Errors of Immunity?'
result = remove_stopwords_and_punctuation(sentence)
print(result)

role STAT3 field Inborn Errors Immunity


We create the text splitter in order to divide portions of articles into chunks in case they are too long for the embedding model

In [None]:
text_splitter = NLTKTextSplitter(chunk_size=2000, chunk_overlap = 100)

## 4. RAG System

First we define a function to call the API

In [None]:
def call_api(url):
    """Function that takes the url and read the information contained in the web page"""
    time.sleep(1)
    url = url.replace(' ', '+')
    print(url)

    req = urllib.request.Request(url)
    with urllib.request.urlopen(req) as response:
        call = response.read()

    return call

How the RAG system works:
1. The question is processed using the function defined earlier (stopwords and punctuation are removed)
2. The processed question is passed as query to search on the NCBI PMC database (only the abstracts and the PMC IDs are retrieved)
3. We try to retrieve the full texts of the articles using the BioC API (only Open Access articles are available). If one article is not available we keep only its abstract
4. We chunk the paragraphs of the articles (if needed), we create their embeddings and we store them in a temporary Chroma DB collection
5. Chunks of articles are retrieved from the collection according to their similarity with the query (semantic and keywords) using the embedding model
6. The reranker reranks them and the top chunks are given as context to the LLM that generates the answer.


In [None]:
def rag_system(query):

  """
  Function that performs the RAG process given a question/query
  input:
    query - str: the question for the LLM
  outputs:
    answer_base - str: the answer generated without RAG
    answer_rag - str: the answer generated with RAG
    final_docs - list: the chunks of the articles that have been used by the RAG process
    selected_ids - list: the PMC IDs of the articles that have been used by the RAG process

  """

  search_term = remove_stopwords_and_punctuation(query)

  # Search PubMed Central (PMC) for articles
  handle = Entrez.esearch(db="pmc", term=search_term, retmax=60) # we retrieve at most 30 articles
  record = Entrez.read(handle)
  id_list = record["IdList"] # we extract the PMC IDs

  # Fetch the text
  handle = Entrez.efetch(db="pmc", id=id_list, rettype="medline", retmode="text")
  records = Medline.parse(handle)

  # We create a dictionary in which the keys are the PMC IDs while the values are lists containing the paragraphs of the articles.
  articles_dictionary = {}
  for record in records:
    articles_dictionary[record.get("PMC", "?")] = [record.get("AB", "?")]

  # We create a string with all the PMC IDs to search the articles full texts using the BioC API
  ids_dictionary = ",".join(articles_dictionary.keys())

  url_full_text = f"https://www.ncbi.nlm.nih.gov/research/bionlp/RESTful/pmcoa.cgi/BioC_json/{ids_dictionary}/ascii"
  call_pmc=call_api(url_full_text) # we call the api
  json_pmc = json.loads(call_pmc)

  # For the articles that we have found we keep the paragraphs which are longer than 300 characters (so that they have enough information)
  for article in json_pmc:
    paragraphs = []
    for result in article['documents'][0]['passages']:
      if result["infons"]['type']=='paragraph' and result['infons']['section_type'] not in ['AUTH_CONT','COMP_INT'] and len(result['text'])>300:
        paragraphs.append(result['text'])
    articles_dictionary[article['documents'][0]['id']]=paragraphs # we substitute the abstract with all the paragraphs (if we found the article)

  # We remove some articles since some of them does not contain anything (only a question mark)
  articles_dictionary = {key: value for key, value in articles_dictionary.items() if value != "?"}

  # We create a list of chunks. The list contains sublists of the form [pmc_id, chunk_text] so afterwards we can keep track of what chunk has been used.
  chunks_tot = []
  for key, value in articles_dictionary.items():
    chunks = []
    for el in value:
      if len(el)>2000: # if a paragraph is longer than 2000 characters we chunk it.
        chunk = text_splitter.split_text(el)
      else:
        chunk = [el]
      for i in range(len(chunk)):
        chunks.append(chunk[i])
    for i in range(len(chunks)):
      chunks_tot.append([f"{key}-{i+1}", chunks[i]])

  # a subfix is added to the PMC IDs to indicate which part of the article is used.
  chunks_list = [chunk[1] for chunk in chunks_tot]
  pmc_ids = [chunk[0] for chunk in chunks_tot]

  # We create a chroma collection
  persistent_client = chromadb.PersistentClient(path="")
  collection = persistent_client.get_or_create_collection("temp_collection") # create a temporary collection

  # embed the abstracts annd transform them to a list
  embedd = embedding_model.encode(chunks_list, normalize_embeddings=True)
  emb_list = np.array(embedd).tolist()

  # add the abstracts and their embeddings to the db
  collection.add(
          embeddings = emb_list,
          documents = chunks_list,
          ids = pmc_ids
        )

  db = Chroma(
        client=persistent_client,
        collection_name="temp_collection",
        embedding_function=rag_model,
    )

  full_collection = collection.get()

  # create the keywords retriever
  keywords_retriever = BM25Retriever.from_texts(chunks_list)
  keywords_retriever.k = 50 # set top k documents in keywords retriever
  # create the semantic retriever
  semantic_retriever = db.as_retriever(
      search_type="similarity",
      search_kwargs={'k': 50}) # search top_k most similar documents to the query (using embedding model)
  # create the enseble retriever
  ensemble_retriever = EnsembleRetriever(
    retrievers=[keywords_retriever, semantic_retriever],
    weights=[0.6, 0.4])

  # define the query for the retrieval
  docs = ensemble_retriever.get_relevant_documents(search_term) # retrieve relevant documents
  query_with_docs = [(search_term, doc.page_content) for doc in docs] # create pairs of (query, document)
  scores = rerank_model.predict(query_with_docs) # creates similarity score between query and each document
  ranking = sorted(list(zip(docs, scores)), key=lambda x: x[1], reverse = True) # sort according to scores
  rer_k=6 # number of abstracts to provide as context
  final_docs = [ranking[i][0].page_content for i in range(rer_k)] # take the content of the abstracts

  # create a unique string with the abstracts
  fin_docs=f"Document: {final_docs[0]}"
  for i in final_docs[1:]:
    fin_docs+=f'\n\nDocument: {i}'

  selected_ids = []
  for doc in final_docs:
    for i in range(len(full_collection['documents'])):
      if doc==full_collection['documents'][i]:
        selected_ids.append(full_collection['ids'][i])

  # delete the temporary collection
  persistent_client.delete_collection('temp_collection')

  prompt_rag=f"""Answer the question using the following context:
{fin_docs}

Analyze carefully the context and come with an exhaustive and complete answer to the following question.

Question: {query}

Answer:

"""
  answer_rag = llm_generation_pipeline(prompt_rag)[0]['generated_text']

  answer_base = llm_generation_pipeline(query+"\n\n")[0]['generated_text']

  output = f"""Answer without RAG:
{answer_base}



Answer with RAG:
{answer_rag}



Retrieved abstracts (PMC IDs {','.join(selected_ids)}):

{fin_docs}
"""

  print(output) # print the full output

  return answer_base, answer_rag, final_docs, selected_ids



In [None]:
print(rag_system("What is the role of RNase L in host-pathogen interaction and immune signaling?"))

https://www.ncbi.nlm.nih.gov/research/bionlp/RESTful/pmcoa.cgi/BioC_json/PMC11559068,PMC11551408,PMC11549945,PMC11536464,PMC11528449,PMC11519428,PMC11494902,PMC11435178,PMC11414379,PMC11392916,PMC11389394,PMC11578319,PMC11545081,PMC11508327,PMC11479892,PMC11472168,PMC11447406,PMC11413781,PMC11406996,PMC11374765,PMC11370967,PMC11343454,PMC11360189,PMC11305385,PMC11358953,PMC11359028,PMC11353914,PMC11328406,PMC11276256,PMC11281336,PMC11280125,PMC10402074,PMC11229848,PMC11210920,PMC11189941,PMC11201578,PMC11209522,PMC11141478,PMC11171315,PMC11209158,PMC11234502,PMC11125849,PMC11093365,PMC11126147,PMC11125669,PMC11075854,PMC11125802,PMC11092344,PMC11083791,PMC11058653,PMC11054756,PMC11006213,PMC11018982,PMC11257101,PMC11018997,PMC11048616,PMC10968337,PMC10939522,PMC10925002,PMC10914904/ascii


  db = Chroma(
  docs = ensemble_retriever.get_relevant_documents(search_term) # retrieve relevant documents


Answer without RAG:
RNase L is a ribonuclease enzyme that plays a crucial role in the host immune response against various viruses. It is primarily involved in the degradation of viral RNA, thereby inhibiting viral replication. RNase L is activated through several signaling pathways, including the interferon (IFN) system, which is the major antiviral defense mechanism in eukaryotic cells.

The activation of RNase L occurs through a complex process involving multiple proteins and post-translational modifications. The enzyme is kept in an inactive form in the cytoplasm of the cell by binding to an inhibitor protein, RNase L inhibitor (RLI). Upon viral infection, IFN production is triggered, leading to the activation of several signaling pathways, including the double-stranded RNA (dsRNA) sensor, melanoma differentiation-associated gene 5 (MDA-5), and Toll-like receptor 3 (TLR3). These sensors detect viral RNA and initiate a cascade of signaling events, ultimately leading to the activatio

In [None]:
print(rag_system("What is the role of STAT3 in the field of Inborn Errors of Immunity?"))

https://www.ncbi.nlm.nih.gov/research/bionlp/RESTful/pmcoa.cgi/BioC_json/PMC11565648,PMC11559068,PMC11554218,PMC11552499,PMC11493558,PMC11431085,PMC11413684,PMC11555470,PMC11522874,PMC11505985,PMC11432681,PMC11396297,PMC11374039,PMC11322596,PMC11345217,PMC11301933,PMC11351677,PMC7616361,PMC11277571,PMC11263878,PMC11211259,PMC11109925,PMC11106885,PMC11121590,PMC11094350,PMC11039790,PMC11011008,PMC11002082,PMC10991685,PMC10972870,PMC11191067,PMC10961426,PMC10873437,PMC10792788,PMC10800401,PMC10772138,PMC10792266,PMC10748506,PMC10740682,PMC10707526,PMC10683363,PMC10703449,PMC10704817,PMC10670380,PMC10672531,PMC10850682,PMC10645063,PMC10589750,PMC10621649,PMC10574116,PMC10615875,PMC10544547,PMC10605219,PMC10600960,PMC10621644,PMC10526783,PMC10538643,PMC10560124,PMC10505736,PMC10470049/ascii


INFO:backoff:Backing off send_request(...) for 0.5s (requests.exceptions.ConnectionError: ('Connection aborted.', RemoteDisconnected('Remote end closed connection without response')))


Answer without RAG:
Inborn Errors of Immunity (IEI) are genetic disorders that impair the development or function of the immune system. STAT3 (Signal Transducer and Activator of Transcription 3) is a transcription factor that plays a crucial role in the immune response, particularly in the JAK-STAT signaling pathway.

STAT3 is involved in the regulation of various immune functions, including the production of interferons, inflammatory cytokines, and the differentiation of T helper 17 (Th17) cells. In IEI, mutations in STAT3 can lead to impaired immune function and increased susceptibility to infections.

One specific IEI caused by STAT3 mutations is known as Loss of Function (LOF) STAT3 deficiency. This disorder is characterized by a combined immunodeficiency, which affects both the adaptive and innate immune systems. Patients with LOF STAT3 deficiency have a reduced ability to produce interferons and Th17 cells, leading to recurrent bacterial and viral infections.

Moreover, gain-of-f

In [None]:
print(rag_system("What possible genetic diagnosis would you hypothesize in a young patient with signs of vasculitis and anemia? And without anemia?"))

https://www.ncbi.nlm.nih.gov/research/bionlp/RESTful/pmcoa.cgi/BioC_json/PMC11586569,PMC11412324,PMC11395540,PMC11586549,PMC11559324,PMC11391898,PMC11384718,PMC11235923,PMC11235737,PMC11157553,PMC11094575,PMC11094350,PMC11043565,PMC11191887,PMC10913204,PMC11176220,PMC10917318,PMC10727565,PMC10775241,PMC10695392,PMC10671046,PMC10672691,PMC10810974,PMC10580657,PMC10594506,PMC10483180,PMC10661818,PMC10586544,PMC10323075,PMC10264219,PMC10371271,PMC10242122,PMC10198255,PMC10198253,PMC10092953,PMC10015523,PMC10000335,PMC9870665,PMC10953327,PMC9701571,PMC9628617,PMC9688279,PMC9585390,PMC9708457,PMC9569470,PMC9466322,PMC9902739,PMC9592650,PMC9536822,PMC9421640,PMC9426798,PMC9379246,PMC9366260,PMC9316835,PMC9429973,PMC9152586,PMC9073161,PMC9017428,PMC8965239,PMC8987778/ascii


