In [None]:
!nvidia-smi

In [None]:
from transformation import *
from utilities import *
from calculate_disparity import *
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import tensorflow as tf
import cv2 as cv
import warnings
import skimage.transform as st
import gc
import os

print(tf.__version__)
warnings.filterwarnings("ignore")

gpus = tf.config.list_physical_devices(device_type='GPU')
tf.config.set_visible_devices(devices=gpus[0], device_type='GPU')

In [None]:
seed = 2021
os.environ['PYTHONHASHSEED']=str(seed)
tf.random.set_seed(seed)
np.random.seed(seed)
Labels_diseases = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion', 'Lung Opacity', 'No Finding', 'Pleural Effusion', 'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices']

In [None]:
def get_data(dataset='mimic', data_split='test', types='race', feature_type=-1, random_aug=False):
    np.random.seed(2021)
    
    X = []
    y = []
    if (dataset == 'mimic'):
        if (random_aug == True):
            if (data_split == 'train'):
                filename = 'data/mimic_train_transformed.tfrecords'
            elif(data_split == 'val'):
                filename = 'data/mimic_val_transformed.tfrecords'
            else:
                filename = 'data/mimic_test_transformed.tfrecords'
        else:
            if (data_split == 'train'):
                filename = 'data/mimic_train.tfrecords'
            elif(data_split == 'val'):
                filename = 'data/mimic_val.tfrecords'
            else:
                filename = 'data/mimic_test.tfrecords'
                
    elif (dataset == 'chexpert'):
        if (random_aug == True):
            if (data_split == 'train'):
                filename = 'data/Chexpert_train_transformed.tfrecords'
            elif(data_split == 'val'):
                filename = 'data/Chexpert_val_transformed.tfrecords'
            else:
                filename = 'data/Chexpert_test_transformed.tfrecords'
        if (data_split == 'train'):
            filename = '../Data/Chexpert_train.tfrecords'
        elif(data_split == 'val'):
            filename = '../Data/Chexpert_val.tfrecords'
        else:
            filename = '../Data/Chexpert_test.tfrecords'
        
    raw_dataset = tf.data.TFRecordDataset(filename)
    for raw_record in raw_dataset:
        sub_y = []

        example = tf.train.Example()
        example.ParseFromString(raw_record.numpy())
        
        race = example.features.feature['race'].int64_list.value[0]
        
        if (types == 'race'):
            feature = race
        elif (types == 'age'):
            age = example.features.feature['age'].int64_list.value[0]
            if (age > 0):
                age -= 1
            feature = age
        else:
            gender = example.features.feature['gender'].int64_list.value[0]
            feature = gender

        
        if not (race == 0 or race == 1 or race == 4):
            continue
            
        if (feature_type == -1 or feature == feature_type):

            sub_y.append(1 if example.features.feature['Atelectasis'].float_list.value[0] == 1 else 0)
            sub_y.append(1 if example.features.feature['Cardiomegaly'].float_list.value[0] == 1 else 0)
            sub_y.append(1 if example.features.feature['Consolidation'].float_list.value[0] == 1 else 0)
            sub_y.append(1 if example.features.feature['Edema'].float_list.value[0] == 1 else 0)
            sub_y.append(1 if example.features.feature['Enlarged Cardiomediastinum'].float_list.value[0] == 1 else 0)
            sub_y.append(1 if example.features.feature['Fracture'].float_list.value[0] == 1 else 0)
            sub_y.append(1 if example.features.feature['Lung Lesion'].float_list.value[0] == 1 else 0)
            sub_y.append(1 if example.features.feature['Lung Opacity'].float_list.value[0] == 1 else 0)
            sub_y.append(1 if example.features.feature['No Finding'].float_list.value[0] == 1 else 0)
            sub_y.append(1 if example.features.feature['Pleural Effusion'].float_list.value[0] == 1 else 0)
            sub_y.append(1 if example.features.feature['Pleural Other'].float_list.value[0] == 1 else 0)
            sub_y.append(1 if example.features.feature['Pneumonia'].float_list.value[0] == 1 else 0)
            sub_y.append(1 if example.features.feature['Pneumothorax'].float_list.value[0] == 1 else 0)
            sub_y.append(1 if example.features.feature['Support Devices'].float_list.value[0] == 1 else 0)
            
            nparr = np.fromstring(example.features.feature['jpg_bytes'].bytes_list.value[0], np.uint8)
            img_np = cv.imdecode(nparr, cv.IMREAD_GRAYSCALE)

            X.append(np.float32(st.resize(img_np, (224, 224))))
            y.append(sub_y)
    
    return np.array(X), np.array(y)

# Radiological labels Detection

In [None]:
INPUT_SHAPE = (128, 128, 1)
def define_model_diseases():
    
    input = tf.keras.layers.Input(shape=INPUT_SHAPE)
    
    reshape_layer = tf.keras.layers.UpSampling3D(size=(1,1,3))(input)
    
    base_model = tf.keras.applications.densenet.DenseNet121(
            include_top=False, weights='imagenet', input_shape=(128, 128, 3), pooling='max')(reshape_layer)
         
    pred_layer = tf.keras.layers.Dense(14, activation='sigmoid')(base_model)
 
    model = tf.keras.Model(inputs=input, outputs=pred_layer)
    
    model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
                 optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), metrics='AUC')
  
    return model

model = define_model_diseases()

model.load_weights('checkpoints/model_mimic_128')

In [None]:
model.summary()

In [None]:
checkpoint_filepath = 'checkpoints/model_mimic_128'

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_best_only=True,
    save_weights_only=True,
    monitor='val_loss',
    mode='min')

callback = [tf.keras.callbacks.LearningRateScheduler(scheduler),
            tf.keras.callbacks.EarlyStopping(mode='min', patience=4, monitor='val_loss'),
            model_checkpoint_callback]

# mimic
X_train, y_train = get_data(dataset='mimic', data_split='train', types='race', feature_type=-1, random_aug=False)
X_val, y_val = get_data(dataset='mimic', data_split='val', types='race', feature_type=-1, random_aug=False)

# chexpert
# X_train, y_train = get_data(dataset='chexpert', data_split='train', types='race', feature_type=-1, random_aug=True)
# X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=2021)

model = define_model_diseases()

model.fit(X_train, y_train, epochs=15, validation_data=(X_val, y_val), callbacks=callback, batch_size=BATCH_SIZE)

# y_preds = model.predict(X_val)

# best_thresh = cal_best_thresh(y_val, y_preds)

# np.savetxt('mimic_proposed_thresh.txt', [best_thresh])

# del X_train, y_train
# del X_val, y_val
# gc.collect()

In [None]:
X_test, y_test = get_data(dataset='mimic', data_split='test', types='race', feature_type=-1, random_aug=False)

model = define_model_diseases()

model.load_weights('checkpoints/model_mimic_128')
y_preds = model.predict(X_test)
plot_roc(y_test, y_preds, 'ROC', Labels_diseases)
print(test(y_preds, y_test))

In [None]:
model = tf.keras.models.load_model('model_mimic_128.h5')

In [None]:
model.summary()

In [None]:
X_test, y_test = get_data(dataset='mimic', data_split='test', types='race', feature_type=-1, random_aug=False)

model = define_model_diseases()

model.load_weights('checkpoints/model_mimic_baseline')
y_preds = model.predict(X_test)
plot_roc(y_test, y_preds, 'ROC', Labels_diseases)
print(test(y_preds, y_test))

# model.load_weights('checkpoints/model_mimic_proposed')
# y_preds = model.predict(X_test)
# plot_roc(y_test, y_preds, 'ROC', Labels_diseases)
# print(test(y_preds, y_test))

# model.load_weights('checkpoints/model_chexpert_baseline')
# y_preds = model.predict(X_test)
# plot_roc(y_test, y_preds, 'ROC', Labels_diseases)
# print(test(y_preds, y_test))

# model.load_weights('checkpoints/model_chexpert_proposed')
# y_preds = model.predict(X_test)
# plot_roc(y_test, y_preds, 'ROC', Labels_diseases)
# print(test(y_preds, y_test))

del model 
gc.collect()

# Calculate TPR Disparity

In [None]:
model = define_model_diseases()

model.load_weights('checkpoints/model_mimic_baseline')
best_thresh = np.loadtxt('mimic_baseline_thresh.txt')

# model.load_weights('checkpoints/model_mimic_proposed')
# best_thresh = np.loadtxt('mimic_proposed_thresh.txt')

# model.load_weights('checkpoints/model_chexpert_baseline')
# best_thresh = np.loadtxt('chexpert_baseline_thresh.txt')

# model.load_weights('checkpoints/model_chexpert_proposed')
# best_thresh = np.loadtxt('chexpert_proposed_thresh.txt')

## Race

In [None]:
random_aug = False
X_test_white, y_test_white = get_data('mimic', 'test', 'race', 0, random_aug)
X_test_black, y_test_black = get_data('mimic', 'test', 'race', 1, random_aug)
X_test_asia, y_test_asia = get_data('mimic', 'test', 'race', 4, random_aug)

In [None]:
y_preds_white = model.predict(X_test_white)
y_preds_black = model.predict(X_test_black)
y_preds_asia = model.predict(X_test_asia)

In [None]:
disparity = []
tprs_white = []
tprs_black = []
tprs_asia = []
for i in range(14):
    tpr_white = get_tpr(y_test_white[:, i], y_preds_white[:, i], best_thresh[i])
    tpr_black = get_tpr(y_test_black[:, i], y_preds_black[:, i], best_thresh[i])
    tpr_asia = get_tpr(y_test_asia[:, i], y_preds_asia[:, i], best_thresh[i])

    median = np.median([tpr_white, tpr_black, tpr_asia])
    gap = np.abs(tpr_white-median) + np.abs(tpr_black-median) + np.abs(tpr_asia-median)
    disparity.append(gap)
    tprs_white.append(tpr_white)
    tprs_black.append(tpr_black)
    tprs_asia.append(tpr_asia)
    
result = []
result.append(disparity)
result.append(tprs_white)
result.append(tprs_black)
result.append(tprs_asia)

In [None]:
print(test(y_preds_white, y_test_white))
print(test(y_preds_black, y_test_black))
print(test(y_preds_asia, y_test_asia))

In [None]:
fig = plt.figure(figsize=(8,6))
# target_label = 4
for target_label in range(14):
    fpr, tpr, _ = roc_curve(y_test_white[:, target_label], y_preds_white[:, target_label])
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, linestyle='solid', color='b', label='{} AUC={:.3f}'.format('White', roc_auc))

    fpr, tpr, _ = roc_curve(y_test_black[:, target_label], y_preds_black[:, target_label])
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, linestyle='solid', color='g', label='{} AUC={:.3f}'.format('Black', roc_auc))

    fpr, tpr, _ = roc_curve(y_test_asia[:, target_label], y_preds_asia[:, target_label])
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, linestyle='solid', color='y', label='{} AUC={:.3f}'.format('Asian', roc_auc))


    fpr, tpr, _ = roc_curve(y_test_white[:, target_label], y_preds_white[:, target_label])
    roc_auc = auc(fpr, tpr)
    idx = find_nearest(_, best_thresh[target_label])
    plt.plot(fpr[idx], tpr[idx], marker='X', color='b', markersize=12, label='TPR={:.3f} FPR={:.3f}'.format(tpr[idx], fpr[idx]))

    fpr, tpr, _ = roc_curve(y_test_black[:, target_label], y_preds_black[:, target_label])
    roc_auc = auc(fpr, tpr)
    idx = find_nearest(_, best_thresh[target_label])
    plt.plot(fpr[idx], tpr[idx], marker='X', color='g', markersize=12, label='TPR={:.3f} FPR={:.3f}'.format(tpr[idx], fpr[idx]))

    fpr, tpr, _ = roc_curve(y_test_asia[:, target_label], y_preds_asia[:, target_label])
    roc_auc = auc(fpr, tpr)
    idx = find_nearest(_, best_thresh[target_label])
    plt.plot(fpr[idx], tpr[idx], marker='X', color='y', markersize=12, label='TPR={:.3f} FPR={:.3f}'.format(tpr[idx], fpr[idx]))


    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('{} ROC'.format(Labels_diseases[target_label]))
    plt.legend(loc="lower right", ncol=2)
    plt.plot([0,1], [0,1], color='black', linestyle='--')
    plt.savefig('MIMICproposed_MIMICnoaug/{} ROC.jpg'.format(Labels_diseases[target_label]))

    plt.show()

## Age

In [None]:
random_aug = False

X_test_age0, y_test_age0 = get_data('mimic', 'test', 'age', 0, random_aug)
X_test_age1, y_test_age1 = get_data('mimic', 'test', 'age', 1, random_aug)
X_test_age2, y_test_age2 = get_data('mimic', 'test', 'age', 2, random_aug)
X_test_age3, y_test_age3 = get_data('mimic', 'test', 'age', 3, random_aug)

In [None]:
y_preds_age0 = model.predict(X_test_age0)
y_preds_age1 = model.predict(X_test_age1)
y_preds_age2 = model.predict(X_test_age2)
y_preds_age3 = model.predict(X_test_age3)

In [None]:
disparity = []
tprs_age0 = []
tprs_age1 = []
tprs_age2 = []
tprs_age3 = []
for i in range(14):
    tpr_age0 = get_tpr(y_test_age0[:, i], y_preds_age0[:, i], best_thresh[i])
    tpr_age1 = get_tpr(y_test_age1[:, i], y_preds_age1[:, i], best_thresh[i])
    tpr_age2 = get_tpr(y_test_age2[:, i], y_preds_age2[:, i], best_thresh[i])
    tpr_age3 = get_tpr(y_test_age3[:, i], y_preds_age3[:, i], best_thresh[i])

    median = np.median([tpr_age0, tpr_age1, tpr_age2, tpr_age3])
    gap = np.abs(tpr_age0-median) + np.abs(tpr_age1-median) + np.abs(tpr_age2-median) + np.abs(tpr_age3-median)
    disparity.append(gap)
    tprs_age0.append(tpr_age0)
    tprs_age1.append(tpr_age1)
    tprs_age2.append(tpr_age2)
    tprs_age3.append(tpr_age3)
    
result = []
result.append(disparity)
result.append(tprs_age0)
result.append(tprs_age1)
result.append(tprs_age2)
result.append(tprs_age3)

In [None]:
print(test(y_preds_age0, y_test_age0))
print(test(y_preds_age1, y_test_age1))
print(test(y_preds_age2, y_test_age2))
print(test(y_preds_age3, y_test_age3))

In [None]:
fig = plt.figure(figsize=(8,6))
target_label = 4

for target_label in range(14):
    fpr, tpr, _ = roc_curve(y_test_age0[:, target_label], y_preds_age0[:, target_label])
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, linestyle='solid', color='r', label='{} AUC={:.3f}'.format('0-40', roc_auc))

    fpr, tpr, _ = roc_curve(y_test_age1[:, target_label], y_preds_age1[:, target_label])
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, linestyle='solid', color='c', label='{} AUC={:.3f}'.format('40-60', roc_auc))

    fpr, tpr, _ = roc_curve(y_test_age2[:, target_label], y_preds_age2[:, target_label])
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, linestyle='solid', color='m', label='{} AUC={:.3f}'.format('60-80', roc_auc))

    fpr, tpr, _ = roc_curve(y_test_age3[:, target_label], y_preds_age3[:, target_label])
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, linestyle='solid', color='g', label='{} AUC={:.3f}'.format('80-', roc_auc))


    fpr, tpr, _ = roc_curve(y_test_age0[:, target_label], y_preds_age0[:, target_label])
    roc_auc = auc(fpr, tpr)
    idx = find_nearest(_, best_thresh[target_label])
    plt.plot(fpr[idx], tpr[idx], marker='X', color='r', markersize=12, label='TPR={:.3f} FPR={:.3f}'.format(tpr[idx], fpr[idx]))

    fpr, tpr, _ = roc_curve(y_test_age1[:, target_label], y_preds_age1[:, target_label])
    roc_auc = auc(fpr, tpr)
    idx = find_nearest(_, best_thresh[target_label])
    plt.plot(fpr[idx], tpr[idx], marker='X', color='c', markersize=12, label='TPR={:.3f} FPR={:.3f}'.format(tpr[idx], fpr[idx]))

    fpr, tpr, _ = roc_curve(y_test_age2[:, target_label], y_preds_age2[:, target_label])
    roc_auc = auc(fpr, tpr)
    idx = find_nearest(_, best_thresh[target_label])
    plt.plot(fpr[idx], tpr[idx], marker='X', color='m', markersize=12, label='TPR={:.3f} FPR={:.3f}'.format(tpr[idx], fpr[idx]))

    fpr, tpr, _ = roc_curve(y_test_age3[:, target_label], y_preds_age3[:, target_label])
    roc_auc = auc(fpr, tpr)
    idx = find_nearest(_, best_thresh[target_label])
    plt.plot(fpr[idx], tpr[idx], marker='X', color='g', markersize=12, label='TPR={:.3f} FPR={:.3f}'.format(tpr[idx], fpr[idx]))


    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('{} ROC'.format(Labels_diseases[target_label]))
    plt.legend(loc="lower right", ncol=2)
    plt.plot([0,1], [0,1], color='black', linestyle='--')
    plt.savefig('MIMICbaseline_MIMICnoaug/{} ROC age.jpg'.format(Labels_diseases[target_label]))

    plt.show()

## Gender

In [None]:
random_aug = False

X_test_male, y_test_male = get_data('mimic', 'test', 'gender', 0, random_aug)
X_test_female, y_test_female = get_data('mimic', 'test', 'gender', 1, random_aug)

In [None]:
y_preds_male = model.predict(X_test_male)
y_preds_female = model.predict(X_test_female)

In [None]:
disparity = []
tprs_male = []
tprs_female = []
for i in range(14):
    tpr_male = get_tpr(y_test_male[:, i], y_preds_male[:, i], best_thresh[i])
    tpr_female = get_tpr(y_test_female[:, i], y_preds_female[:, i], best_thresh[i])

    gap = np.abs(tpr_male-tpr_female)
    disparity.append(gap)
    tprs_male.append(tpr_male)
    tprs_female.append(tpr_female)

result = []
result.append(disparity)
result.append(tprs_male)
result.append(tprs_female)

In [None]:
print(test(y_preds_male, y_test_male))
print(test(y_preds_female, y_test_female))

In [None]:
fig = plt.figure(figsize=(8,6))
target_label = 4

for target_label in range(14):
    fpr, tpr, _ = roc_curve(y_test_male[:, target_label], y_preds_male[:, target_label])
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, linestyle='solid', color='b', label='{} AUC={:.3f}'.format('Male', roc_auc))

    fpr, tpr, _ = roc_curve(y_test_female[:, target_label], y_preds_female[:, target_label])
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, linestyle='solid', color='r', label='{} AUC={:.3f}'.format('Female', roc_auc))


    fpr, tpr, _ = roc_curve(y_test_male[:, target_label], y_preds_male[:, target_label])
    roc_auc = auc(fpr, tpr)
    idx = find_nearest(_, best_thresh[target_label])
    plt.plot(fpr[idx], tpr[idx], marker='X', color='b', markersize=12, label='TPR={:.3f} FPR={:.3f}'.format(tpr[idx], fpr[idx]))

    fpr, tpr, _ = roc_curve(y_test_female[:, target_label], y_preds_female[:, target_label])
    roc_auc = auc(fpr, tpr)
    idx = find_nearest(_, best_thresh[target_label])
    plt.plot(fpr[idx], tpr[idx], marker='X', color='r', markersize=12, label='TPR={:.3f} FPR={:.3f}'.format(tpr[idx], fpr[idx]))



    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('{} ROC'.format(Labels_diseases[target_label]))
    plt.legend(loc="lower right", ncol=2)
    plt.plot([0,1], [0,1], color='black', linestyle='--')
    plt.savefig('MIMICproposed_MIMICnoaug/{} ROC gender.jpg'.format(Labels_diseases[target_label]))

    plt.show()