In [None]:
import torch
import numpy as np
import pandas as pd
import tta_fns
import zero_shot
import reliability_diagrams as rd
from typing import Optional, List
from pathlib import Path
import matplotlib.pyplot as plt
%matplotlib inline

cxr_true_labels_path: Optional[str] = 'data/groundtruth.csv' # (optional for evaluation) if labels are provided, provide path
model_dir: str = 'checkpoints/chexzero_weights' # where pretrained models are saved (.pt) 
predictions_dir: Path = Path('predictions') # where to save predictions
cache_dir: str = predictions_dir / "cached" # where to cache ensembled predictions

context_length: int = 77

# ------- LABELS ------  #
# Define labels to query each image | will return a prediction for each label
cxr_labels: List[str] = [
    'Atelectasis','Cardiomegaly', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 
    'Fracture', 'Lung Lesion', 'Lung Opacity', 'No Finding','Pleural Effusion', 'Pleural Other',
    'Pneumonia', 'Pneumothorax', 'Support Devices'
    ]

In [None]:
test_pred = np.load('predictions/cached/no_tta/best_128_5e-05_original_22000_0.855.npy')

In [None]:
test_pred.shape

In [None]:

test_true = zero_shot.make_true_labels(cxr_true_labels_path=cxr_true_labels_path, cxr_labels=cxr_labels)

In [None]:
threshold = 0.5
pred_lab = (test_pred > threshold).astype(int)
confidence = np.copy(test_pred)
confidence[pred_lab == 0] = 1 - confidence[pred_lab == 0]

In [None]:
confidence

In [None]:
test_pred


In [None]:
pred_lab

In [None]:
y_true = test_true[:,0]
y_pred = pred_lab[:,0]
y_conf = test_pred[:,0]

In [None]:
def set_style():
    plt.style.use("seaborn")
    plt.rc("font", size=12)
    plt.rc("axes", labelsize=12)
    plt.rc("xtick", labelsize=12)
    plt.rc("ytick", labelsize=12)
    plt.rc("legend", fontsize=12)

In [None]:
# for i in range(14):
#     y_true = test_true[:,i]
#     y_pred = pred_lab[:,i]
#     y_conf = test_pred[:,i]
#     set_style()
#     fig = rd.reliability_diagram(y_true, y_pred, y_conf, title=cxr_labels[i])
    

In [None]:

plt.style.use("seaborn")
plt.rc("font", size=12)
plt.rc("axes", labelsize=12)
plt.rc("xtick", labelsize=12)
plt.rc("ytick", labelsize=12)
plt.rc("legend", fontsize=12)

title = 'plot'

In [None]:
fig = rd.reliability_diagram(y_true, y_pred, y_conf)

In [None]:
print(f"r")

In [None]:
dict = {}
for i in range(14):
    dict_i = {
    "true_labels" : test_true[:,i],
    "pred_labels" : pred_lab[:,i],
    "confidences" : test_pred[:,i]
    }
    dict.update({cxr_labels[i]: dict_i})
print(dict)

In [None]:

# threshold = 0.5
# pred_lab = (test_pred > threshold).astype(int)
# confidence = np.copy(test_pred)
# confidence[pred_lab == 0] = 1 - confidence[pred_lab == 0]
# y_true = test_true[:,0]
# y_pred = pred_lab[:,0]
# y_conf = test_pred[:,0]

# dict = {
#     "true_labels" : y_true,
#     "pred_labels" : y_pred,
#     "confidences" : y_conf
# }

# dict2 = {"chexzero": dict}

rd.reliability_diagrams(dict, num_cols = 7, num_rows = 2, draw_bin_importance=True)