# Sample Notebook for Zero-Shot Inference with CheXzero
This notebook walks through how to use CheXzero to perform zero-shot inference on a chest x-ray image dataset.

## Import Libraries

In [266]:
import os
import numpy as np
import pandas as pd
from pathlib import Path
from typing import List, Tuple, Optional

import sys
sys.path.append('../')

from eval import evaluate, bootstrap
from zero_shot import make, make_true_labels, run_softmax_eval

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Directories and Constants

In [None]:
## Define Zero Shot Labels and Templates

# ----- DIRECTORIES ------ #
cxr_filepath: str = '../test_data/cxr.h5' # filepath of chest x-ray images (.h5)
cxr_true_labels_path: Optional[str] = '../data/groundtruth.csv' # (optional for evaluation) if labels are provided, provide path
model_dir: str = '../checkpoints_train/pt-imp' # where pretrained models are saved (.pt) 
predictions_dir: Path = Path('../predictions-val') # where to save predictions
cache_dir: str = predictions_dir / "cached_val" # where to cache ensembled predictions

context_length: int = 77

# ------- LABELS ------  #
# Define labels to query each image | will return a prediction for each label
cxr_labels: List[str] = ['Atelectasis','Cardiomegaly', 
                                      'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion',
                                      'Lung Opacity', 'No Finding','Pleural Effusion', 'Pleural Other', 'Pneumonia', 
                                      'Pneumothorax', 'Support Devices']

# ---- TEMPLATES ----- # 
# Define set of templates | see Figure 1 for more details                        
cxr_pair_template: Tuple[str] = ("{}", "no {}")

# ----- MODEL PATHS ------ #
# If using ensemble, collect all model paths
model_paths = ['../checkpoints_train/pt-imp/checkpoint_18000.pt']
# for subdir, dirs, files in os.walk(model_dir):
#     for file in files:
#         full_dir = os.path.join(subdir, file)
#         model_paths.append(full_dir)
        
print(model_paths)

['../checkpoints10gb50e/pt-imp/checkpoint_18000.pt']


## Run Inference

In [268]:
## Run the model on the data set using ensembled models
def ensemble_models(
    model_paths: List[str], 
    cxr_filepath: str, 
    cxr_labels: List[str], 
    cxr_pair_template: Tuple[str], 
    cache_dir: str = None, 
    save_name: str = None,
) -> Tuple[List[np.ndarray], np.ndarray]: 
    """
    Given a list of `model_paths`, ensemble model and return
    predictions. Caches predictions at `cache_dir` if location provided.

    Returns a list of each model's predictions and the averaged
    set of predictions.
    """

    predictions = []
    model_paths = sorted(model_paths) # ensure consistency of 
    print(model_paths)
    for path in model_paths: # for each model
        model_name = Path(path).stem

        # load in model and `torch.DataLoader`
        model, loader = make(
            model_path=path, 
            cxr_filepath=cxr_filepath, 
        ) 
        
        # path to the cached prediction
        if cache_dir is not None:
            if save_name is not None: 
                cache_path = Path(cache_dir) / f"{save_name}_{model_name}.npy"
            else: 
                cache_path = Path(cache_dir) / f"{model_name}.npy"

        # if prediction already cached, don't recompute prediction
        if cache_dir is not None and os.path.exists(cache_path): 
            print("Loading cached prediction for {}".format(model_name))
            y_pred = np.load(cache_path)
        else: # cached prediction not found, compute preds
            print("Inferring model {}".format(path))
            y_pred = run_softmax_eval(model, loader, cxr_labels, cxr_pair_template)
            if cache_dir is not None: 
                Path(cache_dir).mkdir(exist_ok=True, parents=True)
                np.save(file=cache_path, arr=y_pred)
        predictions.append(y_pred)
    
    # compute average predictions
    print(predictions)
    y_pred_avg = np.mean(predictions, axis=0)
    print(y_pred_avg)
    return predictions, y_pred_avg

In [269]:
predictions, y_pred_avg = ensemble_models(
    model_paths=model_paths, 
    cxr_filepath=cxr_filepath, 
    cxr_labels=cxr_labels, 
    cxr_pair_template=cxr_pair_template, 
    cache_dir=cache_dir,
)

['../checkpoints10gb50e/pt-imp/checkpoint_18000.pt']
Loading cached prediction for checkpoint_18000
[array([[0.49611872, 0.49636057, 0.5076487 , ..., 0.51039606, 0.4975381 ,
        0.51275533],
       [0.5018535 , 0.50363845, 0.4970435 , ..., 0.4977175 , 0.49737707,
        0.49912113],
       [0.5029214 , 0.49823767, 0.5070606 , ..., 0.507116  , 0.50782096,
        0.5144976 ],
       ...,
       [0.49379626, 0.49402764, 0.4913916 , ..., 0.4954754 , 0.49808016,
        0.48922092],
       [0.5146677 , 0.5107169 , 0.5075194 , ..., 0.5018823 , 0.51208967,
        0.49710515],
       [0.4972319 , 0.5026027 , 0.49866775, ..., 0.4938344 , 0.5048814 ,
        0.49989307]], dtype=float32)]
[[0.49611872 0.49636057 0.5076487  ... 0.51039606 0.4975381  0.51275533]
 [0.5018535  0.50363845 0.4970435  ... 0.4977175  0.49737707 0.49912113]
 [0.5029214  0.49823767 0.5070606  ... 0.507116   0.50782096 0.5144976 ]
 ...
 [0.49379626 0.49402764 0.4913916  ... 0.4954754  0.49808016 0.48922092]
 [0.51466

In [270]:
# save averaged preds
pred_name = "chexpert_preds.npy" # add name of preds
# pred_name="chx.txt"
predictions_dir = predictions_dir / pred_name
np.save(file=predictions_dir, arr=y_pred_avg)
# np.savetxt(predictions_dir, y_pred_avg, fmt="%.6f")

In [271]:
# predictions_dir = Path("../predictions")
# predictions_dir.mkdir(parents=True, exist_ok=True)  # Ensure directory exists

# pred_file = predictions_dir / "chexpert_preds.txt"  # Save as a .txt file instead of .npy
# np.savetxt(pred_file, y_pred_avg, fmt="%.6f")


In [None]:

import pandas as pd

# Load the mimic-impressions.csv file
impressions_df = pd.read_csv("../test_data/mimic_impressions.csv")

# Extract numeric study_id from filename and convert to float
impressions_df["study_id"] = impressions_df["filename"].str.extract(r"s(\d+)")[0].astype(float)

# Load the second file
data_df = pd.read_csv("../MIMIC_CXR_report_phenotypes.csv", encoding='utf-8', encoding_errors='replace')

# Merge data_df with impressions_df to preserve duplicate occurrences of study_id
filtered_df = impressions_df.merge(data_df, on="study_id", how="inner")
filtered_df = filtered_df.fillna(0)

# Display the first 5 rows
print(filtered_df.iloc[0:5])

# Save the filtered data
filtered_df.to_csv("filtered_data.csv", index=False)


        filename                                         impression  \
0  s52100637.txt  1. Unchanged appearance of mild pulmonary edem...   
1  s52100637.txt  1. Unchanged appearance of mild pulmonary edem...   
2  s52100637.txt  1. Unchanged appearance of mild pulmonary edem...   
3  s52974196.txt           Moderate pulmonary edema.checkpoint_9000   
4  s52974196.txt                          Moderate pulmonary edema.   

     study_id  subject_id                                             report  \
0  52100637.0  10249381.0  FINAL REPORT\n EXAMINATION:  Chest pain\n \n I...   
1  52100637.0  10249381.0  FINAL REPORT\n EXAMINATION:  Chest pain\n \n I...   
2  52100637.0  10249381.0  FINAL REPORT\n EXAMINATION:  Chest pain\n \n I...   
3  52974196.0  10245890.0  FINAL REPORT\n INDICATION:  ___-year-old with ...   
4  52974196.0  10245890.0  FINAL REPORT\n INDICATION:  ___-year-old with ...   

   Atelectasis  Cardiomegaly  Consolidation  Edema  \
0          0.0           1.0          

In [273]:
cxr_true_labels_path: Optional[str] = 'filtered_data.csv'

## (Optional) Evaluate Results
If ground truth labels are available, compute AUC on each pathology to evaluate the performance of the zero-shot model. 

In [274]:
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, hamming_loss

def evaluate1(y_pred, y_true, cxr_labels):
    """
    Evaluates a multi-label classification model.

    Args:
        y_pred (numpy.ndarray): Predicted probabilities (shape: [num_samples, num_classes]).
        y_true (numpy.ndarray): True labels (binary) (shape: [num_samples, num_classes]).
        cxr_labels (list): List of class labels.

    Returns:
        dict: Dictionary containing accuracy, precision, recall, F1-score, and Hamming Loss.
    """
    num_classes = y_true.shape[1]
    results = {}

    # Convert predictions to binary (threshold 0.5)
    y_pred_binary = (y_pred >= 0.5).astype(int)
    num_all_zero_rows = np.sum(np.all(y_pred_binary == 0, axis=1))

    print("Number of rows where all values are 0:", num_all_zero_rows)
    # Ensure y_true is binary (no unexpected values)
    y_true = (y_true > 0).astype(int)

    # Compute accuracy per class
    class_accuracies = [accuracy_score(y_true[:, i], y_pred_binary[:, i]) for i in range(num_classes)]
    overall_accuracy = accuracy_score(y_true.flatten(), y_pred_binary.flatten())  # Micro accuracy

    # Compute Precision, Recall, and F1-score
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred_binary, average=None)

    # Compute Hamming Loss (Lower is better)
    hamming = hamming_loss(y_true, y_pred_binary)

    # Store results
    results["Accuracy per class"] = dict(zip(cxr_labels, class_accuracies))
    results["Overall Accuracy"] = overall_accuracy
    results["Precision"] = dict(zip(cxr_labels, precision))
    results["Recall"] = dict(zip(cxr_labels, recall))
    results["F1-score"] = dict(zip(cxr_labels, f1))
    results["Hamming Loss"] = hamming

    return results


In [275]:
# Generate ground truth labels
test_true = make_true_labels(cxr_true_labels_path=cxr_true_labels_path, cxr_labels=cxr_labels)

# Ensure predictions are in the correct format
test_pred = y_pred_avg

# Print shapes and sample values for debugging
print("Shape of test_true:", test_true.shape)
print("Shape of test_pred:", test_pred.shape)
print("Sample Prediction:", test_pred[0])
print("Sample True Labels:", test_true[0])

# Ensure test_true and test_pred are NumPy arrays
test_true = np.array(test_true)
test_pred = np.array(test_pred)

# Evaluate model
cxr_results = evaluate1(test_pred, test_true, cxr_labels)

# Bootstrap evaluations for 95% confidence intervals
# bootstrap_results = bootstrap(test_pred, test_true, cxr_labels)

# Print evaluation results
print("Evaluation Results:", cxr_results)
# print("Bootstrap Results:", bootstrap_results)


Shape of test_true: (328, 14)
Shape of test_pred: (328, 14)
Sample Prediction: [0.49611872 0.49636057 0.5076487  0.49583548 0.4958477  0.5075441
 0.4888963  0.48862243 0.4973661  0.502445   0.49782807 0.51039606
 0.4975381  0.51275533]
Sample True Labels: [0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
Number of rows where all values are 0: 6
Evaluation Results: {'Accuracy per class': {'Atelectasis': 0.5152439024390244, 'Cardiomegaly': 0.5548780487804879, 'Consolidation': 0.4847560975609756, 'Edema': 0.6859756097560976, 'Enlarged Cardiomediastinum': 0.5670731707317073, 'Fracture': 0.6310975609756098, 'Lung Lesion': 0.573170731707317, 'Lung Opacity': 0.5914634146341463, 'No Finding': 0.2530487804878049, 'Pleural Effusion': 0.5274390243902439, 'Pleural Other': 0.5060975609756098, 'Pneumonia': 0.375, 'Pneumothorax': 0.4573170731707317, 'Support Devices': 0.46646341463414637}, 'Overall Accuracy': 0.5135017421602788, 'Precision': {'Atelectasis': 0.23076923076923078, 'Cardiomegaly': 0.2313432835

In [276]:
# display AUC with confidence intervals
bootstrap_results[1]

NameError: name 'bootstrap_results' is not defined