In [23]:
import os, sys
sys.path.append(os.path.abspath("../.."))

In [24]:
import os

import pandas as pd
from dotenv import load_dotenv
import psycopg2
from psycopg2.extras import execute_values
from pyarrow.compute import top_k_unstable
from sentence_transformers import SentenceTransformer
from huggingface_hub import login
import json
import pprint

from RAG.retrieval_agent import RetrievalAgent

In [25]:
with open("../test_cases/knowledge_base.json") as f:
    knowledge_base = json.load(f)
    stm_data = knowledge_base["STM_data"]
    ltm_data = knowledge_base["LTM_data"]
    hcm_data = knowledge_base["HCM_data"]

stm_data_doc_ids_mapping = {stm['content']:stm['document_id'] for stm in stm_data}
ltm_data_doc_ids_mapping = {ltm_document['value']: ltm_document['document_id'] for ltm_document in ltm_data}
hcm_data_doc_ids_mapping = {hcm_record['description']: hcm_record['document_id'] for hcm_record in hcm_data}

consolidated_mapping = {
    'stm': stm_data_doc_ids_mapping,
    'ltm': ltm_data_doc_ids_mapping,
    'hcm': hcm_data_doc_ids_mapping
}
with open("../test_cases/doc_mapping.json", "w") as f:
    json.dump(consolidated_mapping, f, indent=4)

### Test with the rag_test_cases.json file

In [26]:
test_elderly_id = "87654321-4321-4321-4321-019876543210"
retrieval_agent = RetrievalAgent(test_elderly_id)

INFO:root:Loading embedding model: google/embeddinggemma-300m
INFO:sentence_transformers.SentenceTransformer:Use pytorch device_name: cpu
INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: google/embeddinggemma-300m
INFO:sentence_transformers.SentenceTransformer:14 prompts are loaded, with the keys: ['query', 'document', 'BitextMining', 'Clustering', 'Classification', 'InstructionRetrieval', 'MultilabelClassification', 'PairClassification', 'Reranking', 'Retrieval', 'Retrieval-query', 'Retrieval-document', 'STS', 'Summarization']
INFO:root:embedding model loaded successfully: google/embeddinggemma-300m


### Perform evaluation for single test memory store retrival


In [27]:
with open("../test_cases/augmented_agentic_rag_test_cases.json", "r") as f:
    test_cases = json.load(f)

user_query_list = []
results_lst = []
expected_retrievals_lst = []
test_case_id_lst = []
retrieval_type_lst= []
for idx, test_case in enumerate(test_cases):
    print(f"Test Case {idx+1}: {test_case['id']}")
    test_case_id_lst.append(test_case['id'])
    user_query = test_case["conversation"][0]["user_query"]
    user_query_list.append(user_query)
    print("Query:", user_query)
    expected_retrievals = test_case.get("expected_retrieval", {})
    if len(expected_retrievals) > 1:
        test_case_id_lst.pop()  # remove last appended test case id
        user_query_list.pop()
        print("Skipping multi memory store retrieval test case")
        continue
    assert len(expected_retrievals) == 1, "This cell supports single memory store retrieval"
    for key, value in expected_retrievals.items():
        retrieval_type_lst.append(key)
        if key == "ltm":
            print("Retrieving from ltm...")
            retrieved_ltm = retrieval_agent.retrieve_similar_ltm(user_query)
            expected_retrievals_lst.append(value)
            results_lst.append(retrieved_ltm)
        elif key == "stm":
            print("Retrieving from stm...")
            retrieved_stm = retrieval_agent.retrieve_similar_stm(user_query)
            expected_retrievals_lst.append(value)
            results_lst.append(retrieved_stm)
        elif key == "hcm":
            print("Retrieving from hcm...")
            retrieved_hcm = retrieval_agent.retrieve_similar_health(user_query)
            expected_retrievals_lst.append(value)
            results_lst.append(retrieved_hcm)
        else:
            print(f"Unknown retrieval type: {key}")


Test Case 1: LTM_1
Query: Who is your daughter-in-law?
Retrieving from ltm...


Batches: 100%|██████████| 1/1 [00:00<00:00,  6.51it/s]


Test Case 2: LTM_2
Query: Who usually joins me for birthday celebrations?
Retrieving from ltm...


Batches: 100%|██████████| 1/1 [00:00<00:00, 11.65it/s]


Test Case 3: HCM_1
Query: What medications do I take every morning?
Retrieving from hcm...


Batches: 100%|██████████| 1/1 [00:00<00:00,  7.43it/s]

Test Case 4: STM_1





Query: What should I have for dinner later?
Retrieving from stm...


Batches: 100%|██████████| 1/1 [00:00<00:00,  7.70it/s]

Test Case 5: STM_2
Query: What was I supposed to do in the morning?
Retrieving from stm...



Batches: 100%|██████████| 1/1 [00:00<00:00,  7.50it/s]


Test Case 6: STM_3
Query: When was the health talk again?
Retrieving from stm...


Batches: 100%|██████████| 1/1 [00:00<00:00,  5.44it/s]


Test Case 7: Temporal_1
Query: What should I eat now?
Retrieving from stm...


Batches: 100%|██████████| 1/1 [00:00<00:00,  7.26it/s]


Test Case 8: MultiMem_1
Query: Should I bring someone for my morning walk tomorrow?
Retrieving from ltm...


Batches: 100%|██████████| 1/1 [00:00<00:00,  6.90it/s]

Test Case 9: MultiMem_2





Query: Can I still play badminton?
Skipping multi memory store retrieval test case
Test Case 10: MultiMem_3
Query: What’s my usual breakfast?
Retrieving from ltm...


Batches: 100%|██████████| 1/1 [00:00<00:00,  9.36it/s]


Test Case 11: MultiMem_4
Query: Can I drink coffee with my medicines?
Skipping multi memory store retrieval test case
Test Case 12: LTM_3
Query: Who is my grandson?
Retrieving from ltm...


Batches: 100%|██████████| 1/1 [00:00<00:00,  9.44it/s]


Test Case 13: STM_3
Query: Who did I meet yesterday at the bakery?
Retrieving from stm...


Batches: 100%|██████████| 1/1 [00:00<00:00,  4.75it/s]


Test Case 14: HCM_2
Query: What health conditions do I currently have?
Retrieving from hcm...


Batches: 100%|██████████| 1/1 [00:00<00:00,  7.40it/s]


Test Case 15: MultiMem_5
Query: Who do I usually meet for my morning walks?
Skipping multi memory store retrieval test case
Test Case 16: MultiMem_6
Query: Should I attend the health talk next week?
Skipping multi memory store retrieval test case
Test Case 17: Temporal_2
Query: What did I plan to buy tomorrow morning?
Retrieving from stm...


Batches: 100%|██████████| 1/1 [00:00<00:00,  7.41it/s]


Test Case 18: CrossMem_1
Query: What hobbies do I enjoy that keep me healthy?
Skipping multi memory store retrieval test case
Test Case 19: CrossMem_2
Query: What did I say about my grandson’s school project?
Skipping multi memory store retrieval test case
Test Case 20: MultiMem_7
Query: Do I still enjoy gardening despite my knee surgery?
Skipping multi memory store retrieval test case
Test Case 21: STM_4
Query: What did I enjoy listening to yesterday evening?
Retrieving from stm...


Batches: 100%|██████████| 1/1 [00:00<00:00,  8.53it/s]


In [28]:
all_results = {
    "test_case_id" : test_case_id_lst,
    "retrieval_type" : retrieval_type_lst,
    "user_query" : user_query_list,
    "expected_retrievals" : expected_retrievals_lst,
    "results" : results_lst
}

all_results = pd.DataFrame(all_results)


In [29]:
def reattach_doc_id(row):
    with open("../test_cases/doc_mapping.json", "r") as f:
        doc_mapping = json.load(f)
    retrieval_type = row["retrieval_type"]
    results_list = row["results"]
    if retrieval_type == "ltm":
        for result in results_list:
            result['document_id'] = doc_mapping['ltm'].get(result['value'], None)
    elif retrieval_type == "stm":
        for result in results_list:
            result['document_id'] = doc_mapping['stm'].get(result['content'], None)
    elif retrieval_type == "hcm":
        for result in results_list:
            result['document_id'] = doc_mapping['hcm'].get(result['description'], None)
    return results_list

all_results['results'] = all_results.apply(reattach_doc_id, axis=1)

In [30]:
def extract_retrieved_ids(row):
    results = row['results']
    retrieved_ids = [res.get('document_id', None) for res in results if res.get('document_id', None) is not None]
    return retrieved_ids

def extract_reference_ids(row):
    expected_retrievals = row['expected_retrievals']
    reference_ids = [ref.get('document_id', None) for ref in expected_retrievals if ref.get('document_id', None) is not None]
    return reference_ids

all_results['retrieved_ids'] = all_results.apply(extract_retrieved_ids, axis=1)
all_results['reference_ids'] = all_results.apply(extract_reference_ids, axis=1)


### Compute context precision and context recall

#### Context Precision (ID BASED)

    Context Precision = (Number of retrieved_Context IDs found in reference Context IDs) / (Total number of retrieved context IDs)

#### Context Recall (ID BASED):
    Context Recall = (Number of reference context IDs found in retrieved context IDs) / (Total number of reference context IDs)


Referencing RAGAS evaluation template:
https://docs.ragas.io/en/stable/concepts/metrics/available_metrics/context_recall/
https://docs.ragas.io/en/stable/concepts/metrics/available_metrics/context_precision/

In [31]:
### Calculate Context Precision and Context Recall

def calculate_context_precision(row):
    retrieved_ids = set(row['retrieved_ids'])
    reference_ids = set(row['reference_ids'])
    if not retrieved_ids:
        return 0.0
    true_positives = len(retrieved_ids.intersection(reference_ids))
    precision = true_positives / len(retrieved_ids)
    return precision

def calculate_context_recall(row):
    retrieved_ids = set(row['retrieved_ids'])
    reference_ids = set(row['reference_ids'])
    if not reference_ids:
        return 0.0
    true_positives = len(retrieved_ids.intersection(reference_ids))
    recall = true_positives / len(reference_ids)
    return recall

all_results['context_precision'] = all_results.apply(calculate_context_precision, axis=1)
all_results['context_recall'] = all_results.apply(calculate_context_recall, axis=1)
all_results['num_retrieved'] = all_results['retrieved_ids'].apply(lambda x: len(x) if x else 0)


display(all_results)

Unnamed: 0,test_case_id,retrieval_type,user_query,expected_retrievals,results,retrieved_ids,reference_ids,context_precision,context_recall,num_retrieved
0,LTM_1,ltm,Who is your daughter-in-law?,"[{'document_id': 'LTM001', 'key': 'Daughter-in...","[{'category': 'family', 'key': 'Ganesh Menon D...","[LTM_002, LTM_012, LTM_001, LTM_005, LTM_007]",[LTM001],0.0,0.0,5
1,LTM_2,ltm,Who usually joins me for birthday celebrations?,"[{'document_id': 'LTM005', 'key': 'family_rela...","[{'category': 'family', 'key': 'family_relatio...","[LTM_005, LTM_001, LTM_012, LTM_013, LTM_007]",[LTM005],0.0,0.0,5
2,HCM_1,hcm,What medications do I take every morning?,"[{'document_id': 'HCM_001', 'type': 'medicatio...","[{'record_type': 'medication', 'description': ...","[HCM_001, HCM_002, HCM_004, HCM_003, HCM_005]","[HCM_001, HCM_002]",0.4,1.0,5
3,STM_1,stm,What should I have for dinner later?,"[{'document_id': 'STM014', 'timestamp': '2025-...",[{'content': 'I want to eat salad for lunch to...,"[STM_032, STM_009, STM_021, STM_046, STM_020]",[STM014],0.0,0.0,5
4,STM_2,stm,What was I supposed to do in the morning?,"[{'document_id': 'STM_011', 'timestamp': '2025...",[{'content': 'I made toast and eggs for breakf...,"[STM_007, STM_001, STM_015, STM_047, STM_049]",[STM_011],0.0,0.0,5
5,STM_3,stm,When was the health talk again?,"[{'document_id': 'STM_052', 'timestamp': '2025...",[{'content': 'I heard there will be a health t...,"[STM_052, STM_001, STM_037, STM_016, STM_023]",[STM_052],0.2,1.0,5
6,Temporal_1,stm,What should I eat now?,"[{'document_id': 'STM_010', 'timestamp': '2025...",[{'content': 'I want to eat salad for lunch to...,"[STM_032, STM_009, STM_046, STM_001, STM_051]","[STM_010, STM_014, STM_003]",0.0,0.0,5
7,MultiMem_1,ltm,Should I bring someone for my morning walk tom...,"[{'document_id': 'LTM_004', 'category': 'lifes...","[{'category': 'lifestyle', 'key': 'routine', '...","[LTM_009, LTM_004, LTM_014, LTM_008, LTM_007]","[LTM_004, LTM_003]",0.2,0.5,5
8,MultiMem_3,ltm,What’s my usual breakfast?,"[{'document_id': 'LTM_008', 'category': 'lifes...","[{'category': 'lifestyle', 'key': 'routine', '...","[LTM_008, LTM_004, LTM_014, LTM_001, LTM_012]",[LTM_008],0.2,1.0,5
9,LTM_3,ltm,Who is my grandson?,"[{'document_id': 'LTM_011', 'category': 'famil...","[{'category': 'family', 'key': 'Ganesh Menon D...","[LTM_002, LTM_001, LTM_012, LTM_005, LTM_007]",[LTM_011],0.0,0.0,5


#### Calculate average context precision and context recall

In [32]:
average_context_precision = all_results['context_precision'].mean()
average_context_recall = all_results['context_recall'].mean()

print(f"Average context precision: {average_context_precision}")
print(f"Average context recall: {average_context_recall}")

Average context precision: 0.14285714285714285
Average context recall: 0.5357142857142857
