In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from pathlib import Path
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import TypeAlias, Union
import re
import sys


# Add the base_agent directory to sys.path
base_agent_dir = "/vol/biomedic3/bglocker/ugproj2324/nns20/cxr-agent/base_agent"
sys.path.insert(0, str(base_agent_dir))

from pathology_sets import Pathologies
from pathology_detector import CheXagentVisionTransformerPathologyDetector
from phrase_grounder import BioVilTPhraseGrounder
from generation_engine import GenerationEngine, Llama3Generation, CheXagentLanguageModelGeneration, GeminiFlashGeneration


pathology_dict_type: TypeAlias = dict[set]

In [2]:
DEVICE = None #"cuda:1"  
PATHOLOGY_DETECTION_THRESHOLD = 0.4
PHRASE_GROUNDING_THRESHOLD = 0.2

USER_PROMPT = "Just list the findings on the chest x-ray, nothing else. If there are no findings, just say that."
IGNORE_PATHOLOGIES = {} #{"Support Devices"}
DO_NOT_LOCALISE = {"Cardiomegaly"}

In [3]:
def find_pathologies_in_sentences(text, pathologies):
    sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
    detected_pathologies = set()

    negation_keywords = r"\bno\b|\bnot\b|\bwithout\b|\babsent\b"

    for sentence in sentences:
        lower_sentence = sentence.lower()
        pathologies_in_sentence = set()
 
        for pathology in pathologies:
            if pathology.lower() in lower_sentence:
                pathologies_in_sentence.add(pathology)

        # Check for negation in the sentence
        negation_match = re.search(negation_keywords, lower_sentence)

        # Extract pathologies in 'no xxx or yyy' construction
        if negation_match:
            ns = negation_match.end()
            to_end_sentence = lower_sentence[ns:]
            # Look for immediate pathologies after 'no' possibly linked by 'or'/'and'
            excluded_pathologies = [p for p in pathologies_in_sentence if p.lower() in to_end_sentence]

            # Remove the pathologies that are excluded by negation
            pathologies_in_sentence.difference_update(excluded_pathologies)

        detected_pathologies.update(pathologies_in_sentence)

    return detected_pathologies    

def build_image_id_to_pathology_dict(file_path, pathologies, header = False,pathologies_check = True, sentences = False ) -> pathology_dict_type:
    image_ids_to_pathologies = defaultdict(set)
    
    pathologies = [pathology.lower() for pathology in pathologies]
    with open(file_path, 'r') as f:
        if header:
            header = f.readline() # skip header
        for line in f.readlines():

            if not sentences:
                line = line.strip().split(',')
                for pathology in line[1:]:
                    if pathologies_check and pathology.lower() in pathologies:
                        image_ids_to_pathologies[line[0]].add(pathology.lower().strip())
                    else:
                        image_ids_to_pathologies[line[0]].add(pathology.lower().strip())
            
            else:
                image_id = line.split(',')[0]
                text = (',').join(line.split(',')[1:])
                detected_pathologies = find_pathologies_in_sentences(text, pathologies)
                if len(detected_pathologies) == 0:
                    detected_pathologies.add('no finding')
                image_ids_to_pathologies[image_id] = detected_pathologies
                
    return image_ids_to_pathologies

In [4]:
def normalize_labels(labels, synonym_mapping):
    normalized_labels = set()
    for label in labels:
        found = False
        for synonyms in synonym_mapping:
            if label in synonyms:
                normalized_labels.add(frozenset(synonyms))  # Adding the frozenset of synonyms
                found = True
                break
        if not found:
            normalized_labels.add(label)
    return normalized_labels


def contains_no_finding(label_set, synonym_mappings):
    for synonyms in synonym_mappings:
        if "no finding" in synonyms:
            if frozenset(synonyms) in label_set:
                return True
    return False


def convert_probabilities_to_labels(file_path: Path,pathologies, threshold = True, threshold_val = 0.5, top_k = False, k = 1) -> pathology_dict_type:
    image_ids_to_pathologies = defaultdict(set)
    pathologies = [pathology.lower() for pathology in pathologies]

    with open(file_path, 'r') as f:
        # header contains mapping of pathology index to pathology name
        header = f.readline()
        pathology_names = header.strip().split(',')[1:]
        for line in f.readlines():
            line = line.strip().split(',')
            image_id = line[0]
            probabilities = np.array([float(prob) for prob in line[1:]])
            for i, prob in enumerate(probabilities):
                pathology = pathology_names[i].lower()
                if pathology not in pathologies:
                    continue

                if threshold:
                    if prob >= threshold_val:
                        image_ids_to_pathologies[image_id].add(pathology)
                elif top_k:
                    if i in np.argsort(probabilities)[-k:]:
                        if probabilities[i] >= threshold_val: # threshold_val should be 0.5 when threshold flag is False
                            image_ids_to_pathologies[image_id].add(pathology)
                else:
                    return ValueError("Invalid arguments for convert_probabilities_to_labels")
            
            if len(image_ids_to_pathologies[image_id]) == 0:
                image_ids_to_pathologies[image_id].add("no finding")

    return image_ids_to_pathologies


def calculate_accuracy_metrics(gt_path: Path, preds: Union[Path, pathology_dict_type], gt_pathologies: set, pred_pathologies = None , synonym_mappings = [], threshold = True, threshold_val = 0.5, k = 1, ignore_patholgies = {}):
    """
    Pre-reqs: assume header of probability files will have the pathology names (i.e. image_id, pathology1, pathology2, ...)
    
    Args:
    gt_path: Path to the ground truth file with probabilities
    preds: Path to the predictions file with probabilities or a dictionary of image_id to pathologies
    gt_pathologies: Set of pathologies to consider in the ground truth, all others will be ignored
    pred_pathologies: Set of pathologies to consider in the predictions, all others will be ignored
    synonym_mappings: List of lists of synonyms for each pathology
    threshold: Flag to determine whether to apply a threshold to the probabilities
    threshold_val: Threshold value to apply to the probabilities
    k: Number of top k pathologies to consider (only if threshold is False)
    """
    
    ground_truth = convert_probabilities_to_labels(gt_path,pathologies = gt_pathologies)
    if pred_pathologies is None:
        pred_pathologies = gt_pathologies
    
    if isinstance(preds, Path):
        predictions = convert_probabilities_to_labels(preds, pathologies = pred_pathologies, threshold = threshold, threshold_val = threshold_val, top_k = True, k = k)
    else:
        predictions = preds # Since CheXagent does not give probability bounds on outputs
        
    # Normalize each set of pathologies in the ground truth and predictions using the provided synonym mappings
    normalized_ground_truth = {image_id: normalize_labels(gt_labels, synonym_mappings) for image_id, gt_labels in ground_truth.items()}
    normalized_predictions = {image_id: normalize_labels(pred_labels, synonym_mappings) for image_id, pred_labels in predictions.items()}
    
    # remove ignore pathologies in normalised ground truth and predictions
    for image_id, gt_labels in normalized_ground_truth.items():
        normalized_ground_truth[image_id] = gt_labels.difference(ignore_patholgies)

    for image_id, pred_labels in normalized_predictions.items():
        normalized_predictions[image_id] = pred_labels.difference(ignore_patholgies)

    exact_matches = 0
    exact_no_finding = 0
    exact_one_pathology = 0
    exact_multiple_pathologies = 0

    total_no_finding = 0
    total_one_pathology = 0
    total_multiple_pathologies = 0

    correct_matches = 0
    correct_no_finding = 0
    correct_pathology = 0

    correct_top_k_pathology_match = 0

    for image_id, ground_truth_labels in normalized_ground_truth.items():
        pred_labels = normalized_predictions.get(image_id, set())

        is_no_finding = "no finding" in ground_truth_labels or contains_no_finding(ground_truth_labels, synonym_mappings)
        if is_no_finding:
            total_no_finding += 1
        elif len(ground_truth_labels) == 1:
            total_one_pathology += 1
        else:  # Assuming any non-empty set of labels greater than 1 is 'multiple pathologies'
            total_multiple_pathologies += 1
        
        # Calculate exact matches (Metric 1)
        if ground_truth_labels == pred_labels:
            exact_matches += 1
            # Update metrics for matches (Metric 2)
            if "no finding" in ground_truth_labels or contains_no_finding(ground_truth_labels, synonym_mappings):
                exact_no_finding += 1
            elif len(ground_truth_labels) == 1:
                exact_one_pathology += 1
            else:
                exact_multiple_pathologies += 1
        
        # Calculate individual correct matches (Metric 3)
        matched_pathologies = ground_truth_labels.intersection(pred_labels)
        correct_matches += len(matched_pathologies)
        if len(matched_pathologies) > 0:
            correct_top_k_pathology_match += 1
        
        # Update metrics for matches (Metric 4)
        for pathology in matched_pathologies:
            if "no finding" in ground_truth_labels or contains_no_finding({pathology}, synonym_mappings):
                correct_no_finding += 1
            else:
                correct_pathology += 1
    
    # Calculate metrics
    n = len(ground_truth)
    exact_matches_percentage = exact_matches / n
    exact_no_finding_percentage = exact_no_finding / total_no_finding if total_no_finding > 0 else 0
    exact_one_pathology_percentage = exact_one_pathology / total_one_pathology if total_one_pathology > 0 else 0
    exact_multiple_pathologies_percentage = exact_multiple_pathologies / total_multiple_pathologies if total_multiple_pathologies > 0 else 0
    
    print(f"\n\nDataset Characteristics:")
    print(f"No finding proportion: {total_no_finding/n:.2f}")
    print(f"One pathology proportion: {total_one_pathology/n:.2f}")
    print(f"Multiple pathologies proportion: {total_multiple_pathologies/n:.2f}")
    
    print(f"\nExact Matches Characteristics:")
    print(f"Exact matches: {exact_matches_percentage:.2f}")
    print(f"Exact no finding: {exact_no_finding_percentage:.2f}")
    print(f"Exact one pathology: {exact_one_pathology_percentage:.2f}")
    print(f"Exact multiple pathologies: {exact_multiple_pathologies_percentage:.2f}")
    
    total_num_of_pathologies = sum([len(pathologies) for pathologies in ground_truth.values()])
    correct_matches_percentage = correct_matches / total_num_of_pathologies
    correct_no_finding_percentage = correct_no_finding / total_no_finding if total_no_finding > 0 else 0
    correct_pathology_percentage = correct_pathology / (total_num_of_pathologies - total_no_finding) if total_num_of_pathologies - total_no_finding > 0 else 0
    
    print(f"\nIndividual Pathology Characteristics:")
    print(f"Correct matches: {correct_matches_percentage:.2f}")
    print(f"Correct no finding: {correct_no_finding_percentage:.2f}")
    print(f"Correct pathology: {correct_pathology_percentage:.2f}")

    print(f"\nTop K Pathology Match Characteristics:")
    print(f"Top K Pathology Match: {correct_top_k_pathology_match/n:.2f}")

    return exact_matches_percentage, exact_no_finding_percentage, exact_one_pathology_percentage, exact_multiple_pathologies_percentage, correct_matches_percentage, correct_no_finding_percentage, correct_pathology_percentage

In [5]:
#VinDr paths
vinDr_test_ground_truth_path = Path("/vol/biomedic3/bglocker/ugproj2324/nns20/datasets/VinDr-CXR/test_set_three_splits/VinDr_test_test_split_with_one_hot_labels.csv")

vindr_pathologies = ["Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
            "Clavicle fracture", "Consolidation", "Emphysema", "Enlarged PA",
            "ILD", "Infiltration", "Lung Opacity", "Lung cavity", "Lung cyst",
            "Mediastinal shift","Nodule/Mass", "Pleural effusion", "Pleural thickening",
            "Pneumothorax", "Pulmonary fibrosis","Rib fracture", "Other lesion",
            "No finding"] 

# CheXpert paths
cheXpert_test_ground_truth_path = Path("/vol/biomedic3/bglocker/ugproj2324/nns20/datasets/CheXpert/test.csv")
cheXpert_test_path = Path("/vol/biodata/data/chest_xray/CheXpert-v1.0-small/CheXpert-v1.0-small/test")

# CheXpert pathologies
cheXpert_pathologies = ['No Finding','Enlarged Cardiomediastinum','Cardiomegaly','Lung Opacity',
        'Lung Lesion','Edema','Consolidation','Pneumonia','Atelectasis','Pneumothorax',
        'Pleural Effusion','Pleural Other','Fracture','Support Devices']

### CheXpert Evals

In [20]:
cheXpert_folder = Path("/vol/biomedic3/bglocker/ugproj2324/nns20/cxr-agent/base_agent/evaluation/cheXpert")

# ignore_patholgies = {"support devices"}
ignore_patholgies = {}
# for each file in the cheXpert folder, calculate the metrics

for index, file in enumerate(cheXpert_folder.iterdir()):
  
    print(f"\n{str(file).split('/')[-1]}")
    image_id_to_pathology_dict = build_image_id_to_pathology_dict(file, cheXpert_pathologies, header = False, pathologies_check = True, sentences = True)
    # print(image_id_to_pathology_dict)
    calculate_accuracy_metrics(cheXpert_test_ground_truth_path, image_id_to_pathology_dict, cheXpert_pathologies, synonym_mappings = [], threshold = False, k = 1, ignore_patholgies = ignore_patholgies)



chexagent.txt


Dataset Characteristics:
No finding proportion: 0.25
One pathology proportion: 0.18
Multiple pathologies proportion: 0.57

Exact Matches Characteristics:
Exact matches: 0.25
Exact no finding: 1.00
Exact one pathology: 0.00
Exact multiple pathologies: 0.00

Individual Pathology Characteristics:
Correct matches: 0.18
Correct no finding: 1.00
Correct pathology: 0.09

Top K Pathology Match Characteristics:
Top K Pathology Match: 0.39

chexagent_agent.txt


Dataset Characteristics:
No finding proportion: 0.25
One pathology proportion: 0.18
Multiple pathologies proportion: 0.57

Exact Matches Characteristics:
Exact matches: 0.25
Exact no finding: 1.00
Exact one pathology: 0.00
Exact multiple pathologies: 0.00

Individual Pathology Characteristics:
Correct matches: 0.10
Correct no finding: 1.00
Correct pathology: 0.00

Top K Pathology Match Characteristics:
Top K Pathology Match: 0.25

llama3_agent.txt


Dataset Characteristics:
No finding proportion: 0.25
One pathology propo

### VinDr Evals

In [7]:
vinDr_folder = Path("/vol/biomedic3/bglocker/ugproj2324/nns20/cxr-agent/base_agent/evaluation/VinDr")

# ignore_patholgies = {"support devices"}
ignore_patholgies = {}
# for each file in the cheXpert folder, calculate the metrics

for index, file in enumerate(vinDr_folder.iterdir()):
  
    print(f"\n{str(file).split('/')[-1]}")
    image_id_to_pathology_dict = build_image_id_to_pathology_dict(file, vindr_pathologies, header = False, pathologies_check = True, sentences = True)
    print(len(image_id_to_pathology_dict))
    # print(image_id_to_pathology_dict)
    calculate_accuracy_metrics(vinDr_test_ground_truth_path, image_id_to_pathology_dict, vindr_pathologies, synonym_mappings = [], threshold = False, k = 1, ignore_patholgies = ignore_patholgies)



chexagent.txt
449


Dataset Characteristics:
No finding proportion: 0.68
One pathology proportion: 0.12
Multiple pathologies proportion: 0.20

Exact Matches Characteristics:
Exact matches: 0.68
Exact no finding: 0.99
Exact one pathology: 0.02
Exact multiple pathologies: 0.01

Individual Pathology Characteristics:
Correct matches: 0.51
Correct no finding: 0.99
Correct pathology: 0.03

Top K Pathology Match Characteristics:
Top K Pathology Match: 0.69

llama3_agent.txt
448


Dataset Characteristics:
No finding proportion: 0.68
One pathology proportion: 0.12
Multiple pathologies proportion: 0.20

Exact Matches Characteristics:
Exact matches: 0.61
Exact no finding: 0.86
Exact one pathology: 0.21
Exact multiple pathologies: 0.00

Individual Pathology Characteristics:
Correct matches: 0.54
Correct no finding: 0.86
Correct pathology: 0.21

Top K Pathology Match Characteristics:
Top K Pathology Match: 0.71

gemini_agent.txt
448


Dataset Characteristics:
No finding proportion: 0.68
One pathol

In [8]:
vinDr_double_check_file = Path("/vol/biomedic3/bglocker/ugproj2324/nns20/cxr-agent/cheXagent/evaluation/VinDr/chexagent/temperature_0.5_2/vindr_chexagent_agent_findings_0")
image_id_to_pathology_dict = build_image_id_to_pathology_dict(vinDr_double_check_file, vindr_pathologies, header = False, pathologies_check = True, sentences = True)
calculate_accuracy_metrics(vinDr_test_ground_truth_path, image_id_to_pathology_dict, vindr_pathologies, synonym_mappings = [], threshold = False, k = 1, ignore_patholgies = ignore_patholgies)




Dataset Characteristics:
No finding proportion: 0.68
One pathology proportion: 0.12
Multiple pathologies proportion: 0.20

Exact Matches Characteristics:
Exact matches: 0.68
Exact no finding: 1.00
Exact one pathology: 0.02
Exact multiple pathologies: 0.01

Individual Pathology Characteristics:
Correct matches: 0.51
Correct no finding: 1.00
Correct pathology: 0.03

Top K Pathology Match Characteristics:
Top K Pathology Match: 0.70


(0.6822222222222222,
 0.9967320261437909,
 0.017857142857142856,
 0.011363636363636364,
 0.5147058823529411,
 0.9967320261437909,
 0.032679738562091505)

In [36]:
# # TODO Figure out why difference
collected_image_ids = set()
with open("/vol/biomedic3/bglocker/ugproj2324/nns20/cxr-agent/base_agent/evaluation/VinDr/sample.txt") as f:
    for line in f.readlines():
        image_id = line.strip()
        collected_image_ids.add(image_id)

all_image_ids = set() 
with open("/vol/biomedic3/bglocker/ugproj2324/nns20/datasets/VinDr-CXR/test_set_three_splits/VinDr_test_test_split.txt") as f:
    for line in f.readlines():
        image_id = line.strip()
        all_image_ids.add(image_id)

missing_image_ids = all_image_ids.difference(collected_image_ids)
print(f"Missing image ids: {missing_image_ids}")

other_way_around = collected_image_ids.difference(all_image_ids)
print(f"Other way around: {other_way_around}")

Missing image ids: {'71e9666b94e09e7485ab6350036faea6', 'a384e99cde85590d5d4273d65b921282'}
Other way around: set()


### Image Text Reasoning

In [22]:
image_id_to_ground_truth_location = {}
ground_truth_itr_path = Path("/vol/biomedic3/bglocker/ugproj2324/nns20/datasets/CheXbench/image_text_reasoning_task")
with open(ground_truth_itr_path, 'r') as f:
    for line in f.readlines():
        image_id = line.split(',')[2]
        question = line.split(',')[3]
        answer = line.split(',')[4]
        option_0 = line.split(',')[5]
        option_1 = line.split(',')[6]

        if answer == 0:
            image_id_to_ground_truth_location[image_id] = option_0
        else:
            image_id_to_ground_truth_location[image_id] = option_1


cheXbench_collected_results = Path("/vol/biomedic3/bglocker/ugproj2324/nns20/cxr-agent/base_agent/evaluation/cheXbench")
specific_pathologies = ["opacity","atelectasis","effusion"]

for index, file in enumerate(cheXbench_collected_results.iterdir()):
    if file.suffix != ".txt":
        continue

    total = 0
    correct_location = 0
    correct_pathology = 0

    for line in file.open('r').readlines():
        image_id = line.split(',')[0]
        predicted_location = (',').join(line.split(',')[1:])

        for pathology in specific_pathologies:
            if pathology in image_id_to_ground_truth_location[image_id].lower() and pathology in predicted_location.lower():
                correct_pathology += 1

        if "left" in image_id_to_ground_truth_location[image_id].lower() and "left" in predicted_location.lower():
            correct_location += 1
        
        if "right" in image_id_to_ground_truth_location[image_id].lower() and "right" in predicted_location.lower():
            correct_location += 1
        
        total += 1
        
    print(f"\n{file.name}")
    print(f"Pathology Accuracy: {correct_pathology/total:.2f}")
    print(f"Location Accuracy: {correct_location/total:.2f}")


chexagent.txt
Pathology Accuracy: 0.99
Location Accuracy: 0.48

gemini_agent.txt
Pathology Accuracy: 0.78
Location Accuracy: 0.28

chexagent_agent.txt
Pathology Accuracy: 0.99
Location Accuracy: 0.53

llama3_agent.txt
Pathology Accuracy: 0.81
Location Accuracy: 0.17
