# Day 1: Pathology Identification with off-the-shelf LLMs

- Task: Identify pathologies from radiology reports

## Details

- Input: Raw radiology report sections (findings sections)
- Output: Predicted pathology DISEASE_COLUMNS (multi-label classification)

In [None]:
import os
from pathlib import Path

# URL:PORT must be identical to what is set in LM Studio!
HOST_URL = "http://localhost:1235/v1" 

# model name as served by LM Studio
MODEL = 'unsloth/medgemma-4b-it-gguf/medgemma-4b-it-q4_k_s.gguf'

# path to logged results
Y_PRED_LLM_CACHED = Path('log') / 'y_pred' / 'y_pred_llm_no_rag.csv'

# create parent folder if not existing
Y_PRED_LLM_CACHED.parent.mkdir(parents=True, exist_ok=True)

## Load Data Splits

In [None]:
import json
import pandas as pd

## Load splits
def load_test_splits():
    data_path = Path("data")
    X_test = pd.read_csv(data_path / "X_test.csv")
    y_test = pd.read_csv(data_path / "y_test.csv")
    print(f"X_test dim:\t{X_test.shape}\ty_test dim:\t{y_test.shape}")
    return X_test, y_test

X_test, y_test = load_test_splits()

In [None]:
X_test.head()

In [None]:
X_test['section_findings'].values[0]

In [None]:
y_test.head()

In [None]:
DISEASE_COLUMNS = y_test.columns

## Evaluation Function
For a multi-class, multi-label problem (where true negative (TN) counts are not sensible) suitable metrics are
- precision (fraction of correctly capture TPs: TP/(TP + FP))
- recall (fraction of recalled TPs: TP/P)
- F1 (harmonic mean of precision, recall)

There are three distinct strategies on how to combine per class performance:
1. micro - global pooling of TP, FP, FN (global picture, bias towards majority classes)
2. macro - per class scores are averaged (no bias, minority class sensitivity)
3. weighted - per class scores are weighted and averaged (bias towards large classes, moderate impact of minor classes)

In [None]:
from sklearn.metrics import f1_score, precision_score, recall_score

def compute_scores(y_true:pd.DataFrame, y_pred:pd.DataFrame, average:str='micro'):
    # average methods: 
    #  micro - global pooling of TP, FP, FN
    #  macro - per class scores are averaged 
    #  weighted - per class scores are weighted and averaged
    # Ensure identically ordered columns and numerical type
    y_true = y_true[DISEASE_COLUMNS].astype(int)
    y_pred = y_pred[DISEASE_COLUMNS].astype("Int64").fillna(0).astype(int)


    f1 = f1_score(y_true, y_pred, average=average)
    precision = precision_score(y_true, y_pred, average=average)
    recall = recall_score(y_true, y_pred, average=average)
    return pd.DataFrame({f"{average}-F1": [f1], 
                        f"{average}-Precision": [precision],
                        f"{average}-Recall": [recall]})


In [None]:
import re
import json
import warnings


def extract_json_or_list(text_with_json: str):
    # Regex matches both lists ([...]) and dicts ({...})
    json_rx = re.compile(r"(\{.*?\}|\[.*?\])", re.DOTALL)
    matches = json_rx.findall(text_with_json)
    if not matches:
        warnings.warn(f"Could not extract JSON/list block: {text_with_json}")
        return None
    last_json = matches[-1]
    # Try to parse as JSON
    try:
        parsed = json.loads(last_json)
        return parsed
    except json.JSONDecodeError as e:
        warnings.warn(
            f"Could not decode JSON/list: {e}\nRaw block: {last_json}"
        )
        return None

def cleanse_to_multihot(json_or_list, all_labels=DISEASE_COLUMNS):
    # Case 1: 0/1 or True/False dict
    if isinstance(json_or_list, dict):
        filtered_pred = {}
        for label in all_labels:
            value = json_or_list.get(label, 0)
            if isinstance(value, (int, float, bool)):
                filtered_pred[label] = 1 if value else 0
            elif isinstance(value, str):
                filtered_pred[label] = 1 if value.lower() in {'1', 'true', 'yes'} else 0
            else:
                filtered_pred[label] = 0
        return filtered_pred
    # Case 2: list of strings = present labels only
    elif isinstance(json_or_list, list):
        return {label: 1 if label in json_or_list else 0 for label in all_labels}
    # Unrecognized
    else:
        warnings.warn("Unknown prediction format. Returning None.")
        return None


In [None]:
from openai import OpenAI
import os
from datetime import datetime
from pathlib import Path
from typing import List


class GenerativeLLMClassifier():
    def __init__(self, model:str='llm'):
        
        self.timestamp = datetime.now().strftime("%Y%m%d_%H%M")
        # self.log_path = Path("..") / "log"
        model_name = model.split("/")[-1]
        self.log_path = Path('log') / model_name

        self.log_path_set = False
        self.client = OpenAI(
            base_url=HOST_URL,
            api_key='dummy'
        )

    
    def build_prompt(self, query_text, similar_examples=None, k:int=5):
        if similar_examples is None:
            similar_examples = []
        # Create prompt optionally RAG-augmented with retrieved examples
     
        prompt = f"""You are a radiology AI assistant. Classify the following medical text for pathologies.
        ### Task
        Determine which of these pathologies are present: [{', '.join(DISEASE_COLUMNS)}]
        """
    
        if similar_examples:
            # Add retrieved examples
            prompt += "### Similar Examples from Training Data:"
            similar_examples = list(similar_examples)
            for i, (text, labels) in enumerate(similar_examples[:k]):
                positive_labels = [label for label, value in zip(DISEASE_COLUMNS, labels) if value == 1]
                prompt += f"""
            Example {i+1}:
            Text: "{text}"
            Present pathologies: {', '.join(positive_labels) if positive_labels else 'No Finding'}
            """

        prompt += f"""
        ### Your Task
        Text to classify: "{query_text}"
        
        Return JSON with 0/1 for each pathology:
        """
        return prompt
    
    
    def run(self, text_id, query, vectorstore=None, k=5):
        if self.log_path_set == False:
            path_str = f"{self.timestamp}_rag" if vectorstore else f"{self.timestamp}_no_rag"
            self.log_path = self.log_path / path_str
            self.log_path.mkdir(parents=True, exist_ok=True)
            self.log_path_set = True

        if vectorstore:
            similar_texts, similar_labels, scores = vectorstore.retrieve_similar_cases(query, k=k)
            user_prompt = self.build_prompt(query, zip(similar_texts, similar_labels), k)
        else:
            user_prompt = self.build_prompt(query)

        # Log prompt
        with open(self.log_path / f"{text_id}_prompt.log", 'w') as f:
            f.write(user_prompt)

        # Generate a prompt completion
        system_prompt = "You are a clinical NLP assistant specialized in radiology."
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ]
        # params = self._cfg["params"] or {}
        completion = self.client.chat.completions.create(
            model=MODEL, # only mandatory if you serve multiple models in LM Studio!
            messages=messages, #**params
        )
        
        return completion.choices[0].message.content
        

llm_classifier = GenerativeLLMClassifier(model=MODEL)

## Baseline

In [None]:
import warnings
import pandas as pd

def classify_all(classifier, X_test, y_test, acc, vectorstore=None):
    """
    Appends new predictions for X_test/y_test rows with indices not in acc, to acc DataFrame.

    Args:
        classifier: LLM classifier instance (with .run method)
        X_test (pd.DataFrame): Test set with 'section_findings' column
        y_test (pd.DataFrame): True labels DataFrame for test set
        acc (pd.DataFrame): Accumulator DataFrame of prior predictions (index = text_id)
        vectorstore: Optional retrieval model

    Returns:
        pd.DataFrame: Updated accumulator DataFrame (with new predictions appended)
    """

    new_indices = []
    new_preds = []

    if acc is None:
        acc = pd.DataFrame(columns=DISEASE_COLUMNS)

    for text_id, row in X_test.iterrows():
        if text_id in acc.index:
            print(f"Skipping text_id {text_id}: already in accumulator.")
            continue
        print(f"Processing text_id {text_id} ...")
        
        text = row["section_findings"]
        label_row = y_test.loc[text_id]   # use .loc for index alignment
        active_labels = label_row[label_row == 1].index.tolist()
        print(f"Text: {text}")
        print(f"Active labels: {active_labels}")
        
        completion = classifier.run(text_id, text, vectorstore)
        logfile = classifier.log_path / f"{text_id}_completion.log"
        with open(logfile, "w") as f:
            f.write(completion)
        try:
            json_or_list = extract_json_or_list(completion)
            if json_or_list is None:
                print(f"Skipping text_id {text_id}: could not parse completion.")
                continue
            y_pred_row = cleanse_to_multihot(json_or_list, DISEASE_COLUMNS)
            new_indices.append(text_id)
            new_preds.append(y_pred_row)

        except Exception as e:
            raise RuntimeError(f"Error processing completion {text_id}: {e}")

    y_pred_new = pd.DataFrame(new_preds, columns=DISEASE_COLUMNS, index=new_indices)
    acc_updated = pd.concat([acc, y_pred_new])
    return acc_updated



In [None]:
def create_or_load_prediction(rag_flag:bool=False):
    cache_path = Y_PRED_LLM_W_RAG_CACHED if rag_flag else Y_PRED_LLM_CACHED
    if Path(cache_path).exists():
        return pd.read_csv(cache_path, index_col=0)
    else:
        return pd.DataFrame(columns=DISEASE_COLUMNS)

In [None]:
# Restrict computation to the first n items to save time
n = 5

y_pred1 = create_or_load_prediction(rag_flag=False)
y_pred1 = classify_all(llm_classifier, X_test.head(n), y_test, y_pred1)


y_pred1

In [None]:
# Ensure that y_test is indexed the same as y_pred
y_test_subset = y_test.loc[y_pred1.index]

# Calculate scores
scores = compute_scores(y_test_subset, y_pred1, average='micro')

scores

In [None]:
y_pred1.to_csv(Y_PRED_LLM_CACHED, index=False)