In [19]:
import os
import numpy as np
import pandas as pd
from pathlib import Path
from typing import List, Tuple, Optional

import sys
sys.path.append('../')

from eval import evaluate, bootstrap
from zero_shot import make, make_true_labels, run_softmax_eval

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [28]:
## Define Zero Shot Labels and Templates

# ----- DIRECTORIES ------ #
cxr_filepath: str = '../data_padchest/padchest.h5' # filepath of chest x-ray images (.h5)
cxr_true_labels_path: Optional[str] = '../data_padchest/groundtruth.csv' # (optional for evaluation) if labels are provided, provide path
model_dir: str = '../checkpoints/sample_model' # where pretrained models are saved (.pt) 
predictions_dir: Path = Path('../predictions') # where to save predictions
# cache_dir: str = predictions_dir / "cached" # where to cache ensembled predictions

context_length: int = 77

# ------- LABELS ------  #
# Define labels to query each image | will return a prediction for each label
cxr_labels: List[str] = ['ImageID', 'ImageDir', 'StudyDate_DICOM', 'StudyID', 'PatientID',
       'PatientBirth', 'PatientSex_DICOM', 'ViewPosition_DICOM', 'Projection',
       'MethodProjection', 'Pediatric', 'Modality_DICOM', 'Manufacturer_DICOM',
       'PhotometricInterpretation_DICOM', 'PixelRepresentation_DICOM',
       'PixelAspectRatio_DICOM', 'SpatialResolution_DICOM', 'BitsStored_DICOM',
       'WindowCenter_DICOM', 'WindowWidth_DICOM', 'Rows_DICOM',
       'Columns_DICOM', 'XRayTubeCurrent_DICOM', 'Exposure_DICOM',
       'ExposureInuAs_DICOM', 'ExposureTime', 'RelativeXRayExposure_DICOM',
       'ReportID', 'Report', 'MethodLabel', 'Labels', 'Localizations',
       'LabelsLocalizationsBySentence', 'labelCUIS', 'LocalizationsCUIS']

# ---- TEMPLATES ----- # 
# Define set of templates | see Figure 1 for more details                        
cxr_pair_template: Tuple[str] = ("{}", "no {}")

# ----- MODEL PATHS ------ #
# If using ensemble, collect all model paths
model_paths = []
for subdir, dirs, files in os.walk(model_dir):
    for file in files:
        full_dir = os.path.join(subdir, file)
        model_paths.append(full_dir)
        
print(model_paths)

['../checkpoints/sample_model/best_64_0.0001_original_16000_0.861.pt']


In [29]:
## Run the model on the data set using ensembled models
def ensemble_models(
    model_paths: List[str], 
    cxr_filepath: str, 
    cxr_labels: List[str], 
    cxr_pair_template: Tuple[str], 
    cache_dir: str = None, 
    save_name: str = None,
) -> Tuple[List[np.ndarray], np.ndarray]: 
    """
    Given a list of `model_paths`, ensemble model and return
    predictions. Caches predictions at `cache_dir` if location provided.

    Returns a list of each model's predictions and the averaged
    set of predictions.
    """

    predictions = []
    model_paths = sorted(model_paths) # ensure consistency of 
    path = model_paths[0]
    model_name = Path(path).stem

    # load in model and `torch.DataLoader`
    model, loader = make(
        model_path=path, 
        cxr_filepath=cxr_filepath, 
    ) 
    
    # path to the cached prediction
    if cache_dir is not None:
        if save_name is not None: 
            cache_path = Path(cache_dir) / f"{save_name}_{model_name}.npy"
        else: 
            cache_path = Path(cache_dir) / f"{model_name}.npy"

    # if prediction already cached, don't recompute prediction
    if cache_dir is not None and os.path.exists(cache_path): 
        print("Loading cached prediction for {}".format(model_name))
        y_pred = np.load(cache_path)
    else: # cached prediction not found, compute preds
        print("Inferring model {}".format(path))
        y_pred = run_softmax_eval(model, loader, cxr_labels, cxr_pair_template)
        if cache_dir is not None: 
            Path(cache_dir).mkdir(exist_ok=True, parents=True)
            np.save(file=cache_path, arr=y_pred)
    predictions.append(y_pred)
    
    # compute average predictions
    y_pred_avg = np.mean(predictions, axis=0)
    
    return predictions, y_pred_avg

In [30]:
predictions, y_pred_avg = ensemble_models(
    model_paths=model_paths, 
    cxr_filepath=cxr_filepath, 
    cxr_labels=cxr_labels, 
    cxr_pair_template=cxr_pair_template, 
)

Inferring model ../checkpoints/sample_model/best_64_0.0001_original_16000_0.861.pt


100%|██████████| 35/35 [00:05<00:00,  6.84it/s]
100%|██████████| 137/137 [00:30<00:00,  4.49it/s]
100%|██████████| 35/35 [00:05<00:00,  6.82it/s]
100%|██████████| 137/137 [00:30<00:00,  4.49it/s]


In [33]:
# make test_true
test_pred = y_pred_avg

print(test_pred.shape)
print(test_pred)

(137, 35)
[[0.4973931  0.49869898 0.4940695  ... 0.4977452  0.49521396 0.4964329 ]
 [0.49503478 0.49806723 0.49632862 ... 0.49879998 0.49670997 0.4982326 ]
 [0.4974887  0.49925068 0.4928193  ... 0.49716708 0.49543622 0.49562556]
 ...
 [0.49731663 0.49919915 0.49148542 ... 0.49771944 0.49660465 0.49806395]
 [0.49884188 0.49894333 0.49120635 ... 0.49745607 0.4958816  0.49686214]
 [0.49982184 0.49753514 0.4892039  ... 0.49646258 0.49799892 0.49466464]]
