In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from pathlib import Path
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import TypeAlias, Union
import re
import sys


# Add the base_agent directory to sys.path
base_agent_dir = "/vol/biomedic3/bglocker/ugproj2324/nns20/cxr-agent/base_agent"
sys.path.insert(0, str(base_agent_dir))

from pathology_sets import Pathologies
from pathology_detector import CheXagentVisionTransformerPathologyDetector
from phrase_grounder import BioVilTPhraseGrounder
from generation_engine import GenerationEngine, Llama3Generation, CheXagentLanguageModelGeneration, GeminiFlashGeneration


pathology_dict_type: TypeAlias = dict[set]

In [2]:
DEVICE = None #"cuda:1"  
PATHOLOGY_DETECTION_THRESHOLD = 0.4
PHRASE_GROUNDING_THRESHOLD = 0.2

USER_PROMPT = "Just list the findings on the chest x-ray, nothing else. If there are no findings, just say that."
IGNORE_PATHOLOGIES = {} #{"Support Devices"}
DO_NOT_LOCALISE = {"Cardiomegaly"}

### Generate and save outputs of agents

In [3]:
#VinDr paths
vinDr_test_ground_truth_path = Path("/vol/biomedic3/bglocker/ugproj2324/nns20/datasets/VinDr-CXR/test_set_three_splits/VinDr_test_test_split_with_one_hot_labels.csv")

vindr_pathologies = ["Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
            "Clavicle fracture", "Consolidation", "Emphysema", "Enlarged PA",
            "ILD", "Infiltration", "Lung Opacity", "Lung cavity", "Lung cyst",
            "Mediastinal shift","Nodule/Mass", "Pleural effusion", "Pleural thickening",
            "Pneumothorax", "Pulmonary fibrosis","Rib fracture", "Other lesion",
            "No finding"] 

# CheXpert paths
cheXpert_test_ground_truth_path = Path("/vol/biomedic3/bglocker/ugproj2324/nns20/datasets/CheXpert/test.csv")
cheXpert_test_path = Path("/vol/biodata/data/chest_xray/CheXpert-v1.0-small/CheXpert-v1.0-small/test")

# CheXpert pathologies
cheXpert_pathologies = ['No Finding','Enlarged Cardiomediastinum','Cardiomegaly','Lung Opacity',
        'Lung Lesion','Edema','Consolidation','Pneumonia','Atelectasis','Pneumothorax',
        'Pleural Effusion','Pleural Other','Fracture','Support Devices']

In [4]:
pathology_detector = CheXagentVisionTransformerPathologyDetector(pathologies=Pathologies.CHEXPERT, device=DEVICE)
phrase_grounder = BioVilTPhraseGrounder(detection_threshold=PHRASE_GROUNDING_THRESHOLD, device = DEVICE)
l3 = Llama3Generation(device = DEVICE)
cheXagent_lm = CheXagentLanguageModelGeneration(pathology_detector.processor, pathology_detector.model, pathology_detector.generation_config, pathology_detector.device, pathology_detector.dtype)
gemini = GeminiFlashGeneration()

GPU 0: NVIDIA GeForce RTX 4090, Free memory: 24177 MB
GPU 1: NVIDIA GeForce RTX 4090, Free memory: 24203 MB
Selecting GPU 1 with 24203 MB free memory, Device = cuda:1


Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

CheXagent Model loaded


The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BertTokenizer'. 
The class this function is called from is 'CXRBertTokenizer'.
You are using a model of type bert to instantiate a model of type cxr-bert. This is not supported for all configurations of models and can yield errors.
  return self.fget.__get__(instance, owner)()
Some weights of the model checkpoint at microsoft/BiomedVLP-BioViL-T were not used when initializing CXRBertModel: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing CXRBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing CXRBertModel from the checkpoint of a model that you expect to 

Using downloaded and verified file: /tmp/biovil_t_image_model_proj_size_128.pt
GPU 0: NVIDIA GeForce RTX 4090, Free memory: 23788 MB
GPU 1: NVIDIA GeForce RTX 4090, Free memory: 6991 MB
Selecting GPU 0 with 23788 MB free memory, Device = cuda:0
GPU 0: NVIDIA GeForce RTX 4090, Free memory: 23216 MB
GPU 1: NVIDIA GeForce RTX 4090, Free memory: 6991 MB
Selecting GPU 0 with 23216 MB free memory, Device = cuda:0


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [17]:
model_outputs = defaultdict(dict)
with open(cheXpert_test_ground_truth_path) as chexpert_test_file:
    for index, line in enumerate(chexpert_test_file):
        if index == 0:
            continue
        image_id = line.strip().split(",")[0]
        image_path = cheXpert_test_path / image_id
        pathology_confidences, localised_pathologies, chexagent_e2e = GenerationEngine.detect_and_localise_pathologies(
            image_path=image_path,
            pathology_detector=pathology_detector,
            phrase_grounder=phrase_grounder,
            pathology_detection_threshold = PATHOLOGY_DETECTION_THRESHOLD,
            ignore_pathologies=IGNORE_PATHOLOGIES,
            do_not_localise=DO_NOT_LOCALISE,
            prompt_for_chexagent_lm_output= USER_PROMPT,
        )
        
        model_outputs[image_id]['chexagent'] = chexagent_e2e

        gemini_system_prompt, gemini_image_context_prompt = gemini.generate_prompts(pathology_confidences, localised_pathologies)
        model_outputs[image_id]['gemini_agent'] = gemini.generate_model_output(gemini_system_prompt, gemini_image_context_prompt, user_prompt=USER_PROMPT)

        def run_chexagent_lm(pathology_confidences, localised_pathologies):
            system_prompt, image_context_prompt = GenerationEngine.generate_prompts(pathology_confidences, localised_pathologies)
            return cheXagent_lm.generate_model_output(system_prompt, image_context_prompt, user_prompt=USER_PROMPT)

        def run_llama3_agent(pathology_confidences, localised_pathologies):
            system_prompt, image_context_prompt = l3.generate_prompts(pathology_confidences, localised_pathologies)
            return l3.generate_model_output(system_prompt, image_context_prompt, user_prompt=USER_PROMPT)
        
        def run_gemini_agent(pathology_confidences, localised_pathologies):
            system_prompt, image_context_prompt = gemini.generate_prompts(pathology_confidences, localised_pathologies)
            return gemini.generate_model_output(system_prompt, image_context_prompt, user_prompt=USER_PROMPT)

        with ThreadPoolExecutor() as executor:       
            # Run cheXagent_lm and l3 in parallel
            future_chexagent_lm = executor.submit(run_chexagent_lm, pathology_confidences, localised_pathologies)
            future_llama3_agent = executor.submit(run_llama3_agent, pathology_confidences, localised_pathologies)

            for future in as_completed([future_chexagent_lm, future_llama3_agent]):
                if future == future_chexagent_lm:
                    model_outputs[image_id]['chexagent_agent'] = future.result()
                else:
                    model_outputs[image_id]['llama3_agent'] = future.result()

        print(model_outputs)
        break
        

  [torch.tensor(pixel_values) for pixel_values in encoding_image_processor["pixel_values"]]
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


defaultdict(<class 'dict'>, {'patient64741/study1/view1_frontal.jpg': {'chexagent': 'The endotracheal tube is in a standard position. The Swan-Ganz catheter terminates in the right pulmonary artery. The left internal jugular central venous catheter terminates in the mid SVC. The nasogastric tube terminates in the stomach. The cardiomediastinal silhouette is stable. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is

In [19]:
cheXpert_output_folder = Path("/vol/biomedic3/bglocker/ugproj2324/nns20/cxr-agent/base_agent/evaluation/cheXpert")
print(model_outputs.keys())

def dump_outputs_to_files(model_outputs: dict, output_folder: Path):
    for image_id, model_dict in model_outputs.items():
        print(image_id, model_dict)
        print(type(image_id), type(model_dict))
        
        for model_name, model_output in model_dict.items():
            # replace any new lines in the model output with full stops
            model_output = model_output.replace("\n", ".")
            with open(output_folder / f"{model_name}.txt", "a") as output_file:
                output_file.write(f"{image_id},{model_output}")
    
dump_outputs_to_files(model_outputs, cheXpert_output_folder)

dict_keys(['patient64741/study1/view1_frontal.jpg'])
patient64741/study1/view1_frontal.jpg {'chexagent': 'The endotracheal tube is in a standard position. The Swan-Ganz catheter terminates in the right pulmonary artery. The left internal jugular central venous catheter terminates in the mid SVC. The nasogastric tube terminates in the stomach. The cardiomediastinal silhouette is stable. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no pneumothorax. There is no p

In [2]:
def find_pathologies_in_sentences(text, pathologies):
    sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
    detected_pathologies = set()

    negation_keywords = r"\bno\b|\bnot\b|\bwithout\b|\babsent\b"

    for sentence in sentences:
        lower_sentence = sentence.lower()
        pathologies_in_sentence = set()

        for pathology in pathologies:
            if pathology.lower() in lower_sentence:
                pathologies_in_sentence.add(pathology)

        # Check for negation in the sentence
        negation_match = re.search(negation_keywords, lower_sentence)

        # Extract pathologies in 'no xxx or yyy' construction
        if negation_match:
            ns = negation_match.end()
            to_end_sentence = lower_sentence[ns:]
            # Look for immediate pathologies after 'no' possibly linked by 'or'/'and'
            excluded_pathologies = [p for p in pathologies_in_sentence if p.lower() in to_end_sentence]

            # Remove the pathologies that are excluded by negation
            pathologies_in_sentence.difference_update(excluded_pathologies)

        detected_pathologies.update(pathologies_in_sentence)

    return detected_pathologies    

def build_image_id_to_pathology_dict(file_path, pathologies, header = False,pathologies_check = True, sentences = False ) -> pathology_dict_type:
    image_ids_to_pathologies = defaultdict(set)
    
    pathologies = [pathology.lower() for pathology in pathologies]
    with open(file_path, 'r') as f:
        if header:
            header = f.readline() # skip header
        for line in f.readlines():

            if not sentences:
                line = line.strip().split(',')
                for pathology in line[1:]:
                    if pathologies_check and pathology.lower() in pathologies:
                        image_ids_to_pathologies[line[0]].add(pathology.lower().strip())
                    else:
                        image_ids_to_pathologies[line[0]].add(pathology.lower().strip())
            
            else:
                image_id = line.split(',')[0]
                text = (',').join(line.split(',')[1:])
                detected_pathologies = find_pathologies_in_sentences(text, pathologies)
                if len(detected_pathologies) == 0:
                    detected_pathologies.add('no finding')
                image_ids_to_pathologies[image_id] = detected_pathologies
                
    return image_ids_to_pathologies

In [6]:
test = "No findings..."
find_pathologies_in_sentences(test, cheXpert_pathologies)

{'No Finding'}

In [8]:
def normalize_labels(labels, synonym_mapping):
    normalized_labels = set()
    for label in labels:
        found = False
        for synonyms in synonym_mapping:
            if label in synonyms:
                normalized_labels.add(frozenset(synonyms))  # Adding the frozenset of synonyms
                found = True
                break
        if not found:
            normalized_labels.add(label)
    return normalized_labels


def contains_no_finding(label_set, synonym_mappings):
    for synonyms in synonym_mappings:
        if "no finding" in synonyms:
            if frozenset(synonyms) in label_set:
                return True
    return False


def convert_probabilities_to_labels(file_path: Path,pathologies, threshold = True, threshold_val = 0.5, top_k = False, k = 1) -> pathology_dict_type:
    image_ids_to_pathologies = defaultdict(set)
    pathologies = [pathology.lower() for pathology in pathologies]

    with open(file_path, 'r') as f:
        # header contains mapping of pathology index to pathology name
        header = f.readline()
        pathology_names = header.strip().split(',')[1:]
        for line in f.readlines():
            line = line.strip().split(',')
            image_id = line[0]
            probabilities = np.array([float(prob) for prob in line[1:]])
            for i, prob in enumerate(probabilities):
                pathology = pathology_names[i].lower()
                if pathology not in pathologies:
                    continue

                if threshold:
                    if prob >= threshold_val:
                        image_ids_to_pathologies[image_id].add(pathology)
                elif top_k:
                    if i in np.argsort(probabilities)[-k:]:
                        if probabilities[i] >= threshold_val: # threshold_val should be 0.5 when threshold flag is False
                            image_ids_to_pathologies[image_id].add(pathology)
                else:
                    return ValueError("Invalid arguments for convert_probabilities_to_labels")
            
            if len(image_ids_to_pathologies[image_id]) == 0:
                image_ids_to_pathologies[image_id].add("no finding")

    return image_ids_to_pathologies


def calculate_accuracy_metrics(gt_path: Path, preds: Union[Path, pathology_dict_type], gt_pathologies: set, pred_pathologies = None , synonym_mappings = [], threshold = True, threshold_val = 0.5, k = 1):
    """
    Pre-reqs: assume header of probability files will have the pathology names (i.e. image_id, pathology1, pathology2, ...)
    
    Args:
    gt_path: Path to the ground truth file with probabilities
    preds: Path to the predictions file with probabilities or a dictionary of image_id to pathologies
    gt_pathologies: Set of pathologies to consider in the ground truth, all others will be ignored
    pred_pathologies: Set of pathologies to consider in the predictions, all others will be ignored
    synonym_mappings: List of lists of synonyms for each pathology
    threshold: Flag to determine whether to apply a threshold to the probabilities
    threshold_val: Threshold value to apply to the probabilities
    k: Number of top k pathologies to consider (only if threshold is False)
    """
    
    ground_truth = convert_probabilities_to_labels(gt_path,pathologies = gt_pathologies)
    if pred_pathologies is None:
        pred_pathologies = gt_pathologies
    
    if isinstance(preds, Path):
        predictions = convert_probabilities_to_labels(preds, pathologies = pred_pathologies, threshold = threshold, threshold_val = threshold_val, top_k = True, k = k)
    else:
        predictions = preds # Since CheXagent does not give probability bounds on outputs
        
    # Normalize each set of pathologies in the ground truth and predictions using the provided synonym mappings
    normalized_ground_truth = {image_id: normalize_labels(gt_labels, synonym_mappings) for image_id, gt_labels in ground_truth.items()}
    normalized_predictions = {image_id: normalize_labels(pred_labels, synonym_mappings) for image_id, pred_labels in predictions.items()}

    exact_matches = 0
    exact_no_finding = 0
    exact_one_pathology = 0
    exact_multiple_pathologies = 0

    total_no_finding = 0
    total_one_pathology = 0
    total_multiple_pathologies = 0

    correct_matches = 0
    correct_no_finding = 0
    correct_pathology = 0

    correct_top_k_pathology_match = 0

    for image_id, ground_truth_labels in normalized_ground_truth.items():
        pred_labels = normalized_predictions.get(image_id, set())

        is_no_finding = "no finding" in ground_truth_labels or contains_no_finding(ground_truth_labels, synonym_mappings)
        if is_no_finding:
            total_no_finding += 1
        elif len(ground_truth_labels) == 1:
            total_one_pathology += 1
        else:  # Assuming any non-empty set of labels greater than 1 is 'multiple pathologies'
            total_multiple_pathologies += 1
        
        # Calculate exact matches (Metric 1)
        if ground_truth_labels == pred_labels:
            exact_matches += 1
            # Update metrics for matches (Metric 2)
            if "no finding" in ground_truth_labels or contains_no_finding(ground_truth_labels, synonym_mappings):
                exact_no_finding += 1
            elif len(ground_truth_labels) == 1:
                exact_one_pathology += 1
            else:
                exact_multiple_pathologies += 1
        
        # Calculate individual correct matches (Metric 3)
        matched_pathologies = ground_truth_labels.intersection(pred_labels)
        correct_matches += len(matched_pathologies)
        if len(matched_pathologies) > 0:
            correct_top_k_pathology_match += 1
        
        # Update metrics for matches (Metric 4)
        for pathology in matched_pathologies:
            if "no finding" in ground_truth_labels or contains_no_finding({pathology}, synonym_mappings):
                correct_no_finding += 1
            else:
                correct_pathology += 1
    
    # Calculate metrics
    n = len(ground_truth)
    exact_matches_percentage = exact_matches / n
    exact_no_finding_percentage = exact_no_finding / total_no_finding if total_no_finding > 0 else 0
    exact_one_pathology_percentage = exact_one_pathology / total_one_pathology if total_one_pathology > 0 else 0
    exact_multiple_pathologies_percentage = exact_multiple_pathologies / total_multiple_pathologies if total_multiple_pathologies > 0 else 0
    
    print(f"\n\nDataset Characteristics:")
    print(f"No finding proportion: {total_no_finding/n:.2f}")
    print(f"One pathology proportion: {total_one_pathology/n:.2f}")
    print(f"Multiple pathologies proportion: {total_multiple_pathologies/n:.2f}")
    
    print(f"\nExact Matches Characteristics:")
    print(f"Exact matches: {exact_matches_percentage:.2f}")
    print(f"Exact no finding: {exact_no_finding_percentage:.2f}")
    print(f"Exact one pathology: {exact_one_pathology_percentage:.2f}")
    print(f"Exact multiple pathologies: {exact_multiple_pathologies_percentage:.2f}")
    
    total_num_of_pathologies = sum([len(pathologies) for pathologies in ground_truth.values()])
    correct_matches_percentage = correct_matches / total_num_of_pathologies
    correct_no_finding_percentage = correct_no_finding / total_no_finding if total_no_finding > 0 else 0
    correct_pathology_percentage = correct_pathology / (total_num_of_pathologies - total_no_finding) if total_num_of_pathologies - total_no_finding > 0 else 0
    
    print(f"\nIndividual Pathology Characteristics:")
    print(f"Correct matches: {correct_matches_percentage:.2f}")
    print(f"Correct no finding: {correct_no_finding_percentage:.2f}")
    print(f"Correct pathology: {correct_pathology_percentage:.2f}")

    print(f"\nTop K Pathology Match Characteristics:")
    print(f"Top K Pathology Match: {correct_top_k_pathology_match/n:.2f}")

    return exact_matches_percentage, exact_no_finding_percentage, exact_one_pathology_percentage, exact_multiple_pathologies_percentage, correct_matches_percentage, correct_no_finding_percentage, correct_pathology_percentage