# Building and Evaluating the Retrieval Augmented Generation (RAG) Flow 

## PART 1: Building the RAG Pipeline

### Import Libraries and Read Data Source

In [1]:
from IPython.display import display, HTML
import time
import os
import json
import random
import pandas as pd
import minsearch
from tqdm.auto import tqdm
import google.auth
from google.oauth2 import service_account
import vertexai
from vertexai.generative_models import GenerativeModel
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())

# Set up the API key and project ID for Gemini 
PROJECT_ID = os.environ['GCP_PROJECT_ID']
credentials = service_account.Credentials.from_service_account_file(
    "../pacific-ethos-428312-n5-eb4864ff3add.json"
)
vertexai.init(project=PROJECT_ID, credentials=credentials, location="us-central1")

# Path to your data JSONL file
file_path = '../data/bq-results-20240829-041517-1724904953827.jsonl'

# Read the JSONL file directly into a Pandas DataFrame
df = pd.read_json(file_path, lines=True)

### Index Data with Minsearch

In [14]:
index = minsearch.Index(
    text_fields=['abstract', 'authors', 'keywords', 'organization_affiliated', 'title', 'id'],
    keyword_fields=['id']
)

In [15]:
documents = df.to_dict(orient='records')
index.fit(documents)

<minsearch.Index at 0x1fc70c61a60>

### Retrieval Function

In [16]:
def search(query):
    boost = {}

    results = index.search(
        query=query,
        filter_dict={},
        boost_dict=boost,
        num_results=10
    )

    return results

In [17]:
#Tests
# query = 'Which articles discuss correlation between smoking and pregnancy?'
# results = search(query)
# query = "Provide a summary of insights found in articles that discuss pregnant women exposured to smoking"
# results = search(query)

### Prompt Building

In [2]:
prompt_template = """
You're an experienced biomedical researcher. Answer the QUESTION based only on the CONTEXT from our Biomedical Research database.
Use only the facts from the CONTEXT when answering the QUESTION. Your answer must be an accurate summary and not an exact copy of the text. 
However, article titles, authors, keywords, and organizations must be exact from the CONTEXT. 
Do NOT include any article that does NOT exist in the CONTEXT.
Do NOT include anything that does NOT answer the QUESTION.
Do NOT repeat ANYTHING that you have previously said in your response.

QUESTION: {question}

CONTEXT:
{context}
""".strip()

entry_template = """
abstract: {abstract}
authors: {authors} 
keywords: {keywords} 
organization_affiliated: {organization_affiliated} 
title: {title}
""".strip()

def build_prompt(query, search_results):
    context = ""
    
    for doc in search_results:
        context = context + entry_template.format(**doc) + "\n\n"

    prompt = prompt_template.format(question=query, context=context).strip()
    return prompt

In [19]:
#Test
# print(build_prompt(query, results))
# prompt = build_prompt(query, results)

### Calling LLM

In [3]:
def llm(prompt, model="gemini-1.5-flash-001"):
    model = GenerativeModel(model)
    response = model.generate_content(prompt)
    return response.text

In [21]:
#Test
# answer = llm(prompt)
# print(answer)

### RAG Pipeline

In [24]:
def rag(query, model='gemini-1.5-flash-001'):
    search_results = search(query)
    prompt = build_prompt(query, search_results)
    #print(prompt)
    answer = llm(prompt, model=model)
    return answer

### Sanity Check

In [25]:
question = 'Provide a summary of insights found in articles that discuss correlation between pregnancy and smoking'
answer = rag(question)
print(answer)

The article "Interaction of glutathione S-transferase polymorphisms and tobacco smoking during pregnancy in susceptibility to autism spectrum disorders" by Vanja Mandic-Maravic et al. discusses a possible correlation between maternal smoking during pregnancy and the risk of Autism Spectrum Disorders (ASD).  However, the study found that maternal smoking during pregnancy did not increase the risk of ASD. 



## PART 2: Evaluating Retrieval

### Read and Join Data Source and Ground Truth Data

In [2]:
# Read the JSONL file directly into a Pandas DataFrame
df = pd.read_json(file_path, lines=True)

In [5]:
# Read the ground truth CSV file directly into a Pandas DataFrame
ground_truth_df = pd.read_csv('../data/ground-truth-retrieval.csv')

ground_truth_df.head()

Unnamed: 0,id,question
0,c3ea29df-6683-4443-a2c7-3f027137c1d8,What are some systemic factors that influence ...
1,c3ea29df-6683-4443-a2c7-3f027137c1d8,What types of systemic factors should be consi...
2,c3ea29df-6683-4443-a2c7-3f027137c1d8,What is the role of systemic factors in determ...
3,c3ea29df-6683-4443-a2c7-3f027137c1d8,How do systemic factors impact treatment plann...
4,c3ea29df-6683-4443-a2c7-3f027137c1d8,"In the context of periodontitis reassessment, ..."


In [6]:
# Convert ground truth data to JSON 
ground_truth = ground_truth_df.to_dict(orient='records')

ground_truth[0]

{'id': 'c3ea29df-6683-4443-a2c7-3f027137c1d8',
 'question': 'What are some systemic factors that influence treatment decisions for residual periodontal probing depths?'}

In [7]:
df_sample = df[df['id'].isin(list(ground_truth_df['id']))].reset_index(drop=True)
print('Evaluation data shape:', df_sample.shape)
print('Number of unique ids:', len(df_sample['id'].unique()))

Evaluation data shape: (200, 6)
Number of unique ids: 200


### Index Evaluation Data with Minsearch

In [8]:
index = minsearch.Index(
    text_fields=['abstract', 'authors', 'keywords', 'organization_affiliated', 'title', 'id'],
    keyword_fields=['id']
)

documents = df_sample.to_dict(orient='records')
index.fit(documents)

<minsearch.Index at 0x21b4ef918e0>

### Implement the Evaluation Functions

In [18]:
def hit_rate(relevance_total):
    cnt = 0

    for line in relevance_total:
        if True in line:
            cnt = cnt + 1

    return cnt / len(relevance_total)

def mrr(relevance_total):
    total_score = 0.0

    for line in relevance_total:
        for rank in range(len(line)):
            if line[rank] == True:
                total_score = total_score + 1 / (rank + 1)

    return total_score / len(relevance_total)

def minsearch_search(query):
    boost = {}

    results = index.search(
        query=query,
        filter_dict={},
        boost_dict=boost,
        num_results=10
    )

    return results

def evaluate(ground_truth, search_function):
    relevance_total = []

    for q in tqdm(ground_truth):
        doc_id = q['id']
        results = search_function(q)
        relevance = [d['id'] == doc_id for d in results]
        relevance_total.append(relevance)

    return {
        'hit_rate': hit_rate(relevance_total),
        'mrr': mrr(relevance_total),
    }

### Compute the Evaluation Metrics

In [8]:
evaluate(ground_truth, lambda q: minsearch_search(q['question']))

  0%|          | 0/1000 [00:00<?, ?it/s]

{'hit_rate': 0.959, 'mrr': 0.8424214285714289}

### Finding the best parameters

In [17]:
def simple_optimize(param_ranges, objective_function, n_iterations=10):
    best_params = None
    best_score = float('-inf')  # Assuming we're minimizing. Use float('-inf') if maximizing.

    for _ in range(n_iterations):
        # Generate random parameters
        current_params = {}
        for param, (min_val, max_val) in param_ranges.items():
            if isinstance(min_val, int) and isinstance(max_val, int):
                current_params[param] = random.randint(min_val, max_val)
            else:
                current_params[param] = random.uniform(min_val, max_val)
        
        # Evaluate the objective function
        current_score = objective_function(current_params)
        
        # Update best if current is better
        if current_score > best_score:  # Change to > if maximizing
            best_score = current_score
            best_params = current_params
    
    return best_params, best_score

def minsearch_search(query, boost=None):
    if boost is None:
        boost = {}

    results = index.search(
        query=query,
        filter_dict={},
        boost_dict=boost,
        num_results=10
    )

    return results

def objective(boost_params):
    def search_function(q):
        return minsearch_search(q['question'], boost_params)

    results = evaluate(ground_truth, search_function)
    return results['mrr']

In [17]:
param_ranges = {
    'abstract': (0.0, 3.0),
    'authors': (0.0, 3.0),
    'keywords': (0.0, 3.0),
    'organization_affiliated': (0.0, 3.0),
    'title': (0.0, 3.0)
}
simple_optimize(param_ranges, objective, n_iterations=20)

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

({'abstract': 2.3804731710209435,
  'authors': 0.033858106315256986,
  'keywords': 0.5177450395210756,
  'organization_affiliated': 1.3341455155330277,
  'title': 0.1961060756966928},
 0.9655884920634924)

### Compute the Evaluation Metrics with Optimized Parameters

In [20]:
def minsearch_improved(query):
    boost = {
          'abstract': 2.38,
          'authors': 0.03,
          'keywords': 0.52,
          'organization_affiliated': 1.33,
          'title': 0.20
    }

    results = index.search(
        query=query,
        filter_dict={},
        boost_dict=boost,
        num_results=10
    )

    return results

In [19]:
evaluate(ground_truth, lambda q: minsearch_improved(q['question']))

  0%|          | 0/1000 [00:00<?, ?it/s]

{'hit_rate': 0.99, 'mrr': 0.9655884920634924}

## PART 3: Evaluating RAG

### Prompt Building

In [4]:
prompt2_template = """
You are an expert evaluator for a Biomedical research question answering system.
Your task is to analyze the relevance of the answer to the given question.
Based on the relevance of the generated answer, you will classify it
as "NON_RELEVANT", "PARTLY_RELEVANT", or "RELEVANT".

Here is the data for evaluation:

Question: {question}
Generated Answer: {answer_llm}

Please analyze the content and context of the generated answer in relation to the question
and provide your evaluation in parsable JSON without using code blocks:

{{
  "Relevance": "NON_RELEVANT" | "PARTLY_RELEVANT" | "RELEVANT",
  "Explanation": "[Provide a brief explanation for your evaluation]"
}}
""".strip()

### Create the Questions Random Sample out of the Ground Truth Dataset

In [9]:
ground_truth_df.shape

(1000, 2)

In [10]:
# Sample one random question for each id
sample_q_df = ground_truth_df.groupby('id').apply(lambda x: x.sample(n=1)).reset_index(drop=True)

display(sample_q_df.head())

Unnamed: 0,id,question
0,00aade8a-7177-4b9e-8120-84770f278638,What are the challenges associated with using ...
1,01459f11-c3da-4d57-8e5e-0c9874459f54,What is the relationship between the geographi...
2,0183843c-ce40-44be-a2c2-ae8b2110b8a4,What are the specific network metrics that are...
3,0213f1ca-8b6c-462f-9688-d5cc0f3919ef,What is the significance of the finding that *...
4,028c6a25-fcb6-446f-9d87-b8b718c0e4a2,Were there any other drugs besides ranitidine ...


In [11]:
print('Questions data shape:', sample_q_df.shape)
print('Number of unique ids:', len(sample_q_df['id'].unique()))

Questions data shape: (200, 2)
Number of unique ids: 200


### Update RAG to use Boosted Search

In [12]:
def rag(query, model='gemini-1.5-flash-001'):
    search_results = minsearch_improved(query)
    prompt = build_prompt(query, search_results)
    answer = llm(prompt, model=model)
    return answer

### Sanity Check

In [39]:
id0 = sample_q_df.loc[0, 'id']
q0 = sample_q_df.loc[0, 'question']
answer = rag(q0)
prompt = prompt2_template.format(question=q0, answer_llm=answer)
evaluation_llm = llm(prompt)
json_string = evaluation_llm.strip().replace('json', '').replace('`', '')
evaluation = json.loads(json_string)
print("Question:", q0)
print("Answer:", answer)
print("Evaluation:", evaluation)

Question: What are the challenges associated with using conventional battery technologies in epidermal electronic systems?
Answer: The physical bulk, large mass, and high mechanical modulus of conventional battery technologies hinder efforts to achieve epidermal characteristics. Near-field power transfer schemes offer only a limited operating distance.  

Evaluation: {'Relevance': 'RELEVANT', 'Explanation': 'The answer directly addresses the question by identifying key challenges of conventional batteries in epidermal electronics: bulkiness, weight, and stiffness. It also mentions limited operating distance of near-field power transfer, which is relevant to the topic of power sources in epidermal systems.'}


### Generate Answers and Evaluations for the Sampled Records

In [13]:
def handle_rate_limit_error():
    print("Rate limit exceeded. Sleeping for 60 seconds...")
    time.sleep(60)

#### Model 1: Gemini 1.5 Flash 001

In [None]:
sample_q = sample_q_df.to_dict(orient='records')
evaluations = []
for record in tqdm(sample_q):
    question = record['question']
    while True: # Retry loop
        try:
            answer = rag(question) 
            prompt = prompt2_template.format(
                question=question,
                answer_llm=answer
            )
            evaluation_llm = llm(prompt)
            json_string = evaluation_llm.strip().replace('json', '').replace('`', '')
            evaluation = json.loads(json_string)
            break  # Exit the retry loop if successful
        except Exception as e:
            if "Quota exceeded" in str(e):
                handle_rate_limit_error()
            else:
                # Handle other exceptions or re-raise them
                raise e
    evaluations.append((record, answer, evaluation))

In [43]:
len(evaluations)

111

In [45]:
sample_q = sample_q_df[111:].to_dict(orient='records')
for record in tqdm(sample_q):
    question = record['question']
    while True: # Retry loop
        try:
            answer = rag(question) 
            prompt = prompt2_template.format(
                question=question,
                answer_llm=answer
            )
            evaluation_llm = llm(prompt)
            json_string = evaluation_llm.strip().replace('json', '').replace('`', '')
            evaluation = json.loads(json_string)
            break  # Exit the retry loop if successful
        except Exception as e:
            if "Quota exceeded" in str(e):
                handle_rate_limit_error()
            else:
                # Handle other exceptions or re-raise them
                raise e
    evaluations.append((record, answer, evaluation))

  0%|          | 0/89 [00:00<?, ?it/s]

Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 sec

In [46]:
len(evaluations)

200

In [47]:
df_eval = pd.DataFrame(evaluations, columns=['record', 'answer', 'evaluation'])

df_eval['id'] = df_eval.record.apply(lambda d: d['id'])
df_eval['question'] = df_eval.record.apply(lambda d: d['question'])

df_eval['relevance'] = df_eval.evaluation.apply(lambda d: d['Relevance'])
df_eval['explanation'] = df_eval.evaluation.apply(lambda d: d['Explanation'])

del df_eval['record']
del df_eval['evaluation']

In [50]:
df_eval.relevance.value_counts()

relevance
RELEVANT           132
PARTLY_RELEVANT     68
Name: count, dtype: int64

In [48]:
df_eval.relevance.value_counts(normalize=True)

relevance
RELEVANT           0.66
PARTLY_RELEVANT    0.34
Name: proportion, dtype: float64

In [49]:
df_eval.to_csv('../data/rag-eval-gemini-1.5-flash-001.csv', index=False)

#### Model 1: Gemini 1.0 Pro

In [22]:
sample_q = sample_q_df.to_dict(orient='records')
evaluations = []
for record in tqdm(sample_q):
    question = record['question']
    while True: # Retry loop
        try:
            answer = rag(question, model='gemini-1.0-pro') 
            prompt = prompt2_template.format(
                question=question,
                answer_llm=answer
            )
            evaluation_llm = llm(prompt)
            json_string = evaluation_llm.strip().replace('json', '').replace('`', '')
            evaluation = json.loads(json_string)
            break  # Exit the retry loop if successful
        except Exception as e:
            if "Quota exceeded" in str(e):
                handle_rate_limit_error()
            else:
                # Handle other exceptions or re-raise them
                raise e
    evaluations.append((record, answer, evaluation))

  0%|          | 0/200 [00:00<?, ?it/s]

Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 seconds...
Rate limit exceeded. Sleeping for 60 sec

In [23]:
len(evaluations)

200

In [24]:
df_eval = pd.DataFrame(evaluations, columns=['record', 'answer', 'evaluation'])

df_eval['id'] = df_eval.record.apply(lambda d: d['id'])
df_eval['question'] = df_eval.record.apply(lambda d: d['question'])

df_eval['relevance'] = df_eval.evaluation.apply(lambda d: d['Relevance'])
df_eval['explanation'] = df_eval.evaluation.apply(lambda d: d['Explanation'])

del df_eval['record']
del df_eval['evaluation']

In [25]:
df_eval.relevance.value_counts()

relevance
RELEVANT           132
PARTLY_RELEVANT     53
NON_RELEVANT        15
Name: count, dtype: int64

In [26]:
df_eval.relevance.value_counts(normalize=True)

relevance
RELEVANT           0.660
PARTLY_RELEVANT    0.265
NON_RELEVANT       0.075
Name: proportion, dtype: float64

In [27]:
df_eval.to_csv('../data/rag-eval-gemini-1.0-pro.csv', index=False)