# Sample Notebook for Zero-Shot Inference with CheXzero
This notebook walks through how to use CheXzero to perform zero-shot inference on a chest x-ray image dataset.

## Import Libraries

In [33]:
import os
import numpy as np
import pandas as pd
from pathlib import Path
from typing import List, Tuple, Optional

import sys
sys.path.append('../')

from eval import evaluate, bootstrap
from zero_shot import make, make_true_labels, run_softmax_eval

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Directories and Constants

In [39]:
## Define Zero Shot Labels and Templates

# ----- DIRECTORIES ------ #
cxr_filepath: str = '../data/chexpert_test.h5' # filepath of chest x-ray images (.h5)
cxr_true_labels_path: Optional[str] = '../data/groundtruth.csv' # (optional for evaluation) if labels are provided, provide path
model_dir: str = '../checkpoints/sample_model' # where pretrained models are saved (.pt) 
predictions_dir: Path = Path('../predictions_chexpert') # where to save predictions
cache_dir: str = predictions_dir / "cached" # where to cache ensembled predictions

context_length: int = 77

# ------- LABELS ------  #
# Define labels to query each image | will return a prediction for each label
cxr_labels: List[str] = ['Atelectasis','Cardiomegaly', 
                                      'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion',
                                      'Lung Opacity', 'No Finding','Pleural Effusion', 'Pleural Other', 'Pneumonia', 
                                      'Pneumothorax', 'Support Devices']

# ---- TEMPLATES ----- # 
# Define set of templates | see Figure 1 for more details                        
cxr_pair_template: Tuple[str] = ("{}", "no {}")

# ----- MODEL PATHS ------ #
# If using ensemble, collect all model paths
model_paths = []
for subdir, dirs, files in os.walk(model_dir):
    for file in files:
        full_dir = os.path.join(subdir, file)
        model_paths.append(full_dir)
        
print(model_paths)

['../checkpoints/sample_model/best_64_0.0001_original_16000_0.861.pt']


## Run Inference

In [40]:
## Run the model on the data set using ensembled models
def ensemble_models(
    model_paths: List[str], 
    cxr_filepath: str, 
    cxr_labels: List[str], 
    cxr_pair_template: Tuple[str], 
    cache_dir: str = None, 
    save_name: str = None,
) -> Tuple[List[np.ndarray], np.ndarray]: 
    """
    Given a list of `model_paths`, ensemble model and return
    predictions. Caches predictions at `cache_dir` if location provided.

    Returns a list of each model's predictions and the averaged
    set of predictions.
    """

    predictions = []
    model_paths = sorted(model_paths) # ensure consistency of 
    for path in model_paths: # for each model
        model_name = Path(path).stem

        # load in model and `torch.DataLoader`
        model, loader = make(
            model_path=path, 
            cxr_filepath=cxr_filepath, 
        ) 
        
        # path to the cached prediction
        if cache_dir is not None:
            if save_name is not None: 
                cache_path = Path(cache_dir) / f"{save_name}_{model_name}.npy"
            else: 
                cache_path = Path(cache_dir) / f"{model_name}.npy"

        # if prediction already cached, don't recompute prediction
        if cache_dir is not None and os.path.exists(cache_path): 
            print("Loading cached prediction for {}".format(model_name))
            y_pred = np.load(cache_path)
        else: # cached prediction not found, compute preds
            print("Inferring model {}".format(path))
            y_pred = run_softmax_eval(model, loader, cxr_labels, cxr_pair_template)
            if cache_dir is not None: 
                Path(cache_dir).mkdir(exist_ok=True, parents=True)
                np.save(file=cache_path, arr=y_pred)
        predictions.append(y_pred)
    
    # compute average predictions
    y_pred_avg = np.mean(predictions, axis=0)
    
    return predictions, y_pred_avg

In [41]:
predictions, y_pred_avg = ensemble_models(
    model_paths=model_paths, 
    cxr_filepath=cxr_filepath, 
    cxr_labels=cxr_labels, 
    cxr_pair_template=cxr_pair_template, 
    cache_dir=cache_dir,
)

Inferring model ../checkpoints/sample_model/best_64_0.0001_original_16000_0.861.pt


100%|██████████| 14/14 [00:02<00:00,  6.73it/s]
100%|██████████| 500/500 [01:53<00:00,  4.41it/s]
100%|██████████| 14/14 [00:02<00:00,  6.74it/s]
100%|██████████| 500/500 [01:53<00:00,  4.41it/s]


In [42]:
# save averaged preds
pred_name = "chexpert_preds.npy" # add name of preds
predictions_dir = predictions_dir / pred_name
np.save(file=predictions_dir, arr=y_pred_avg)

## (Optional) Evaluate Results
If ground truth labels are available, compute AUC on each pathology to evaluate the performance of the zero-shot model. 

In [45]:
# make test_true
test_pred = y_pred_avg
test_true = make_true_labels(cxr_true_labels_path=cxr_true_labels_path, cxr_labels=cxr_labels)

print(test_pred, test_true, cxr_labels)
print(test_true.shape)
print(test_pred.shape)
# evaluate model
cxr_results = evaluate(test_pred, test_true, cxr_labels)

# boostrap evaluations for 95% confidence intervals
bootstrap_results = bootstrap(test_pred, test_true, cxr_labels)

[[0.49887753 0.5016514  0.50170755 ... 0.5033554  0.49789727 0.50017124]
 [0.5018854  0.48913094 0.49798122 ... 0.5010932  0.49787554 0.4995102 ]
 [0.50100905 0.50335056 0.502935   ... 0.5044339  0.4982046  0.49964315]
 ...
 [0.5014328  0.48918235 0.4984985  ... 0.4997017  0.49687812 0.49928623]
 [0.5008477  0.50330585 0.5060644  ... 0.50820667 0.49923638 0.50100505]
 [0.5029276  0.49219155 0.5031646  ... 0.50474435 0.50265706 0.49729177]] [[0 1 0 ... 0 0 1]
 [1 0 0 ... 0 0 0]
 [1 1 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 1]
 [1 1 0 ... 0 0 1]
 [1 0 0 ... 0 0 1]] ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion', 'Lung Opacity', 'No Finding', 'Pleural Effusion', 'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices']
(500, 14)
(500, 14)


  0%|          | 0/1000 [00:00<?, ?it/s]

In [44]:
# display AUC with confidence intervals
bootstrap_results[1]

Unnamed: 0,Atelectasis_auc,Cardiomegaly_auc,Consolidation_auc,Edema_auc,Enlarged Cardiomediastinum_auc,Fracture_auc,Lung Lesion_auc,Lung Opacity_auc,No Finding_auc,Pleural Effusion_auc,Pleural Other_auc,Pneumonia_auc,Pneumothorax_auc,Support Devices_auc
mean,0.7627,0.8982,0.8838,0.8904,0.8987,0.5718,0.6512,0.8989,0.415,0.9165,0.71,0.7111,0.6421,0.5434
lower,0.7204,0.8667,0.8149,0.8546,0.8721,0.1476,0.4347,0.8691,0.3456,0.8866,0.5291,0.4983,0.4642,0.491
upper,0.8039,0.9233,0.9396,0.9228,0.9233,0.9506,0.8579,0.9225,0.4881,0.9404,0.9719,0.8804,0.815,0.5906
