In [None]:
!nvidia-smi

In [None]:
from models import *
from evaluation import *
from load_data import *
from resnet3d import *
import numpy as np
import tensorflow as tf
import warnings
import os

print(tf.__version__)
warnings.filterwarnings("ignore")

gpus = tf.config.list_physical_devices(device_type='GPU')
tf.config.set_visible_devices(devices=gpus[0], device_type='GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)

In [None]:
seed = 2021
os.environ['PYTHONHASHSEED']=str(seed)
tf.random.set_seed(seed)
np.random.seed(seed)
Labels_diseases = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion', 'Lung Opacity', 'No Finding', 'Pleural Effusion', 'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices']

In [None]:
import io
class CPU_Unpickler(pickle.Unpickler):
    def find_class(self, module, name):
        if module == 'torch.storage' and name == '_load_from_bytes':
            return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
        else:
            pass
        return super().find_class(module, name)

## 2D

In [None]:
def get_data(aug_method='_rotation90', dataset='mimic', data_split='test', task='disease', return_demo=False):
    
    np.random.seed(2021)
            
    X = []
    y = []
    demo = []
    
    filename = 'data/{dataset}_{data_split}{aug_method}.tfrecords'.format(dataset=dataset, data_split=data_split, aug_method=aug_method)

    raw_dataset = tf.data.TFRecordDataset(filename)
    for raw_record in raw_dataset:
        label = []
        
        example = tf.train.Example()
        example.ParseFromString(raw_record.numpy())
        
        race = example.features.feature['race'].int64_list.value[0]
        age = example.features.feature['age'].int64_list.value[0]
        if (dataset == 'mimic' and age > 0):
            age -= 1
        gender = example.features.feature['gender'].int64_list.value[0]
        
        temp = [race, gender, age]
#         {"race":race, "gender":gender, "age":age}
        demo.append(temp)
                        
        if (task=='disease'):
            
            label.append(1 if example.features.feature['Atelectasis'].float_list.value[0] == 1 else 0)
            label.append(1 if example.features.feature['Cardiomegaly'].float_list.value[0] == 1 else 0)
            label.append(1 if example.features.feature['Consolidation'].float_list.value[0] == 1 else 0)
            label.append(1 if example.features.feature['Edema'].float_list.value[0] == 1 else 0)
            label.append(1 if example.features.feature['Enlarged Cardiomediastinum'].float_list.value[0] == 1 else 0)
            label.append(1 if example.features.feature['Fracture'].float_list.value[0] == 1 else 0)
            label.append(1 if example.features.feature['Lung Lesion'].float_list.value[0] == 1 else 0)
            label.append(1 if example.features.feature['Lung Opacity'].float_list.value[0] == 1 else 0)
            label.append(1 if example.features.feature['No Finding'].float_list.value[0] == 1 else 0)
            label.append(1 if example.features.feature['Pleural Effusion'].float_list.value[0] == 1 else 0)
            label.append(1 if example.features.feature['Pleural Other'].float_list.value[0] == 1 else 0)
            label.append(1 if example.features.feature['Pneumonia'].float_list.value[0] == 1 else 0)
            label.append(1 if example.features.feature['Pneumothorax'].float_list.value[0] == 1 else 0)
            label.append(1 if example.features.feature['Support Devices'].float_list.value[0] == 1 else 0)
            
        elif (task == 'race'):
            
            if (race == 0):
                label = [1, 0, 0]
            elif (race == 1):
                label = [0, 1, 0]
            else:
                label = [0, 0, 1]
                
        elif (task == 'age'):
            
            if (age == 0):
                label = [1, 0, 0, 0]
            elif (age == 1):
                label = [0, 1, 0, 0]
            elif (age == 2):
                label = [0, 0, 1, 0]
            else:
                label = [0, 0, 0, 1]
                
        elif (task == 'gender'):
            
            if (gender == 0):
                label = [1, 0]
            else:
                label = [0, 1]
                
        else:
            raise NameError('Wrong task')

#         nparr = np.fromstring(example.features.feature['jpg_bytes'].bytes_list.value[0], np.uint8)
#         img_np = cv.imdecode(nparr, cv.IMREAD_GRAYSCALE)  
        
        X.append(0)
                        
        y.append(label)
                
    if (return_demo):
        return np.array(X), np.array(y), np.array(demo)
    else:
        return np.array(X), np.array(y)
    
aug_method = ''
dataset = 'Chexpert'
task = 'disease'

X_test, y_test, demo = get_data(aug_method=aug_method, dataset=dataset, data_split='test', task=task, return_demo=True)

In [None]:
file_name = 'predictions/model_densenet_Chexpert_ERM_proposed_on_original'

with open(file_name, "rb") as fp:
    y_preds = CPU_Unpickler(fp).load()
fp.close()

best_thresh = np.loadtxt('thresh/model_densenet_Chexpert_ERM_proposed_thresh.txt')

idx_white = np.where(demo[:, 0] == 0)[0]
idx_black = np.where(demo[:, 0] == 1)[0]
idx_asian = np.where(demo[:, 0] == 4)[0]

y_test_white = y_test[idx_white]
y_test_black = y_test[idx_black]
y_test_asian = y_test[idx_asian]

y_preds_white = y_preds[idx_white]
y_preds_black = y_preds[idx_black]
y_preds_asian = y_preds[idx_asian]


idx_m = np.where(demo[:, 1] == 0)[0]
idx_f = np.where(demo[:, 1] == 1)[0]


y_test_male = y_test[idx_m]
y_test_female = y_test[idx_f]


y_preds_male = y_preds[idx_m]
y_preds_female = y_preds[idx_f]


idx_age0 = np.where(demo[:, 2] == 0)[0]
idx_age1 = np.where(demo[:, 2] == 1)[0]
idx_age2 = np.where(demo[:, 2] == 2)[0]
idx_age3 = np.where(demo[:, 2] == 3)[0]

y_test_age0 = y_test[idx_age0]
y_test_age1 = y_test[idx_age1]
y_test_age2 = y_test[idx_age2]
y_test_age3 = y_test[idx_age3]

y_preds_age0 = y_preds[idx_age0]
y_preds_age1 = y_preds[idx_age1]
y_preds_age2 = y_preds[idx_age2]
y_preds_age3 = y_preds[idx_age3]

In [None]:
target_labels = [0, 1, 2, 3, 4, 7, 8, 9, 11, 12]


for target_label in target_labels:
    
    
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(24, 8), constrained_layout=True, sharey=True, dpi=300)
    fig.subplots_adjust(wspace=0.1, hspace=0.1)
    fig.suptitle('{}'.format(Labels_diseases[target_label]), fontsize=20)

        
    fpr, tpr, _ = roc_curve(y_test_white[:, target_label], y_preds_white[:, target_label])
    roc_auc = auc(fpr, tpr)
    ax1.plot(fpr, tpr, linestyle='solid', color='b', label='{} AUC={:.3f}'.format('White', roc_auc))

    fpr, tpr, _ = roc_curve(y_test_black[:, target_label], y_preds_black[:, target_label])
    roc_auc = auc(fpr, tpr)
    ax1.plot(fpr, tpr, linestyle='solid', color='g', label='{} AUC={:.3f}'.format('Black', roc_auc))

    fpr, tpr, _ = roc_curve(y_test_asian[:, target_label], y_preds_asian[:, target_label])
    roc_auc = auc(fpr, tpr)
    ax1.plot(fpr, tpr, linestyle='solid', color='y', label='{} AUC={:.3f}'.format('Asian', roc_auc))


    fpr, tpr, thresh = roc_curve(y_test_white[:, target_label], y_preds_white[:, target_label])
    roc_auc = auc(fpr, tpr)
    idx = find_nearest(thresh, best_thresh[target_label])
    ax1.plot(fpr[idx], tpr[idx], marker='X', color='b', markersize=8, label='TPR={:.3f} FPR={:.3f}'.format(tpr[idx], fpr[idx]))
    
    fpr, tpr, thresh = roc_curve(y_test_black[:, target_label], y_preds_black[:, target_label])
    roc_auc = auc(fpr, tpr)
    idx = find_nearest(thresh, best_thresh[target_label])
    ax1.plot(fpr[idx], tpr[idx], marker='X', color='g', markersize=8, label='TPR={:.3f} FPR={:.3f}'.format(tpr[idx], fpr[idx]))

    fpr, tpr, thresh = roc_curve(y_test_asian[:, target_label], y_preds_asian[:, target_label])
    roc_auc = auc(fpr, tpr)
    idx = find_nearest(thresh, best_thresh[target_label])
    ax1.plot(fpr[idx], tpr[idx], marker='X', color='y', markersize=8, label='TPR={:.3f} FPR={:.3f}'.format(tpr[idx], fpr[idx]))


    ax1.set_title('Race', fontsize=15)
    ax1.set_xlabel('False Positive Rate', fontsize=12)
    ax1.set_ylabel('True Positive Rate', fontsize=12)
    ax1.legend(loc='lower right', ncol=2)
    ax1.plot([0,1], [0,1], color='black', linestyle='--')
    
        
    fpr, tpr, _ = roc_curve(y_test_age0[:, target_label], y_preds_age0[:, target_label])
    roc_auc = auc(fpr, tpr)
    ax2.plot(fpr, tpr, linestyle='solid', color='r', label='{} AUC={:.3f}'.format('0-40  ', roc_auc))

    fpr, tpr, _ = roc_curve(y_test_age1[:, target_label], y_preds_age1[:, target_label])
    roc_auc = auc(fpr, tpr)
    ax2.plot(fpr, tpr, linestyle='solid', color='c', label='{} AUC={:.3f}'.format('40-60', roc_auc))

    fpr, tpr, _ = roc_curve(y_test_age2[:, target_label], y_preds_age2[:, target_label])
    roc_auc = auc(fpr, tpr)
    ax2.plot(fpr, tpr, linestyle='solid', color='m', label='{} AUC={:.3f}'.format('60-80', roc_auc))

    fpr, tpr, _ = roc_curve(y_test_age3[:, target_label], y_preds_age3[:, target_label])
    roc_auc = auc(fpr, tpr)
    ax2.plot(fpr, tpr, linestyle='solid', color='g', label='{} AUC={:.3f}'.format('80+  ', roc_auc))


    fpr, tpr, thresh = roc_curve(y_test_age0[:, target_label], y_preds_age0[:, target_label])
    roc_auc = auc(fpr, tpr)
    idx = find_nearest(thresh, best_thresh[target_label])
    ax2.plot(fpr[idx], tpr[idx], marker='X', color='r', markersize=8, label='TPR={:.3f} FPR={:.3f}'.format(tpr[idx], fpr[idx]))

    fpr, tpr, thresh = roc_curve(y_test_age1[:, target_label], y_preds_age1[:, target_label])
    roc_auc = auc(fpr, tpr)
    idx = find_nearest(thresh, best_thresh[target_label])
    ax2.plot(fpr[idx], tpr[idx], marker='X', color='c', markersize=8, label='TPR={:.3f} FPR={:.3f}'.format(tpr[idx], fpr[idx]))

    fpr, tpr, thresh = roc_curve(y_test_age2[:, target_label], y_preds_age2[:, target_label])
    roc_auc = auc(fpr, tpr)
    idx = find_nearest(thresh, best_thresh[target_label])
    ax2.plot(fpr[idx], tpr[idx], marker='X', color='m', markersize=8, label='TPR={:.3f} FPR={:.3f}'.format(tpr[idx], fpr[idx]))

    fpr, tpr, thresh = roc_curve(y_test_age3[:, target_label], y_preds_age3[:, target_label])
    roc_auc = auc(fpr, tpr)
    idx = find_nearest(thresh, best_thresh[target_label])
    ax2.plot(fpr[idx], tpr[idx], marker='X', color='g', markersize=8, label='TPR={:.3f} FPR={:.3f}'.format(tpr[idx], fpr[idx]))



    ax2.set_title('Age', fontsize=15)
    ax2.set_xlabel('False Positive Rate', fontsize=12)
    ax2.set_ylabel('True Positive Rate', fontsize=12)
    ax2.legend(loc='lower right', ncol=2)
    ax2.plot([0,1], [0,1], color='black', linestyle='--')

    
    fpr, tpr, _ = roc_curve(y_test_male[:, target_label], y_preds_male[:, target_label])
    roc_auc = auc(fpr, tpr)
    ax3.plot(fpr, tpr, linestyle='solid', color='b', label='{} AUC={:.3f}'.format('Male', roc_auc))

    fpr, tpr, _ = roc_curve(y_test_female[:, target_label], y_preds_female[:, target_label])
    roc_auc = auc(fpr, tpr)
    ax3.plot(fpr, tpr, linestyle='solid', color='r', label='{} AUC={:.3f}'.format('Female', roc_auc))


    fpr, tpr, thresh = roc_curve(y_test_male[:, target_label], y_preds_male[:, target_label])
    roc_auc = auc(fpr, tpr)
    idx = find_nearest(thresh, best_thresh[target_label])
    ax3.plot(fpr[idx], tpr[idx], marker='X', color='b', markersize=8, label='TPR={:.3f} FPR={:.3f}'.format(tpr[idx], fpr[idx]))

    fpr, tpr, thresh = roc_curve(y_test_female[:, target_label], y_preds_female[:, target_label])
    roc_auc = auc(fpr, tpr)
    idx = find_nearest(thresh, best_thresh[target_label])
    ax3.plot(fpr[idx], tpr[idx], marker='X', color='r', markersize=8, label='TPR={:.3f} FPR={:.3f}'.format(tpr[idx], fpr[idx]))

    ax3.set_title('Gender', fontsize=15)    
    ax3.set_xlabel('False Positive Rate', fontsize=12)
    ax3.set_ylabel('True Positive Rate', fontsize=12)
    ax3.legend(loc='lower right', ncol=2)
    ax3.plot([0,1], [0,1], color='black', linestyle='--')
    
    path = 'AUC_curve/{}/'.format(file_name[12:])
    if not (os.path.exists(path)):
        os.makedirs(path)
    
    plt.savefig('AUC_curve/{}/{}_ROC.jpg'.format(file_name[12:], Labels_diseases[target_label]), bbox_inches='tight', transparent="True")
    plt.show()

## 3D

In [None]:
df = pd.read_csv('data_new.csv')
data_path = '../../../mnt/usb/kuopc/ADNI_B1/MPR__GradWarp__B1_Correction_crop/'

df = df.loc[df['Group'] != 'MCI']
df = df.loc[df['Split'] == 'test']

df['Group'] = df['Group'].replace(['CN', 'AD'], [0, 1])
df['Sex'] = df['Sex'].replace(['F', 'M'], [0, 1])
df['Age'] = np.where(df['Age'] <= 75, 0, 1)
df['Race'] = np.where(df['Race'] < 1, 0, 1)

In [None]:
file_name = 'predictions/3D_CNN_AD_CN_on_original'

with open(file_name, "rb") as fp:
    y_preds = CPU_Unpickler(fp).load()
fp.close()

best_thresh = np.loadtxt('thresh/3D_CNN_AD_CN_thresh.txt')


idx_m = np.where(df['Sex'].values == 0)[0]
idx_f = np.where(df['Sex'].values == 1)[0]


y_test_male = df['Group'].values[idx_m]
y_test_female = df['Group'].values[idx_f]


y_preds_male = y_preds[idx_m]
y_preds_female = y_preds[idx_f]


idx_age0 = np.where(df['Age'].values == 0)[0]
idx_age1 = np.where(df['Age'].values == 1)[0]


y_test_age0 = df['Group'].values[idx_age0]
y_test_age1 = df['Group'].values[idx_age1]


y_preds_age0 = y_preds[idx_age0]
y_preds_age1 = y_preds[idx_age1]


In [None]:
target_labels = [0, 1, 2, 3, 4, 7, 8, 9, 11, 12]


fig, (ax2, ax3) = plt.subplots(1, 2, figsize=(24, 8), constrained_layout=True, sharey=True, dpi=300)
fig.subplots_adjust(wspace=0.1, hspace=0.1)
fig.suptitle('AD', fontsize=20)


fpr, tpr, _ = roc_curve(y_test_age0, y_preds_age0)
roc_auc = auc(fpr, tpr)
ax2.plot(fpr, tpr, linestyle='solid', color='r', label='{} AUC={:.3f}'.format('0-75', roc_auc))

fpr, tpr, _ = roc_curve(y_test_age1, y_preds_age1)
roc_auc = auc(fpr, tpr)
ax2.plot(fpr, tpr, linestyle='solid', color='c', label='{} AUC={:.3f}'.format('75+', roc_auc))


fpr, tpr, thresh = roc_curve(y_test_age0, y_preds_age0)
roc_auc = auc(fpr, tpr)
idx = find_nearest(thresh, best_thresh)
ax2.plot(fpr[idx], tpr[idx], marker='X', color='r', markersize=8, label='TPR={:.3f} FPR={:.3f}'.format(tpr[idx], fpr[idx]))

fpr, tpr, thresh = roc_curve(y_test_age1, y_preds_age1)
roc_auc = auc(fpr, tpr)
idx = find_nearest(thresh, best_thresh)
ax2.plot(fpr[idx], tpr[idx], marker='X', color='c', markersize=8, label='TPR={:.3f} FPR={:.3f}'.format(tpr[idx], fpr[idx]))


ax2.set_title('Age', fontsize=15)
ax2.set_xlabel('False Positive Rate', fontsize=12)
ax2.set_ylabel('True Positive Rate', fontsize=12)
ax2.legend(loc='lower right', ncol=2)
ax2.plot([0,1], [0,1], color='black', linestyle='--')


fpr, tpr, _ = roc_curve(y_test_male, y_preds_male)
roc_auc = auc(fpr, tpr)
ax3.plot(fpr, tpr, linestyle='solid', color='b', label='{} AUC={:.3f}'.format('Male', roc_auc))

fpr, tpr, _ = roc_curve(y_test_female, y_preds_female)
roc_auc = auc(fpr, tpr)
ax3.plot(fpr, tpr, linestyle='solid', color='r', label='{} AUC={:.3f}'.format('Female', roc_auc))


fpr, tpr, thresh = roc_curve(y_test_male, y_preds_male)
roc_auc = auc(fpr, tpr)
idx = find_nearest(thresh, best_thresh)
ax3.plot(fpr[idx], tpr[idx], marker='X', color='b', markersize=8, label='TPR={:.3f} FPR={:.3f}'.format(tpr[idx], fpr[idx]))

fpr, tpr, thresh_ = roc_curve(y_test_female, y_preds_female)
roc_auc = auc(fpr, tpr)
idx = find_nearest(thresh, best_thresh)
ax3.plot(fpr[idx], tpr[idx], marker='X', color='r', markersize=8, label='TPR={:.3f} FPR={:.3f}'.format(tpr[idx], fpr[idx]))

ax3.set_title('Gender', fontsize=15)    
ax3.set_xlabel('False Positive Rate', fontsize=12)
ax3.set_ylabel('True Positive Rate', fontsize=12)
ax3.legend(loc='lower right', ncol=2)
ax3.plot([0,1], [0,1], color='black', linestyle='--')

path = 'AUC_curve/{}/'.format(file_name[12:])
if not (os.path.exists(path)):
    os.makedirs(path)

plt.savefig('AUC_curve/{}/AD_ROC.jpg'.format(file_name[12:]),bbox_inches='tight', transparent="True")
plt.show()