In [5]:
import tensorflow as tf
import keras
from keras.models import load_model
from keras_contrib.applications.nasnet import NASNetLarge, NASNetMobile
from main_hex import HEXLoss
from metrics import auc_roc, acc
import os
import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings("ignore")    

CHECKPOINTS_PATH = "output/"
MODEL_CLASSES= ['No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly',
                'Lung Opacity',
                'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia',
                'Atelectasis',
                'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture',
                'Support Devices']

def _load_model(checkpoint_name):
    return load_model(CHECKPOINTS_PATH+checkpoint_name, custom_objects={'auc_roc': auc_roc})

df_val = pd.read_csv('CheXpert-v1.0-small/valid.csv')
img_data_gen = keras.preprocessing.image.ImageDataGenerator(  # rotation_range=7,
    rescale=1 / 255)

val_gen = img_data_gen.flow_from_dataframe(df_val,
                                           directory=None,
                                           x_col='Path',
                                           y_col=MODEL_CLASSES,
                                           target_size=(224, 224),
                                           color_mode='grayscale',
                                           class_mode='raw',
                                           batch_size=64,
                                           shuffle=False,
                                           interpolation='box')

Found 234 validated image filenames.


In [3]:
# Validation data histogram
df_val.iloc[:,5:].sum(axis=0).astype(int)

No Finding                     38
Enlarged Cardiomediastinum    109
Cardiomegaly                   68
Lung Opacity                  126
Lung Lesion                     1
Edema                          45
Consolidation                  33
Pneumonia                       8
Atelectasis                    80
Pneumothorax                    8
Pleural Effusion               67
Pleural Other                   1
Fracture                        0
Support Devices               107
dtype: int64

In [4]:
#model_checkpoints = os.listdir(CHECKPOINTS_PATH)
print("loading model...")
model_file = CHECKPOINTS_PATH+'model2_hex_MOBILE.02-13.61.h5'
model = load_model(model_file, compile=False)
print(model_file, 'loaded successfully.')

loading model...
output/model2_hex_MOBILE.02-13.61.h5 loaded successfully.


In [75]:
for layer in model.layers:
    if "predictions" == layer.name:
        print(layer.name)
        print(layer.get_weights())

In [33]:
model.summary()

Model: "NASNet_with_auxiliary"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 224, 224, 1)  0                                            
__________________________________________________________________________________________________
stem_conv1 (Conv2D)             (None, 111, 111, 32) 288         input_1[0][0]                    
__________________________________________________________________________________________________
stem_bn1 (BatchNormalization)   (None, 111, 111, 32) 128         stem_conv1[0][0]                 
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 111, 111, 32) 0           stem_bn1[0][0]                   
______________________________________________________________________________

separable_conv_2_bn_normal_righ (None, 7, 7, 176)    704         separable_conv_2_normal_right2_11
__________________________________________________________________________________________________
normal_left3_11 (AveragePooling (None, 7, 7, 176)    0           normal_bn_1_11[0][0]             
__________________________________________________________________________________________________
normal_left4_11 (AveragePooling (None, 7, 7, 176)    0           adjust_bn_11[0][0]               
__________________________________________________________________________________________________
normal_right4_11 (AveragePoolin (None, 7, 7, 176)    0           adjust_bn_11[0][0]               
__________________________________________________________________________________________________
separable_conv_2_bn_normal_left (None, 7, 7, 176)    704         separable_conv_2_normal_left5_11[
__________________________________________________________________________________________________
normal_add

In [6]:
print("Predicting...")
preds = model.predict_generator(val_gen)

Predicting...



In [56]:
backbone = NASNetMobile(input_shape=(224,224, 1),
                            dropout=0.5,
                            weight_decay=5e-5,
                            use_auxiliary_branch=True,
                            include_top=True,
                            weights=None,
                            input_tensor=None,
                            pooling=None,
                            classes=14,
                            activation='sigmoid')

In [65]:
backbone.summary()
backbone.compile(keras.optimizers.Nadam(lr=1e-4, beta_1=0.9, beta_2=0.999),
                  loss=HEXLoss,
                  metrics=[acc, auc_roc],
                  loss_weights=[1, 0.4])

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 224, 224, 1)  0                                            
__________________________________________________________________________________________________
stem_conv1 (Conv2D)             (None, 111, 111, 32) 288         input_1[0][0]                    
__________________________________________________________________________________________________
stem_bn1 (BatchNormalization)   (None, 111, 111, 32) 128         stem_conv1[0][0]                 
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 111, 111, 32) 0           stem_bn1[0][0]                   
__________________________________________________________________________________________________
reduction_

In [8]:
df_preds = pd.DataFrame(np.array(preds), columns=MODEL_CLASSES)
df_truth = df_val.iloc[:,5:]

In [9]:
df_preds.sum(1)

0       9.945986
1       8.113791
2       7.689355
3      11.917661
4       9.006002
         ...    
229     9.749418
230    10.832539
231    11.139215
232    10.342036
233    11.722413
Length: 234, dtype: float32

### Evaluating Model

In [10]:
from sklearn import metrics
from matplotlib import pyplot as plt
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 4})

benchmark_classes = ['Cardiomegaly', 'Edema','Consolidation','Atelectasis','Pleural Effusion']
for class_label in benchmark_classes:
    i = df_truth.columns.tolist().index(class_label)
    #label = [7,10,11,13,15]#header[i]
    y_pred = df_preds.iloc[:, i].values
    y_true = df_truth.iloc[:, i].values
    fpr, tpr, thresholds = metrics.roc_curve(y_true, y_pred, pos_label=1)
    auc = metrics.auc(fpr, tpr)
    print(class_label, 'auc', auc)
    acc = metrics.accuracy_score(y_true, (y_pred >= 0.5).astype(int), normalize=True)

    plt.figure(figsize=(2, 2), dpi=100)
    plt.xlim((0, 1.0))
    plt.ylim((0, 1.0))
    plt.xticks([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
               [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])
    plt.yticks([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
               [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])
    plt.xlabel('1 - Specificity')
    plt.ylabel('Sensitivity')
    plt.title('{} ROC, AUC : {:.3f}, Acc : {:.3f}'.format(class_label, auc, acc))
    plt.plot(fpr, tpr, '-b')
    
    plt.grid()
    # plt.savefig(
    #     os.path.join("./", "Epoch_10"
    #                  + '_' + MODEL_CLASSES[i] + '_roc.png'), bbox_inches='tight')

Cardiomegaly auc 0.7282955350815025
Edema auc 0.8041152263374486
Consolidation auc 0.8043117744610282
Atelectasis auc 0.6375811688311689
Pleural Effusion auc 0.8162481008132988
