In [1]:
import os, sys
sys.path.append(os.path.abspath("../.."))

In [2]:
import os

import pandas as pd
from dotenv import load_dotenv
import psycopg2
from psycopg2.extras import execute_values
from pyarrow.compute import top_k_unstable
from sentence_transformers import SentenceTransformer
from huggingface_hub import login
import json
import pprint

from RAG.retrieval_agent_hybrid import HybridRetrievalAgent

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
with open("../test_cases/knowledge_base.json") as f:
    knowledge_base = json.load(f)
    stm_data = knowledge_base["STM_data"]
    ltm_data = knowledge_base["LTM_data"]
    hcm_data = knowledge_base["HCM_data"]

stm_data_doc_ids_mapping = {stm['content']:stm['document_id'] for stm in stm_data}
ltm_data_doc_ids_mapping = {ltm_document['value']: ltm_document['document_id'] for ltm_document in ltm_data}
hcm_data_doc_ids_mapping = {hcm_record['description']: hcm_record['document_id'] for hcm_record in hcm_data}

consolidated_mapping = {
    'stm': stm_data_doc_ids_mapping,
    'ltm': ltm_data_doc_ids_mapping,
    'hcm': hcm_data_doc_ids_mapping
}
with open("../test_cases/doc_mapping.json", "w") as f:
    json.dump(consolidated_mapping, f, indent=4)

### Test with the rag_test_cases.json file

In [4]:
test_elderly_id = "87654321-4321-4321-4321-019876543210"
retrieval_agent = HybridRetrievalAgent(test_elderly_id)

INFO:root:Loading embedding model: google/embeddinggemma-300m
INFO:sentence_transformers.SentenceTransformer:Use pytorch device_name: cpu
INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: google/embeddinggemma-300m
INFO:sentence_transformers.SentenceTransformer:14 prompts are loaded, with the keys: ['query', 'document', 'BitextMining', 'Clustering', 'Classification', 'InstructionRetrieval', 'MultilabelClassification', 'PairClassification', 'Reranking', 'Retrieval', 'Retrieval-query', 'Retrieval-document', 'STS', 'Summarization']
INFO:root:embedding model loaded successfully: google/embeddinggemma-300m
INFO:root:Loading CrossEncoder model: BAAI/bge-reranker-base
INFO:sentence_transformers.cross_encoder.CrossEncoder:Use pytorch device: cpu
INFO:root:✅ CrossEncoder model loaded: BAAI/bge-reranker-base


### Perform evaluation for single test memory store retrival


In [5]:
with open("../test_cases/augmented_agentic_rag_test_cases.json", "r") as f:
    test_cases = json.load(f)

from RAG.utils.embedder import CrossEmbedder
shared_cross_encoder = CrossEmbedder("BAAI/bge-reranker-base")

user_query_list = []
results_lst = []
expected_retrievals_lst = []
test_case_id_lst = []
retrieval_type_lst= []
for idx, test_case in enumerate(test_cases):
    print(f"Test Case {idx+1}: {test_case['id']}")
    test_case_id_lst.append(test_case['id'])
    user_query = test_case["conversation"][0]["user_query"]
    user_query_list.append(user_query)
    print("Query:", user_query)
    expected_retrievals = test_case.get("expected_retrieval", {})
    if len(expected_retrievals) > 1:
        test_case_id_lst.pop()  # remove last appended test case id
        user_query_list.pop()
        print("Skipping multi memory store retrieval test case")
        continue
    assert len(expected_retrievals) == 1, "This cell supports single memory store retrieval"
    for key, value in expected_retrievals.items():
        retrieval_type_lst.append(key)
        if key == "ltm":
            print("Retrieving from ltm...")

            retrieved_ltm = retrieval_agent.retrieve_rerank(user_query, mode="long-term", cross_encoder=shared_cross_encoder)
            expected_retrievals_lst.append(value)
            results_lst.append(retrieved_ltm)

        elif key == "stm":
            print("Retrieving from stm...")
            retrieved_stm = retrieval_agent.retrieve_rerank(user_query, mode="short-term", cross_encoder=shared_cross_encoder)
            expected_retrievals_lst.append(value)
            results_lst.append(retrieved_stm)
        elif key == "hcm":
            print("Retrieving from hcm...")
            retrieved_hcm = retrieval_agent.retrieve_rerank(user_query, mode="healthcare", cross_encoder=shared_cross_encoder)
            expected_retrievals_lst.append(value)
            results_lst.append(retrieved_hcm)
        else:
            print(f"Unknown retrieval type: {key}")


INFO:root:Loading CrossEncoder model: BAAI/bge-reranker-base
INFO:sentence_transformers.cross_encoder.CrossEncoder:Use pytorch device: cpu
INFO:root:✅ CrossEncoder model loaded: BAAI/bge-reranker-base


Test Case 1: LTM_1
Query: Who is your daughter-in-law?
Retrieving from ltm...


Batches: 100%|██████████| 1/1 [00:00<00:00,  6.66it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  1.72it/s]


Test Case 2: LTM_2
Query: Who usually joins me for birthday celebrations?
Retrieving from ltm...


Batches: 100%|██████████| 1/1 [00:00<00:00, 10.58it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  1.81it/s]


Test Case 3: HCM_1
Query: What medications do I take every morning?
Retrieving from hcm...


Batches: 100%|██████████| 1/1 [00:00<00:00,  6.96it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  3.42it/s]


Test Case 4: STM_1
Query: What should I have for dinner later?
Retrieving from stm...


Batches: 100%|██████████| 1/1 [00:00<00:00,  7.68it/s]
Batches: 100%|██████████| 1/1 [00:01<00:00,  1.09s/it]


Test Case 5: STM_2
Query: What was I supposed to do in the morning?
Retrieving from stm...


Batches: 100%|██████████| 1/1 [00:00<00:00,  9.19it/s]
Batches: 100%|██████████| 1/1 [00:01<00:00,  1.19s/it]


Test Case 6: STM_3
Query: When was the health talk again?
Retrieving from stm...


Batches: 100%|██████████| 1/1 [00:00<00:00,  8.76it/s]
Batches: 100%|██████████| 1/1 [00:01<00:00,  1.07s/it]


Test Case 7: Temporal_1
Query: What should I eat now?
Retrieving from stm...


Batches: 100%|██████████| 1/1 [00:00<00:00, 10.13it/s]
Batches: 100%|██████████| 1/1 [00:01<00:00,  1.15s/it]


Test Case 8: MultiMem_1
Query: Should I bring someone for my morning walk tomorrow?
Retrieving from ltm...


Batches: 100%|██████████| 1/1 [00:00<00:00,  6.40it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  1.90it/s]


Test Case 9: MultiMem_2
Query: Can I still play badminton?
Skipping multi memory store retrieval test case
Test Case 10: MultiMem_3
Query: What’s my usual breakfast?
Retrieving from ltm...


Batches: 100%|██████████| 1/1 [00:00<00:00, 10.65it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  2.36it/s]


Test Case 11: MultiMem_4
Query: Can I drink coffee with my medicines?
Skipping multi memory store retrieval test case
Test Case 12: LTM_3
Query: Who is my grandson?
Retrieving from ltm...


Batches: 100%|██████████| 1/1 [00:00<00:00, 12.75it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  2.43it/s]


Test Case 13: STM_3
Query: Who did I meet yesterday at the bakery?
Retrieving from stm...


Batches: 100%|██████████| 1/1 [00:00<00:00,  9.45it/s]
Batches: 100%|██████████| 1/1 [00:01<00:00,  1.04s/it]


Test Case 14: HCM_2
Query: What health conditions do I currently have?
Retrieving from hcm...


Batches: 100%|██████████| 1/1 [00:00<00:00, 12.67it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  4.81it/s]


Test Case 15: MultiMem_5
Query: Who do I usually meet for my morning walks?
Skipping multi memory store retrieval test case
Test Case 16: MultiMem_6
Query: Should I attend the health talk next week?
Skipping multi memory store retrieval test case
Test Case 17: Temporal_2
Query: What did I plan to buy tomorrow morning?
Retrieving from stm...


Batches: 100%|██████████| 1/1 [00:00<00:00, 11.16it/s]
Batches: 100%|██████████| 1/1 [00:01<00:00,  1.02s/it]


Test Case 18: CrossMem_1
Query: What hobbies do I enjoy that keep me healthy?
Skipping multi memory store retrieval test case
Test Case 19: CrossMem_2
Query: What did I say about my grandson’s school project?
Skipping multi memory store retrieval test case
Test Case 20: MultiMem_7
Query: Do I still enjoy gardening despite my knee surgery?
Skipping multi memory store retrieval test case
Test Case 21: STM_4
Query: What did I enjoy listening to yesterday evening?
Retrieving from stm...


Batches: 100%|██████████| 1/1 [00:00<00:00,  2.21it/s]
Batches: 100%|██████████| 1/1 [00:01<00:00,  1.12s/it]


In [6]:
all_results = {
    "test_case_id" : test_case_id_lst,
    "retrieval_type" : retrieval_type_lst,
    "user_query" : user_query_list,
    "expected_retrievals" : expected_retrievals_lst,
    "results" : results_lst
}

all_results = pd.DataFrame(all_results)


In [7]:
def reattach_doc_id(row):
    with open("../test_cases/doc_mapping.json", "r") as f:
        doc_mapping = json.load(f)
    retrieval_type = row["retrieval_type"]
    results_list = row["results"]
    if retrieval_type == "ltm":
        for result in results_list:
            result['document_id'] = doc_mapping['ltm'].get(result['value'], None)
    elif retrieval_type == "stm":
        for result in results_list:
            result['document_id'] = doc_mapping['stm'].get(result['content'], None)
    elif retrieval_type == "hcm":
        for result in results_list:
            result['document_id'] = doc_mapping['hcm'].get(result['description'], None)
    return results_list

all_results['results'] = all_results.apply(reattach_doc_id, axis=1)

In [8]:
def extract_retrieved_ids(row):
    results = row['results']
    retrieved_ids = [res.get('document_id', None) for res in results if res.get('document_id', None) is not None]
    return retrieved_ids

def extract_reference_ids(row):
    expected_retrievals = row['expected_retrievals']
    reference_ids = [ref.get('document_id', None) for ref in expected_retrievals if ref.get('document_id', None) is not None]
    return reference_ids

all_results['retrieved_ids'] = all_results.apply(extract_retrieved_ids, axis=1)
all_results['reference_ids'] = all_results.apply(extract_reference_ids, axis=1)


### Compute context precision and context recall

#### Context Precision (ID BASED)

    Context Precision = (Number of retrieved_Context IDs found in reference Context IDs) / (Total number of retrieved context IDs)

#### Context Recall (ID BASED):
    Context Recall = (Number of reference context IDs found in retrieved context IDs) / (Total number of reference context IDs)


Referencing RAGAS evaluation template:
https://docs.ragas.io/en/stable/concepts/metrics/available_metrics/context_recall/
https://docs.ragas.io/en/stable/concepts/metrics/available_metrics/context_precision/

In [9]:
### Calculate Context Precision and Context Recall

def calculate_context_precision(row):
    retrieved_ids = set(row['retrieved_ids'])
    reference_ids = set(row['reference_ids'])
    if not retrieved_ids:
        return 0.0
    true_positives = len(retrieved_ids.intersection(reference_ids))
    precision = true_positives / len(retrieved_ids)
    return precision

def calculate_context_recall(row):
    retrieved_ids = set(row['retrieved_ids'])
    reference_ids = set(row['reference_ids'])
    if not reference_ids:
        return 0.0
    true_positives = len(retrieved_ids.intersection(reference_ids))
    recall = true_positives / len(reference_ids)
    return recall

all_results['context_precision'] = all_results.apply(calculate_context_precision, axis=1)
all_results['context_recall'] = all_results.apply(calculate_context_recall, axis=1)
all_results['num_retrieved'] = all_results['retrieved_ids'].apply(lambda x: len(x) if x else 0)

display(all_results)

Unnamed: 0,test_case_id,retrieval_type,user_query,expected_retrievals,results,retrieved_ids,reference_ids,context_precision,context_recall,num_retrieved
0,LTM_1,ltm,Who is your daughter-in-law?,"[{'document_id': 'LTM001', 'key': 'Daughter-in...","[{'id': 742e0dad-2fc7-4e1b-85ac-89824f895bcb, ...","[LTM_001, LTM_012, LTM_011, LTM_005, LTM_009, ...",[LTM001],0.0,0.0,8
1,LTM_2,ltm,Who usually joins me for birthday celebrations?,"[{'document_id': 'LTM005', 'key': 'family_rela...","[{'id': 658e74b4-cb89-40d6-b00c-110f7735fba0, ...","[LTM_005, LTM_009, LTM_008, LTM_011, LTM_010, ...",[LTM005],0.0,0.0,8
2,HCM_1,hcm,What medications do I take every morning?,"[{'document_id': 'HCM_001', 'type': 'medicatio...","[{'id': 2a635240-5a0a-43e8-bd23-fbea0173bc6b, ...","[HCM_001, HCM_002, HCM_005, HCM_004, HCM_003, ...","[HCM_001, HCM_002]",0.333333,1.0,6
3,STM_1,stm,What should I have for dinner later?,"[{'document_id': 'STM014', 'timestamp': '2025-...","[{'id': 95d4cf0e-bf89-49a4-acf5-66c810751bd8, ...","[STM_032, STM_042, STM_053, STM_040, STM_011, ...",[STM014],0.0,0.0,8
4,STM_2,stm,What was I supposed to do in the morning?,"[{'document_id': 'STM_011', 'timestamp': '2025...","[{'id': b5d08317-22f7-423f-95a9-acf88fc3dbba, ...","[STM_007, STM_037, STM_053, STM_040, STM_043, ...",[STM_011],0.0,0.0,8
5,STM_3,stm,When was the health talk again?,"[{'document_id': 'STM_052', 'timestamp': '2025...","[{'id': c7614e77-46de-447f-8144-48e53e50c8a3, ...","[STM_052, STM_048, STM_031, STM_035, STM_054, ...",[STM_052],0.125,1.0,8
6,Temporal_1,stm,What should I eat now?,"[{'document_id': 'STM_010', 'timestamp': '2025...","[{'id': 95d4cf0e-bf89-49a4-acf5-66c810751bd8, ...","[STM_032, STM_053, STM_051, STM_052, STM_043, ...","[STM_010, STM_014, STM_003]",0.0,0.0,8
7,MultiMem_1,ltm,Should I bring someone for my morning walk tom...,"[{'document_id': 'LTM_004', 'category': 'lifes...","[{'id': e8163365-e914-4376-8381-0318268c509d, ...","[LTM_009, LTM_005, LTM_008, LTM_011, LTM_010, ...","[LTM_004, LTM_003]",0.0,0.0,8
8,MultiMem_3,ltm,What’s my usual breakfast?,"[{'document_id': 'LTM_008', 'category': 'lifes...","[{'id': ddcdb531-51fb-463f-b346-b8fb755c8503, ...","[LTM_014, LTM_008, LTM_010, LTM_005, LTM_011, ...",[LTM_008],0.125,1.0,8
9,LTM_3,ltm,Who is my grandson?,"[{'document_id': 'LTM_011', 'category': 'famil...","[{'id': 17036b38-b468-4658-bee8-cb6b801852fb, ...","[LTM_002, LTM_001, LTM_009, LTM_008, LTM_010, ...",[LTM_011],0.125,1.0,8


#### Calculate average context precision and context recall

In [10]:
average_context_precision = all_results['context_precision'].mean()
average_context_recall = all_results['context_recall'].mean()

print(f"Average context precision: {average_context_precision}")
print(f"Average context recall: {average_context_recall}")

Average context precision: 0.09226190476190475
Average context recall: 0.5


In [None]:
import itertools
import numpy as np
from tqdm import tqdm

# Define search space for key parameters
param_grid = {
    "top_k_retrieval": [10, 15, 20, 25, 30],
    "sim_threshold": [0.2, 0.3, 0.4],
    "alpha_retrieval": [0.3, 0.5, 0.7],
    "alpha_MMR": [0.6, 0.75, 0.9],
    "beta_recency": [0.05, 0.1, 0.2],
}

# Store results
grid_results = []

# Generate all combinations of parameters
param_combinations = list(itertools.product(
    param_grid["top_k_retrieval"],
    param_grid["sim_threshold"],
    param_grid["alpha_retrieval"],
    param_grid["alpha_MMR"],
    param_grid["beta_recency"],
))

print(f"Testing {len(param_combinations)} parameter combinations...")

for params in tqdm(param_combinations):
    top_k, sim_th, alpha_r, alpha_mmr, beta_r = params

    context_precisions = []
    context_recalls = []

    for _, test_case in enumerate(test_cases):
        expected_retrievals = test_case.get("expected_retrieval", {})
        if len(expected_retrievals) != 1:
            continue  # skip multi-memory test cases
        key = list(expected_retrievals.keys())[0]
        user_query = test_case["conversation"][0]["user_query"]

        if key == "ltm":
            results = retrieval_agent.retrieve_rerank(
                user_query, mode="long-term",
                top_k_retrieval=top_k,
                sim_threshold=sim_th,
                alpha_retrieval=alpha_r,
                alpha_MMR=alpha_mmr,
                beta_recency=beta_r, cross_encoder=shared_cross_encoder
            )
        elif key == "stm":
            results = retrieval_agent.retrieve_rerank(
                user_query, mode="short-term",
                top_k_retrieval=top_k,
                sim_threshold=sim_th,
                alpha_retrieval=alpha_r,
                alpha_MMR=alpha_mmr,
                beta_recency=beta_r, cross_encoder=shared_cross_encoder
            )
        elif key == "hcm":
            results = retrieval_agent.retrieve_rerank(
                user_query, mode="healthcare",
                top_k_retrieval=top_k,
                sim_threshold=sim_th,
                alpha_retrieval=alpha_r,
                alpha_MMR=alpha_mmr,
                beta_recency=beta_r, cross_encoder=shared_cross_encoder
            )
        else:
            continue

        # Extract retrieved IDs
        retrieved_ids = [r.get("document_id", None) for r in results if r.get("document_id")]
        reference_ids = [ref.get("document_id", None) for ref in list(expected_retrievals.values())[0] if ref.get("document_id")]

        # Compute precision and recall
        if not retrieved_ids or not reference_ids:
            continue
        retrieved_set, reference_set = set(retrieved_ids), set(reference_ids)
        tp = len(retrieved_set & reference_set)
        precision = tp / len(retrieved_set) if retrieved_set else 0
        recall = tp / len(reference_set) if reference_set else 0

        context_precisions.append(precision)
        context_recalls.append(recall)

    # Compute averages
    avg_precision = np.mean(context_precisions) if context_precisions else 0
    avg_recall = np.mean(context_recalls) if context_recalls else 0
    f1 = 2 * avg_precision * avg_recall / (avg_precision + avg_recall + 1e-8)

    grid_results.append({
        "top_k_retrieval": top_k,
        "sim_threshold": sim_th,
        "alpha_retrieval": alpha_r,
        "alpha_MMR": alpha_mmr,
        "beta_recency": beta_r,
        "avg_precision": avg_precision,
        "avg_recall": avg_recall,
        "f1_score": f1
    })

# Convert results to DataFrame for inspection
grid_df = pd.DataFrame(grid_results)
best_config = grid_df.loc[grid_df["f1_score"].idxmax()]
display(grid_df.sort_values("f1_score", ascending=False).head(10))

print("\n🏆 Best Configuration:")
print(best_config)


Testing 405 parameter combinations...


Batches: 100%|██████████| 1/1 [00:00<00:00,  7.41it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  3.83it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.78it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  3.44it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  9.32it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  2.86it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  9.03it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  2.74it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  9.21it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  2.35it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  9.55it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  1.88it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  7.23it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  2.32it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  8.35it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  2.88it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.29it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  3.34it/s]
Batches: 1