In [None]:
from modularized import Model
import os
import jsonlines

# Set Together API key 
os.environ["TOGETHER_API_KEY"] = "random"

Select Dataset

In [None]:
# Sort PIDs based on length of notes & number of notes 

level1_note1_pids = []
level1_note2_pids = []
level2_note1_pids = []
level2_note2_pids = []
level2_note3_pids = []

with jsonlines.open('EHRNoteQA.jsonl') as reader:
    # Cycle through QAs and sort pids based on level and pid 
    for obj in reader:
        pid = obj["patient_id"]
        if obj['category'] == 'level1': 
            if obj['num_notes'] == 1: 
                level1_note1_pids.append(pid)
            elif obj['num_notes'] == 2: 
                level1_note2_pids.append(pid)
        elif obj['category'] == 'level2':
            if obj['num_notes'] == 1: 
                level2_note1_pids.append(pid)
            elif obj['num_notes'] == 2: 
                level2_note2_pids.append(pid)
            elif obj['num_notes'] == 3: 
                level2_note3_pids.append(pid)

#print(level1_note1_pids) 
#print(level1_note2_pids) 
#print(level2_note1_pids) 
#print(level2_note2_pids) 
#print(level2_note3_pids) 

print(len(level1_note1_pids))
print(len(level1_note2_pids))
print(len(level2_note1_pids)) 
print(len(level2_note2_pids)) 
print(len(level2_note3_pids)) 


In [None]:
# Create dataset (train, validation, test split)

# Each set has 4 patients with 1 note, 4 patients with 2 notes, and 4 patients with 3 notes 
train_set = level1_note1_pids[:2] + level2_note1_pids[:2] + level1_note2_pids[:2] + level2_note2_pids[:2] + level2_note3_pids[:4]
validation_set = level1_note1_pids[2:4] + level2_note1_pids[2:4] + level1_note2_pids[2:4] + level2_note2_pids[2:4] + level2_note3_pids[4:8]
test_set = level1_note1_pids[4:6] + level2_note1_pids[4:6] + level1_note2_pids[4:6] + level2_note2_pids[4:6] + level2_note3_pids[8:12]
dataset = train_set + validation_set + test_set 

print(train_set)
print(len(train_set))
print(validation_set)
print(len(validation_set))
print(test_set)
print(len(test_set))
print(dataset)
print(len(dataset))


Load model 

In [None]:
# Model system components 
# Data:       Loads initial patient database from file + preprocesses by chunking notes 
# Embedder:   Loads embedding model + can embed data 
# Storage:    Populates vector databases w/ patient data + can query indices 
# Generator:  Loads LLM model + can generate responses (untested!)

data_path = "C:/Users/sharp/Documents/Research/Adelaide/clinical_RAG/dataset.csv"
embed_model = "ClinicalBERT"
llm_model = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
api_key = os.environ["TOGETHER_API_KEY"]

model = Model(data_path, embed_model, api_key, llm_model)           

Dataset Overview 

In [None]:
# Print dataset stats  
model.data.print_stats()

In [None]:
# Calculates precision, recall, and F1 given you have counts for each sample of retreived + relevant (TP), retrieved (TP+FP), relevant (TP+FP)
def print_metrics(correct_counts, total_retrieved_counts, total_correct_counts, total_counts): 
    
    print(correct_counts)
    print(total_retrieved_counts)
    print(total_correct_counts)
    print(total_counts)

    num_samples = len(correct_counts)
   
    metrics = []

    # TODO: handle edges cases 
    for i in range(num_samples): 

        # Get counts for sample 
        correct = correct_counts[i]
        total_retrieved = total_retrieved_counts[i]
        total_correct = total_correct_counts[i]

        # If no chunks retrieved -> 100 precision (0/0)
        # If chunks retrived -> calculate precision (?/#)
        if total_retrieved == 0: prec = 100 
        else: prec = (correct/total_retrieved) * 100 

        # If no chunks relevant, and none retrieved -> 100 recall (0/0)
         # If no chunks relevant, yet some retrieved -> error      (#/0)
        # If chunks relevant, retrieved unknown -> calculate recall      (?/#)
        if total_correct == 0: 
            if correct == 0: rec = 100 
            else: print("error")
        else: rec = (correct/total_correct) * 100

        f1 = 2*(prec * rec)/(prec + rec)

        # Update total precision, recall, f1 
        metrics.append((prec, rec, f1))

    avg_p = sum(m[0] for m in metrics) / num_samples
    avg_r = sum(m[1] for m in metrics) / num_samples
    avg_f1 = sum(m[2] for m in metrics) / num_samples

    print(f"\nAverage precision: {avg_p}")
    print(f"Average recall: {avg_r}")
    print(f"F1: {avg_f1}\n")

In [None]:
def evaluate_set(pids, threshold_eval, threshold_search):

    TP_all = []
    TP_correct = []

    retrieved = []
    total_chunks = []

    relevant_all = []
    relevant_correct = []

    
    with jsonlines.open('EHRNoteQA.jsonl') as reader:
        # Cycle throuh QAs corresponding with subset 
        for obj in reader:
            pid = obj["patient_id"] 
            if pid in pids:
                query = obj["question"]
                answer =  obj["answer"]
                answer_choices = {'A': obj["choice_A"], 'B': obj["choice_B"], 'C': obj["choice_C"], 'D': obj["choice_D"], 'E': obj["choice_E"]}

                #print(f"Question:\n{query}\n")
                #print(f"Answer choices:\n{answer_choices}\n")
                #print(f"Correct answer: {answer}\n")

                # Step 1: 
                # Retrieve patient data using query
                retrieved_chunks = model.storage.search_index(pid, query, threshold_search)

                # Evaluate 
                all, correct, total_retrieved = model.storage.evaluate_retrieved(pid, retrieved_chunks, answer_choices, answer, threshold_eval) # By default also prints 
                total_all, total_correct, total = model.storage.evaluate_relevant(pid, answer_choices, answer, threshold_eval)

                TP_all.append(all)
                TP_correct.append(correct) 
                retrieved.append(total_retrieved)
                relevant_all.append(total_all)
                relevant_correct.append(total_correct)
                total_chunks.append(total)

            
                # Step 2: 
                # Generate output using retreived patient data and query (confirm if this works first in normal pipeline)
                # answer = model.generator.generate_response_from_chunks(query, retrieved)
                # print("== RESPONSE:==")
                # print(answer)

                # Evaluate generated output 

    print('Retrieval Metrics if all answers are considered relevant:\n')
    print_metrics(TP_all, retrieved, relevant_all, total_chunks) 

    print('Retrieval Metrics if only correct answer is considered relevant:\n')
    print_metrics(TP_correct, retrieved, relevant_correct, total_chunks) 


        
        




In [None]:
# TODO: need to tune after doing some data exploration 
#evaluate_set(train_set, 0.85, 55)
#evaluate_set(validation_set, .85, 55)
evaluate_set(train_set, 0.80, 60)
evaluate_set(validation_set, .80, 60)

In [None]:
evaluate_set(test_set, .80, 60)

Generation 