# Day 2: Pathology Identification with off-the-shelf LLMs and RAG Augmentation

- Task: Identify pathologies from radiology reports

## Details

- Input: Raw radiology report sections (findings sections) and RAG-retrieved examples
- Output: Predicted pathology DISEASE_COLUMNS (multi-label classification)

In [20]:
import os
from pathlib import Path

# URL:PORT must be identical to what is set in LM Studio!
HOST_URL = "http://localhost:1235"

# model name as served by LM Studio for classification and 
# inference of embeddings for vectorstore
MODEL = "lmstudio-community/medgemma-4b-it-GGUF"

# model supporting embedding endpoint for vectorized embeddings
EMBED_MODEL = "amsaravi/MedEmbed-large-v0.1.gguf"

# path to logged results
Y_PRED_LLM_W_RAG_CACHED = Path('log') / 'y_pred' / 'y_pred_llm_rag.csv'

# create parent folder if not existing
Y_PRED_LLM_W_RAG_CACHED.parent.mkdir(parents=True, exist_ok=True)

## Load Data Splits

In [21]:
import json
import pandas as pd

## Load splits
def load_test_splits():
    data_path = Path("data")
    X_test = pd.read_csv(data_path / "X_test.csv")
    y_test = pd.read_csv(data_path / "y_test.csv")
    print(f"X_test dim:\t{X_test.shape}\ty_test dim:\t{y_test.shape}")
    return X_test, y_test

## Load splits
# TODO: generalize load_test_splits to take the split () as an argument and return either set
def load_data_splits(split:str):
    assert split in ['test', 'train']
    data_path = Path("data")
    X = pd.read_csv(data_path / f"X_{split}.csv")
    y = pd.read_csv(data_path / f"y_{split}.csv")
    print(f"X_{split} dim:\t{X.shape}\ty_{split} dim:\t{y.shape}")
    return X, y

X_test, y_test = load_test_splits()

X_test, y_test = load_data_splits(split='test')
X_train, y_train = load_data_splits(split='train')

X_test dim:	(9603, 2)	y_test dim:	(9603, 14)
X_test dim:	(9603, 2)	y_test dim:	(9603, 14)
X_train dim:	(38601, 2)	y_train dim:	(38601, 14)


In [22]:
X_test.head()

Unnamed: 0,deid_patient_id,section_findings
0,patient04528,Unchanged position of the left upper extremity...
1,patient04986,12 mm focal density in the region of the anter...
2,patient05496,A single AP upright view of the chest taken on...
3,patient05496,"Compared to prior chest x-ray on 1-16-2019, PA..."
4,patient05496,Compared to prior chest x-ray on February 15th...


In [23]:
X_test['section_findings'].values[0]

'Unchanged position of the left upper extremity PICC line. Again seen \nare surgical clips projecting over the right hemithorax. The \ncardiomediastinal silhouette is stable in appearance. Increased \nstranding opacities are noted in the left retrocardiac region. Subtle \nstranding opacities in the right upper lung zone are unchanged.. \nThere are no pleural or significant bony abnormalities. Absence of \nthe right breast shadow compatible with prior mastectomy.'

In [24]:
y_test.head()

Unnamed: 0,Enlarged Cardiomediastinum,Cardiomegaly,Lung Opacity,Lung Lesion,Edema,Consolidation,Pneumonia,Atelectasis,Pneumothorax,Pleural Effusion,Pleural Other,Fracture,Support Devices,No Finding
0,0,0,1,0,0,0,0,0,0,0,0,0,1,0
1,0,0,0,0,0,0,0,0,0,0,0,0,0,0
2,0,1,0,0,1,0,0,0,0,0,0,0,0,0
3,0,1,1,0,0,0,0,0,0,1,0,0,0,0
4,0,1,1,0,0,0,0,0,0,1,0,0,0,0


In [25]:
DISEASE_COLUMNS = y_test.columns

## Evaluation Function
For a multi-class, multi-label problem (where true negative (TN) counts are not sensible) suitable metrics are
- precision (fraction of correctly capture TPs: TP/(TP + FP))
- recall (fraction of recalled TPs: TP/P)
- F1 (harmonic mean of precision, recall)

There are three distinct strategies on how to combine per class performance:
1. micro - global pooling of TP, FP, FN (global picture, bias towards majority classes)
2. macro - per class scores are averaged (no bias, minority class sensitivity)
3. weighted - per class scores are weighted and averaged (bias towards large classes, moderate impact of minor classes)

In [26]:
from sklearn.metrics import f1_score, precision_score, recall_score

def compute_scores(y_true:pd.DataFrame, y_pred:pd.DataFrame, average:str='micro'):
    # average methods: 
    #  micro - global pooling of TP, FP, FN
    #  macro - per class scores are averaged 
    #  weighted - per class scores are weighted and averaged
    # Ensure identically ordered columns and numerical type
    y_true = y_true[DISEASE_COLUMNS].astype(int)
    y_pred = y_pred[DISEASE_COLUMNS].astype(int)


    f1 = f1_score(y_true, y_pred, average=average)
    precision = precision_score(y_true, y_pred, average=average)
    recall = recall_score(y_true, y_pred, average=average)
    return pd.DataFrame({f"{average}-F1": [f1], 
                        f"{average}-Precision": [precision],
                        f"{average}-Recall": [recall]})


## Build Vectorstore from Train Set
### Client for Embedding Retrieval

In [27]:
import requests
import numpy as np
from pathlib import Path

class LMStudioEmbeddingClient:
    def __init__(self, base_url=HOST_URL, model_name=EMBED_MODEL):
        self.base_url = base_url.rstrip("/")
        self.model_name = model_name
        self.batch_id_state = Path("current_batch_id.txt")


    def encode_batch(self, texts):
        """One-shot encoding for small lists (e.g. single query)."""
        if isinstance(texts, str):
            texts = [texts]

        payload = {"model": self.model_name, "input": texts}
        resp = requests.post(f"{self.base_url}/v1/embeddings", json=payload)
        resp.raise_for_status()
        data = resp.json()
        emb = np.array(
            [item["embedding"] for item in data["data"]],
            dtype="float32",
        )
        return emb  # shape (len(texts), dim)

    def encode(self, texts, batch_size=128):
        """
        Generator: yields one np.ndarray of embeddings per batch.
        Resumes from the last completed batch using current_batch_id.txt.
        """
        # Normalize input
        if isinstance(texts, str):
            texts = [texts]

        # Determine starting batch index
        start_batch_id = 0
        if self.batch_id_state.exists():
            with open(self.batch_id_state, "r") as f:
                line = f.readline().strip()
                if line:
                    start_batch_id = int(line) + 1  # resume AFTER last finished batch

        indices = list(range(0, len(texts), batch_size))

        for batch_id, start in enumerate(indices):
            if batch_id < start_batch_id:
                continue  # skip already processed batches

            end = start + batch_size
            print(f"INFO\tProcessing batch {batch_id} ({start}:{end}) "
                  f"of total {len(texts)} items.")

            batch = texts[start:end]
            payload = {"model": self.model_name, "input": batch}
            resp = requests.post(f"{self.base_url}/v1/embeddings", json=payload)
            resp.raise_for_status()
            data = resp.json()
            batch_emb = np.array(
                [item["embedding"] for item in data["data"]],
                dtype="float32",
            )

            # Persist last completed batch id
            with open(self.batch_id_state, "w") as f:
                f.write(str(batch_id))

            yield batch_emb


### Vectorstore

In [28]:

import faiss
import pickle
import os
from pathlib import Path
import numpy as np
from tqdm.notebook import tqdm


class Vectorstore:
    def __init__(self, X_train, y_train, embedding_client, cache_dir="vectorstore"):
        """
        Initialize Vectorstore with training data and an embedding client.

        Args:
            X_train (pd.DataFrame): Must have columns ['deid_patient_id', 'section_findings'].
            y_train (np.ndarray or pd.DataFrame): Multi-hot encoded or equivalent label matrix.
            embedding_client: Object with an .encode(list[str]) -> np.ndarray method (e.g. LMStudioEmbeddingClient).
        
            cache_dir (str): Directory for cached FAISS index and metadata.
        """
        self.X_train = X_train
        self.y_train = y_train
        self.embedding_client = embedding_client

        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(exist_ok=True, parents=True)

        # File paths
        self.index_file = self.cache_dir / "faiss_index.bin"
        self.metadata_file = self.cache_dir / "metadata.pkl"

        # Placeholders
        self.faiss_index = None
        self.train_texts = None
        self.train_labels = None

        # Build vectorstore
        # self.batch_id = 0
        
        if self.faiss_index is None:
            self.build_vector_store(self.X_train["section_findings"].tolist(), self.y_train)
            

    def retrieve_similar_cases(self, query_text, k=5):
        """Retrieve k most similar cases; build index if not available."""
    
        # One-shot embedding for a single query
        query_embedding = self.embedding_client.encode_batch([query_text])  # (1, dim)
        query_embedding = np.asarray(query_embedding, dtype="float32")
        faiss.normalize_L2(query_embedding)
    
        scores, indices = self.faiss_index.search(query_embedding, k)
    
        similar_texts = [self.train_texts[i] for i in indices[0]]
        similar_labels = [self.train_labels.iloc[i] for i in indices[0]]
        similar_scores = scores[0]
    
        return similar_texts, similar_labels, similar_scores

    

    def build_vector_store(self, texts, labels, force_rebuild=False, batch_size=128):
        """Build or extend FAISS index incrementally from batches, persisting after each batch."""
        if self._cache_exists() and not force_rebuild:
            print("Loading FAISS index from cache and appending...")
            self._load_cache()
        else:
            print("Building new FAISS index from reference set...")
            self.faiss_index = None
            self.train_texts = []
            self.train_labels = None
    
        # Append new texts/labels to metadata
        self.train_texts.extend(texts)
        if self.train_labels is None:
            self.train_labels = labels
        else:
            # self.train_labels = self.train_labels.append(labels, ignore_index=True)
            self.train_labels = pd.concat([self.train_labels, labels], ignore_index=True)
    
        # Stream embeddings in batches and add to FAISS, persisting each time
        for batch_emb in self.embedding_client.encode(texts, batch_size=batch_size):
            faiss.normalize_L2(batch_emb)
    
            if self.faiss_index is None:
                dim = batch_emb.shape[1]
                self.faiss_index = faiss.IndexFlatIP(dim)
    
            self.faiss_index.add(batch_emb)
    
            # Persist index + metadata after each batch
            self._save_cache()
            print(f"STATUS\tPersisted index with {self.faiss_index.ntotal} vectors.")
    
        print(f"STATUS\tIndex now has {len(self.train_texts)} entries.")
    
    
    def _cache_exists(self):
        return self.index_file.exists() and self.metadata_file.exists()


    def _save_cache(self):
        """Save FAISS index and metadata; embeddings are stored inside faiss_index."""
        faiss.write_index(self.faiss_index, str(self.index_file))
        metadata = {
            "train_texts": self.train_texts,
            "train_labels": self.train_labels,
            "ntotal": self.faiss_index.ntotal,
            # optionally store model id instead of path if you switched to LM Studio
            # "embedding_model_id": self.embedding_model_id,
        }
        with open(self.metadata_file, "wb") as f:
            pickle.dump(metadata, f)


    def _load_cache(self):
        self.faiss_index = faiss.read_index(str(self.index_file))
        with open(self.metadata_file, 'rb') as f:
            metadata = pickle.load(f)
        self.train_texts = metadata['train_texts']
        self.train_labels = metadata['train_labels']
        print(f"STATUS\tLoaded FAISS index with {len(self.train_texts)} embeddings.")

# LM Studio model id as configured in the app
embedding_client = LMStudioEmbeddingClient(
    base_url=HOST_URL,  # LM Studio API URL
    model_name=EMBED_MODEL,
)

vectorstore = Vectorstore(
    X_train,
    y_train,
    embedding_client=embedding_client,
    cache_dir="vectorstore"
)


Loading FAISS index from cache and appending...
STATUS	Loaded FAISS index with 77202 embeddings.
STATUS	Index now has 115803 entries.


## Label Extractor

In [29]:
import re
import json
import warnings


def extract_json_or_list(text_with_json: str):
    # Regex matches both lists ([...]) and dicts ({...})
    json_rx = re.compile(r"(\{.*?\}|\[.*?\])", re.DOTALL)
    matches = json_rx.findall(text_with_json)
    if not matches:
        warnings.warn(f"Could not extract JSON/list block: {text_with_json}")
        return None
    last_json = matches[-1]
    # Try to parse as JSON
    try:
        parsed = json.loads(last_json)
        return parsed
    except json.JSONDecodeError as e:
        warnings.warn(
            f"Could not decode JSON/list: {e}\nRaw block: {last_json}"
        )
        return None

def cleanse_to_multihot(json_or_list, all_labels=DISEASE_COLUMNS):
    # Case 1: 0/1 or True/False dict
    if isinstance(json_or_list, dict):
        filtered_pred = {}
        for label in all_labels:
            value = json_or_list.get(label, 0)
            if isinstance(value, (int, float, bool)):
                filtered_pred[label] = 1 if value else 0
            elif isinstance(value, str):
                filtered_pred[label] = 1 if value.lower() in {'1', 'true', 'yes'} else 0
            else:
                filtered_pred[label] = 0
        return filtered_pred
    # Case 2: list of strings = present labels only
    elif isinstance(json_or_list, list):
        return {label: 1 if label in json_or_list else 0 for label in all_labels}
    # Unrecognized
    else:
        warnings.warn("Unknown prediction format. Returning None.")
        return None


In [46]:
from openai import OpenAI
import os
from datetime import datetime
from pathlib import Path
from typing import List


class GenerativeLLMClassifier():
    def __init__(self, model:str='llm'):
        
        self.timestamp = datetime.now().strftime("%Y%m%d_%H%M")
        # self.log_path = Path("..") / "log"
        model_name = model.split("/")[-1]
        self.log_path = Path('log') / model_name

        self.log_path_set = False
        self.client = OpenAI(
            base_url=f"{HOST_URL}/v1",
            api_key='dummy'
        )

    
    def build_prompt(self, query_text, similar_examples=None, k:int=5):
        if similar_examples is None:
            similar_examples = []
        # Create prompt optionally RAG-augmented with retrieved examples
     
        prompt = f"""You are a radiology AI assistant. Classify the following medical text for pathologies.
        ### Task
        Determine which of these pathologies are present: [{', '.join(DISEASE_COLUMNS)}]
        """
    
        if similar_examples:
            # Add retrieved examples
            prompt += "### Similar Examples from Training Data:"
            similar_examples = list(similar_examples)
            for i, (text, labels) in enumerate(similar_examples[:k]):
                positive_labels = [label for label, value in zip(DISEASE_COLUMNS, labels) if value == 1]
                prompt += f"""
            Example {i+1}:
            Text: "{text}"
            Present pathologies: {', '.join(positive_labels) if positive_labels else 'No Finding'}
            """

        prompt += f"""
        ### Your Task
        Text to classify: "{query_text}"
        
        Return JSON with 0/1 for each pathology:
        """
        return prompt
    
    
    def run(self, text_id, query, vectorstore=None, k=5):
        if self.log_path_set == False:
            path_str = f"{self.timestamp}_rag" if vectorstore else f"{self.timestamp}_no_rag"
            self.log_path = self.log_path / path_str
            self.log_path.mkdir(parents=True, exist_ok=True)
            self.log_path_set = True
        if vectorstore:
            similar_texts, similar_labels, scores = vectorstore.retrieve_similar_cases(query, k=k)
            user_prompt = self.build_prompt(query, zip(similar_texts, similar_labels), k)
        else:
            user_prompt = self.build_prompt(query)

        # Log prompt
        with open(self.log_path / f"{text_id}_prompt.log", 'w') as f:
            f.write(user_prompt)

        # Generate a prompt completion
        system_prompt = "You are a clinical NLP assistant specialized in radiology."
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ]
        # params = self._cfg["params"] or {}
        completion = self.client.chat.completions.create(
            model=MODEL, # only mandatory if you serve multiple models in LM Studio!
            messages=messages, #**params
        )
        
        return completion.choices[0].message.content
        

llm_classifier = GenerativeLLMClassifier(model=MODEL)

## Classifier with RAG

In [47]:
import warnings
import pandas as pd

def classify_all(classifier, X_test, y_test, acc, vectorstore=None):
    """
    Appends new predictions for X_test/y_test rows with indices not in acc, to acc DataFrame.

    Args:
        classifier: LLM classifier instance (with .run method)
        X_test (pd.DataFrame): Test set with 'section_findings' column
        y_test (pd.DataFrame): True labels DataFrame for test set
        acc (pd.DataFrame): Accumulator DataFrame of prior predictions (index = text_id)
        vectorstore: Optional retrieval model

    Returns:
        pd.DataFrame: Updated accumulator DataFrame (with new predictions appended)
    """

    new_indices = []
    new_preds = []

    if acc is None:
        acc = pd.DataFrame(columns=DISEASE_COLUMNS)

    for text_id, row in X_test.iterrows():
        if text_id in acc.index:
            print(f"Skipping text_id {text_id}: already in accumulator.")
            continue
        print(f"Processing text_id {text_id} ...")
        
        text = row["section_findings"]
        label_row = y_test.loc[text_id]   # use .loc for index alignment
        active_labels = label_row[label_row == 1].index.tolist()
        print(f"Text: {text}")
        print(f"Active labels: {active_labels}")
        
        completion = classifier.run(text_id, text, vectorstore)
        logfile = classifier.log_path / f"{text_id}_completion.log"
        with open(logfile, "w") as f:
            f.write(completion)
        try:
            json_or_list = extract_json_or_list(completion)
            if json_or_list is None:
                print(f"Skipping text_id {text_id}: could not parse completion.")
                continue
            y_pred_row = cleanse_to_multihot(json_or_list, DISEASE_COLUMNS)
            new_indices.append(text_id)
            new_preds.append(y_pred_row)

        except Exception as e:
            raise RuntimeError(f"Error processing completion {text_id}: {e}")

    y_pred_new = pd.DataFrame(new_preds, columns=DISEASE_COLUMNS, index=new_indices)
    acc_updated = pd.concat([acc, y_pred_new])
    return acc_updated



In [48]:
def create_or_load_prediction(rag_flag:bool=False):
    cache_path = Y_PRED_LLM_W_RAG_CACHED if rag_flag else Y_PRED_LLM_CACHED
    if Path(cache_path).exists():
        return pd.read_csv(cache_path, index_col=0)
    else:
        return pd.DataFrame(columns=DISEASE_COLUMNS)

In [49]:
# Restrict computation to the first n items to save time
n = 15

y_pred1 = create_or_load_prediction(rag_flag=True)
y_pred1 = classify_all(llm_classifier, X_test.head(n), y_test, y_pred1, vectorstore)


y_pred1

Processing text_id 0 ...
Text: Unchanged position of the left upper extremity PICC line. Again seen 
are surgical clips projecting over the right hemithorax. The 
cardiomediastinal silhouette is stable in appearance. Increased 
stranding opacities are noted in the left retrocardiac region. Subtle 
stranding opacities in the right upper lung zone are unchanged.. 
There are no pleural or significant bony abnormalities. Absence of 
the right breast shadow compatible with prior mastectomy.
Active labels: ['Lung Opacity', 'Support Devices']
Processing text_id 1 ...
Text: 12 mm focal density in the region of the anterior right fifth rib. 
Lungs are otherwise clear and the cardiac silhouette is not enlarged.
Active labels: []
Processing text_id 2 ...
Text: A single AP upright view of the chest taken on 22 JUNE 2006 
demonstrates sternal wires and mediastinal clips in place.  Cardiac 
silhouette is enlarged but not significantly changed from the prior 
study.  The aorta is ectatic and tortuous

KeyboardInterrupt: 

In [None]:
# Ensure that y_test is indexed the same as y_pred
y_test_subset = y_test.loc[y_pred1.index]

# Calculate scores
scores = compute_scores(y_test_subset, y_pred1, average='micro')

scores

In [None]:
y_pred1.to_csv(Y_PRED_LLM_CACHED, index=False)