##### 1. Normalizer Function
- lowercasing: "Converting all text to lowercase is a
simple normalization technique that helps in reducing
the vocabulary size by treating words in uppercase
and lowercase as the same"
- removing puntuation: "Punctuation marks often do
not carry significant meaning in many text processing
tasks. Removing punctuation marks can help simplify
the text and reduce noise in the data."
- spelling check: using 'pyspellchecker'
- medical synonyms
- acronym expansion

Reference: [Systemaatic Review on Text Normalization Techniques and its Approach to Non-Standard Words](https://www.researchgate.net/publication/374166354_Systematic_Review_on_Text_Normalization_Techniques_and_its_Approach_to_Non-Standard_Words)

In [None]:
import re
from spellchecker import SpellChecker
spell = SpellChecker()

# --- (Re-use and expand your medical lexicon from the notebook) ---
MEDICAL_LEXICON = {
    # == General Medical Acronyms (Common) ==
    "hx": "history",
    "h/o": "history of",
    "dx": "diagnosis",
    "ddx": "differential diagnosis",
    "tx": "treatment",
    "sx": "symptoms",
    "c/o": "complains of",
    "pt": "patient",

    # == General Demographics ==
    "yo": "year old",
    "y/o": "year old",
    "y.o.": "year old",
    "m": "male",
    "f": "female",
    
    # == From SYMPTOM_LEXICON (Cell 22) ==
    "pyrexia": "fever",
    "febrile": "fever",
    "cephalgia": "headache",
    "maculopapular": "rash",
    "petechiae": "rash",
    "urticaria": "rash",
    "vesicular": "rash",
    "diarrhoea": "diarrhea",
    "loose stools": "diarrhea",
    "emesis": "vomiting",
    "abd pain": "abdominal pain",
    "stomach pain": "abdominal pain",
    "icterus": "jaundice",
    "muscle pain": "myalgia",
    "joint pain": "arthralgia",
    "shortness of breath": "dyspnea",
    "itch": "pruritus",
    "blood in urine": "hematuria",
    "melaena": "melena",
    "red eyes": "conjunctivitis",
    "eye bulging": "proptosis",
    "hemorrhage": "bleeding",
    "haemorrhage": "bleeding",
    "gum bleeding": "bleeding",
    "altered mental state": "confusion",
    "swollen nodes": "lymphadenopathy",
    "hepatomegaly": "hepatosplenomegaly",
    "splenomegaly": "hepatosplenomegaly",
    "eschar": "ulcer",
    "chancre": "ulcer",
    "plaque": "lesion",
    "nodule": "lesion",
    "papule": "lesion",
    "pustule": "lesion",
    "boil": "itchy boil",
    "furuncle": "itchy boil",
    "myiasis": "itchy boil",

    # == From VITALS_PATTERNS (Cell 22) ==
    "temp": "temperature",
    "t": "temperature",
    "hr": "heart rate",
    "pulse": "heart rate",
    "bp": "blood pressure",
    "rr": "respiratory rate",
    "spo2": "oxygen saturation",
    "sat": "oxygen saturation",
    "saturation": "oxygen saturation",
    "ht": "height",
    "wt": "weight",

    # == From LAB_PATTERNS (Cell 22) ==
    "hb": "hemoglobin",
    "haemoglobin": "hemoglobin",
    "wbc": "white blood cell",
    "white blood cells": "white blood cell",
    "plt": "platelet",
    "platelets": "platelet",
    "crp": "c-reactive protein",
    "esr": "erythrocyte sedimentation rate",
    "alt": "alanine aminotransferase",
    "ast": "aspartate aminotransferase",
    "na": "sodium",
    "k": "potassium",

    # == From IMAGING_KEYS (Cell 22) ==
    "cxr": "chest x-ray",
    "u/s": "ultrasound",
    "ct": "ct scan",
    "mri": "mri",

    # == From MICRO_PATTERNS (Cell 22) ==
    "rdt": "rapid diagnostic test",

    # == From DIAG_KEYS (Cell 22) ==
    "definitive diagnosis": "final diagnosis",
    "impression": "provisional diagnosis",
    "differential diagnoses": "differential diagnosis",
    "differentials": "differential diagnosis"
}

# Sort keys by length, longest first, to match "loose stools" before "stools"
sorted_keys = sorted(MEDICAL_LEXICON.keys(), key=len, reverse=True)

def normalize_query(query: str) -> str:
    """
    Completes PADT-15: Normalizes a user query.
    """
    # 1. Lowercasing and Punctuation Removal
    text = query.lower()
    text = re.sub(r'[^\w\s-/:;]', '', text) # Keep words, spaces, hyphens, slash, colon, semicolon
    text = re.sub(r'\s+', ' ', text).strip()
    
    # 2. Spelling Correction
    words = text.split()
    corrected_words = [spell.correction(word) or word for word in words]
    text = ' '.join(corrected_words)
    
    # # 3. Acronym & Synonym Expansion (The most important step)
    # for term in sorted_keys:
    #     expansion = MEDICAL_LEXICON[term]
    #     text = re.sub(r'\b' + re.escape(term) + r'\b', expansion, text)
        
    return text

In [20]:
normalize_query("paitent with fevear.")

'patient with fever'

##### 2. Define "Ground Truth"
For each of 50 questions, find the correct 'case_id' in Qdrant database that answers that question (this 'ground_truth_id' is the answer key).
Steps:
- Load the ColPali model to embed queries
- Connect to Qdrant client
- Read dev_setcsv
- Calculate recall@k

In [None]:
import pandas as pd
from qdrant_client import QdrantClient
from transformers import ColPaliForRetrieval, ColPaliProcessor
import torch

# 1. LOAD COLPALI MODEL
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = ColPaliForRetrieval.from_pretrained("vidore/colpali-v1.2-hf").eval().to(device)
processor = ColPaliProcessor.from_pretrained("vidore/colpali-v1.2-hf")

@torch.no_grad()
def embed_query(query: str):
    batch = processor(text=[query]).to(device)
    emb = model(**batch).embeddings[0] # Get the first (only) query embedding
    return emb.to("cpu").float().tolist()

# 2. CONNECT TO QDRANT
client = QdrantClient(host="165.22.56.15", port=6333)
COLLECTION_NAME = "tropical_cases_colpali_cases"
K_VALUE = 5 # This is the 'k' in recall@k

# 3. LOAD DEV SET
try:
    dev_set = pd.read_csv("./dev_set.csv")
except FileNotFoundError:
    print("Error: dev_set.csv not found. Please create it.")
    exit()

def run_evaluation(query_list, ground_truth_list, k=K_VALUE):
    """
    Runs the evaluation and calculates recall@k.
    """
    score = 0
    total = len(query_list)
    
    for query, truth_id in zip(query_list, ground_truth_list):
        # 1. Embed the query
        query_vector = embed_query(query)
        
        # 2. Search Qdrant
        search_results = client.search(
            collection_name=COLLECTION_NAME,
            query_vector=query_vector,
            limit=k
        )
        
        # 3. Check for a match
        found = False
        retrieved_ids = [hit.payload["case_id"] for hit in search_results]
        
        if int(truth_id) in retrieved_ids:
            score += 1
            found = True
            
        print(f"Query: '{query}' -> Found: {found} (Truth: {truth_id}, Got: {retrieved_ids})")

    # 4. Calculate final score
    recall_at_k = score / total
    return recall_at_k

##### 3. Run Experiment A (Baseline)

In [None]:
# --- Get the queries and ground truth from the dataframe ---
raw_queries = dev_set["query"].tolist()
ground_truth = dev_set["ground_truth_id"].tolist()

print("--- RUNNING EXPERIMENT A: RAW QUERIES ---")
baseline_recall = run_evaluation(raw_queries, ground_truth)
print(f"\n[!] BASELINE RECALL@{K_VALUE}: {baseline_recall:.2%}\n")

##### 4. Run Experiment B (Normalized)
Now, run the same test but with normalize_query function.

In [None]:
# --- Apply normalization to every query in the list ---
normalized_queries = [normalize_query(q) for q in raw_queries]

print("--- RUNNING EXPERIMENT B: NORMALIZED QUERIES ---")
normalized_recall = run_evaluation(normalized_queries, ground_truth)
print(f"\n[!] NORMALIZED RECALL@{K_VALUE}: {normalized_recall:.2%}\n")

##### 5. Compare Results
Finally, check if met the "Done when" criteria.

In [None]:
# --- COMPARE RESULTS ---
print("--- TASK COMPLETE? ---")
print(f"Baseline Recall:   {baseline_recall:.2%}")
print(f"Normalized Recall: {normalized_recall:.2%}")

if normalized_recall > baseline_recall:
    print("\n✅ SUCCESS: Recall@k improved. Task PADT-15 is complete.")
else:
    print("\n❌ FAILED: Recall@k did not improve. Go back and add more terms to your MEDICAL_LEXICON.")