In [None]:
!nvidia-smi

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
import tensorflow as tf
import pickle
import pandas as pd
import numpy as np
import warnings
from models import *
from evaluation import *
from load_data import *

print(tf.__version__)
warnings.filterwarnings("ignore")

gpus = tf.config.list_physical_devices(device_type='GPU')
tf.config.set_visible_devices(devices=gpus[0], device_type='GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)

In [None]:
seed = 2021
os.environ['PYTHONHASHSEED']=str(seed)
tf.random.set_seed(seed)
np.random.seed(seed)

# Compute disparity

## 3D MRI disparity results

In [None]:
df = pd.read_csv('data_new.csv')
data_path = '../../../mnt/usb/kuopc/ADNI_B1/MPR__GradWarp__B1_Correction_crop/'

df = df.loc[df['Group'] != 'MCI']
df = df.loc[df['Split'] == 'test']

df['Group'] = df['Group'].replace(['CN', 'AD'], [0, 1])
df['Sex'] = df['Sex'].replace(['F', 'M'], [0, 1])
df['Age'] = np.where(df['Age'] <= 75, 0, 1)
df['Race'] = np.where(df['Race'] < 1, 0, 1)

### Generate disparities csv

In [None]:
metrics = ['AUC', 'BCE', 'ECE', "Error rate", "Precision"]
group = 'race'
testdata = 'original'
group_type = {'race': [0, 1], 'gender': [0, 1], 'age': [0, 1]}
group_name = {'race': ['white', 'others'], 'gender': ['Female', 'Male'], 'age': ['0-75', '75+']}

result_df = pd.DataFrame(columns=metrics)

for model_name in ['', '_balanced', '_stratified', '_Adv', '_DistMatchMMD', '_DistMatchMean', '_FairALM']:
    
    results_list = [
                'results/3D_CNN_AD_CN{model_name}_on_original_{group}_results'.format(model_name=model_name, group=group),
                'results/3D_CNN_AD_CN{model_name}_on_aug_{group}_results'.format(model_name=model_name, group=group),
                'results/3D_CNN_AD_CN{model_name}_proposed_on_original_{group}_results'.format(model_name=model_name, group=group),
                'results/3D_CNN_AD_CN{model_name}_proposed_on_aug_{group}_results'.format(model_name=model_name, group=group),
    ]
          

    for i, metric in enumerate(metrics):

        all_mean_disparity = []
        all_mean_score = []
        for idx, result_name in enumerate(results_list):
            
            result_df_dict = pd.DataFrame({"AUC":'', "BCE":'', "ECE":'', "Error rate":'', "Precision":''}, index=[result_name])

            with open("{result_name}".format(result_name=result_name), "rb") as fp:   # Unpickling
                dfs = pickle.load(fp)
            fp.close()

            all_disparity = []
            all_mean = []
            mean_scores = []
            for k in range(len(group_type[group])):
                dfs[k].replace([np.inf, -np.inf], np.nan, inplace=True)
                dfs[k].fillna(0, inplace=True)
                mean_scores.append(dfs[k][metric].mean(skipna=True))

            median = np.nanmedian(mean_scores)

            disparity = 0
            for k in range(len(group_type[group])):
                disparity += (np.abs(mean_scores[k]-median))

            all_disparity.append(disparity)
            all_mean.append(np.nanmean(mean_scores))

            all_mean_score = np.nanmean(all_mean)
            all_mean_disparity = np.nanmean(all_disparity)
            std_dev = np.nanstd(all_disparity)
            std_error = std_dev / np.math.sqrt(1)
            ci =  2.262 * std_error
            all_lower = (np.nanmean(all_disparity) - ci)
            all_upper = (np.nanmean(all_disparity) + ci)
            
            
            result_df.at[result_name, metric] = "{:.3f}".format(all_mean_disparity)
            
result_df.to_csv('disparity_results/3D_all_result_{group}.csv'.format(group=group))


### Plot disparities

In [None]:
def legend_without_duplicate_labels(figure):
    handles, labels = plt.gca().get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    legend = figure.legend(by_label.values(), by_label.keys(), loc='lower center', bbox_to_anchor=(1.3, 0.5))

    for legend_handle in legend.legendHandles:
        legend_handle.set_markersize(3)

In [None]:
metrics = ['AUC', 'BCE', 'ECE', "Error rate", "Precision"]
groups = ['gender', 'age']
testdata = 'aug'
group_type = {'race': [0, 1], 'gender': [0, 1], 'age': [0, 1]}

results_list = ['results/3D_CNN_AD_CN_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
                'results/3D_CNN_AD_CN_balanced_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
                'results/3D_CNN_AD_CN_stratified_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
                'results/3D_CNN_AD_CN_Adv_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
                'results/3D_CNN_AD_CN_DistMatchMMD_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
                'results/3D_CNN_AD_CN_DistMatchMean_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
                'results/3D_CNN_AD_CN_FairALM_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
                'results/3D_CNN_AD_CN_proposed_on_{testdata}_{group}_results'.format(testdata=testdata, group=group)]

model_list = ['Baseline', 
              'Balanced', 'Stratified', 'Adversarial learning',
              'DistMatchMMD', 
              'DistMatchMean', 'FairALM', 
              'Proposed augmentation']

color_list = ['C0','C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'black']


plt.rcParams["figure.autolayout"] = True
fig, axs = plt.subplots(nrows=1, ncols=2, sharey=True, figsize=(14, 7), dpi=500)

for num, ax in enumerate(fig.axes):
    
    group = groups[num]

    results_list = ['results/3D_CNN_AD_CN_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
                    'results/3D_CNN_AD_CN_balanced_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
                    'results/3D_CNN_AD_CN_stratified_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
                    'results/3D_CNN_AD_CN_Adv_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
                    'results/3D_CNN_AD_CN_DistMatchMMD_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
                    'results/3D_CNN_AD_CN_DistMatchMean_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
                    'results/3D_CNN_AD_CN_FairALM_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
                    'results/3D_CNN_AD_CN_proposed_on_{testdata}_{group}_results'.format(testdata=testdata, group=group)]


    model_list = ['Baseline', 
                  'Balanced', 'Stratified', 'Adversarial learning',
                  'DistMatchMMD', 
                  'DistMatchMean', 'FairALM', 
                  'Proposed augmentation']

    color_list = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'black']



    target_label = [0, 1, 2, 3, 4, 7, 8, 9, 11, 12]

    gap_result = 5
    gap_between_metrics = 20
    gap_metrics = gap_result*(len(results_list)-1) + gap_between_metrics

    top = int(gap_metrics*(len(metrics)-1) + gap_result*(len(results_list)/2))

    ax.set_title('{group} disparity'.format(testdata=testdata, group=group.capitalize()), fontsize=13)
    ax.set_yticks([(i*gap_metrics) for i in range(len(metrics))][::-1], metrics, fontsize=15)

    plt_dot = []

    for i, metric in enumerate(metrics):

        all_mean_disparity = []
        all_mean_score = []
        for idx, result_name in enumerate(results_list):

            with open("{result_name}".format(result_name=result_name), "rb") as fp:   # Unpickling
                dfs = pickle.load(fp)

            all_disparity = []
            all_mean = []
            mean_scores = []
            for k in range(len(group_type[group])):
                dfs[k].replace([np.inf, -np.inf], np.nan, inplace=True)
                mean_scores.append(dfs[k][metric].mean(skipna=True))

            median = np.nanmedian(mean_scores)

            disparity = 0
            for k in range(len(group_type[group])):
                disparity += (np.abs(mean_scores[k]-median))

            all_disparity.append(disparity)
            all_mean.append(np.nanmean(mean_scores))

            all_mean_score.append(np.nanmean(all_mean))
            all_mean_disparity.append(np.nanmean(all_disparity))
            std_dev = np.nanstd(all_disparity)
            std_error = std_dev / np.math.sqrt(1)
            ci =  2.262 * std_error
            all_lower = (np.nanmean(all_disparity) - ci)
            all_upper = (np.nanmean(all_disparity) + ci)

            color = color_list[idx]

#             ax.plot([all_upper, all_lower], [top-(i*gap_metrics+idx*gap_result), top-(i*gap_metrics+idx*gap_result)], color=color, linewidth = 0.8, label=model_list[idx])

        for idx, result_name in enumerate(results_list):

            color = color_list[idx]

            ax.plot(all_mean_disparity[idx], top-(i*gap_metrics+idx*gap_result), 'o', color=color, markersize=3, label=model_list[idx])

        ax.plot([all_mean_disparity[-1], all_mean_disparity[-1]], [top-(i*gap_metrics+idx*gap_result), top-(i*gap_metrics)], linestyle='--', color='black', linewidth = 0.5)

    for i in range(len(metrics)-1):
        ax.axhline(top-((i+1)*gap_metrics)+int(gap_between_metrics/2), linestyle='--', color='black', linewidth = 0.3)

    legend_without_duplicate_labels(plt)
    
plt.savefig('results_imgs/3Ddisparity_on_{testdata}.jpg'.format(testdata=testdata))


### Plot performance

In [None]:
from matplotlib.lines import Line2D

color_list = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'black']

metrics = ['AUC', 'BCE', 'ECE', "Error rate", "Precision"]
testdata = 'original'
group_type = {'race': [0, 1], 'gender': [0, 1], 'age': [0, 1]}
# group_name = {'race': ['white', 'others'], 'gender': ['Female', 'Male'], 'age': ['Young', 'Old']}

plt.rcParams["figure.autolayout"] = True

fig, axs = plt.subplots(nrows=2, ncols=5, sharey=True, sharex='col', figsize=(13, 11), dpi=400)
groups = ['age', 'gender']
center = [(i*80-17.5) for i in range(len(results_list))][::-1]
start = [(i*80) for i in range(len(results_list))][::-1]
steps = [35, 35]


model_list = ['Baseline', 
              'Balanced', 'Stratified', 'Adversarial learning',
              'DistMatchMMD', 
              'DistMatchMean', 'FairALM', 
              'Proposed augmentation']

for num, ax in enumerate(fig.axes):
    
    group = groups[int(num/5)]

    results_list = ['results/3D_CNN_AD_CN_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
                'results/3D_CNN_AD_CN_balanced_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
                'results/3D_CNN_AD_CN_stratified_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
                'results/3D_CNN_AD_CN_Adv_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
                'results/3D_CNN_AD_CN_DistMatchMMD_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
                'results/3D_CNN_AD_CN_DistMatchMean_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
                'results/3D_CNN_AD_CN_FairALM_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
                'results/3D_CNN_AD_CN_proposed_on_{testdata}_{group}_results'.format(testdata=testdata, group=group)]


    metric = metrics[num%5]

    if (num < 5):
        ax.set_title(metric, fontsize=15)
        
    ax.set_yticks(center, model_list, fontsize=15)

    for idx, result_name in enumerate(results_list):

        with open("{result_name}".format(result_name=result_name), "rb") as fp:   # Unpickling
            dfs = pickle.load(fp)

        all_mean_score = []
        all_upper = []
        all_lower = []
        for k in range(len(group_type[group])):
            
            dfs[k].replace([np.inf, -np.inf], np.nan, inplace=True)
            mean_scores = dfs[k][metric].mean(skipna=True)

            all_mean_score.append(np.nanmean(mean_scores))
            std_dev = np.nanstd(dfs[k][metric])
            std_error = std_dev / np.math.sqrt(1)
            ci =  2.262 * std_error
            all_lower.append(np.nanmean(mean_scores) - ci)
            all_upper.append(np.nanmean(mean_scores) + ci)

        for k in range(len(group_type[group])):

            color = color_list[k]
            
            if (np.isnan(all_mean_score[k])):
                print(all_mean_score[k])
                all_mean_score[k] = 0

            ax.plot([all_upper[k], all_lower[k]], [start[idx]-k*steps[int(num/5)], start[idx]-k*steps[int(num/5)]], color=color, linewidth = 1)
            ax.plot(all_mean_score[k], start[idx]-k*steps[int(num/5)], 'o', color=color, label=group_name[group][k], markersize=2)

    for i in range(len(results_list)-1):
        ax.axhline(start[i]-k*steps[int(num/5)]-20, linestyle='--', color='k', linewidth = 0.3)
    
    if (num%5==4):
        lines = []
        for i, g in enumerate(group_name[group]):
            lines.append(Line2D([0], [0], color=color_list[i], label=g))
        
        ax.legend(handles=lines, loc='lower center', bbox_to_anchor=(1.4, 0.5))
        
plt.savefig('results_imgs/3Dperformance_{testdata}.jpg'.format(testdata=testdata))


## 2D CXR disparity results

In [None]:
metrics = ['AUC', 'BCE', 'ECE', "Error rate", "Precision"]
Labels_diseases = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion', 'Lung Opacity', 'No Finding', 'Pleural Effusion', 'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices']
testdata = 'aug'
group_type = {'race': [0, 1, 4], 'gender': [0, 1], 'age': [0, 1, 2, 3]}
group_name = {'race': ['White', 'Black', 'Asian'], 'gender': ['Male', 'Female'], 'age': ['0-40', '40-60', '60-80', '80+']}
groups = ['race', 'age', 'gender']

### Generate disparities csv

In [None]:
group = 'age'

result_df = pd.DataFrame(columns=metrics)

target_label = [0, 1, 2, 3, 4, 7, 8, 9, 11, 12]

# for model_name in ['ERM', 'ERM_balanced', 'ERM_stratified', 'Adv', 'DistMatchMMD', 'DistMatchMean', 'FairALM']:
# for model_name in 'ERM_rotation', 'ERM_shear', 'ERM_scaling', 'ERM_fisheye']:
for model_name in ['ERM_no_weight', 'ERM_proposed_no_weight']:

#     for model names are ['ERM_no_weight', 'ERM_rotation', 'ERM_shear', 'ERM_scaling', 'ERM_fisheye']
    results_list = [
                'results/densenet_mimic_{model_name}_on_original_{group}_results'.format(model_name=model_name, group=group),
                'results/densenet_mimic_{model_name}_on_aug_{group}_results'.format(model_name=model_name, group=group),
    ]


#     for model names are ['ERM', 'ERM_balanced', 'ERM_stratified', 'Adv', 'DistMatchMMD', 'DistMatchMean', 'FairALM']   
#     results_list = [
#                 'results/densenet_mimic_{model_name}_on_original_{group}_results'.format(model_name=model_name, group=group),
#                 'results/densenet_mimic_{model_name}_on_aug_{group}_results'.format(model_name=model_name, group=group),
#                 'results/densenet_mimic_{model_name}_proposed_on_original_{group}_results'.format(model_name=model_name, group=group),
#                 'results/densenet_mimic_{model_name}_proposed_on_aug_{group}_results'.format(model_name=model_name, group=group),
#     ]

#     for resnet results
#     results_list = [
#                 'results/resnet_mimic_ERM_on_original_{group}_results'.format(group=group),
#                 'results/resnet_mimic_ERM_on_aug_{group}_results'.format(group=group),
#                 'results/resnet_mimic_ERM_proposed_on_original_{group}_results'.format(group=group),
#                 'results/resnet_mimic_ERM_proposed_on_aug_{group}_results'.format(group=group),
#     ]


#     for chexpert results
#     results_list = [
#             'results/densenet_Chexpert_ERM_on_original_{group}_results'.format(group=group),
#             'results/densenet_Chexpert_ERM_on_aug_{group}_results'.format(group=group),
#             'results/densenet_Chexpert_ERM_proposed_on_original_{group}_results'.format(group=group),
#             'results/densenet_Chexpert_ERM_proposed_on_aug_{group}_results'.format(group=group),
#     ]
          

    for i, metric in enumerate(metrics):

        all_mean_disparity = []
        all_mean_score = []
        for idx, result_name in enumerate(results_list):
            
            result_df_dict = pd.DataFrame({"AUC":'', "BCE":'', "ECE":'', "Error rate":'', "Precision":''}, index=[result_name])

            with open("{result_name}".format(result_name=result_name), "rb") as fp:   # Unpickling
                dfs = pickle.load(fp)

            all_disparity = []
            all_mean = []
            for j in target_label:
                mean_scores = []
                for k in range(len(group_type[group])):
                    dfs[k][j].replace([np.inf, -np.inf], np.nan, inplace=True)
                    mean_scores.append(dfs[k][j][metric].mean(skipna=True))

                median = np.nanmedian(mean_scores)

                disparity = 0
                for k in range(len(group_type[group])):
                    disparity += (np.abs(mean_scores[k]-median))

                all_disparity.append(disparity)
                all_mean.append(np.nanmean(mean_scores))

            all_mean_score = np.nanmean(all_mean)
            all_mean_disparity = np.nanmean(all_disparity)
            std_dev = np.nanstd(all_disparity)
            std_error = std_dev / np.math.sqrt(1)
            ci =  2.262 * std_error
            all_lower = (np.nanmean(all_disparity) - ci)
            all_upper = (np.nanmean(all_disparity) + ci)
            
            
            result_df.at[result_name, metric] = "{:.3f} [{:.3f} - {:.3f}]".format(all_mean_disparity,all_lower, all_upper)
            
    result_df.to_csv('disparity_results/all_densenet_no_weight_mimic_{group}.csv'.format(group=group))


### plot disparities

In [None]:
def legend_without_duplicate_labels(figure):
    handles, labels = plt.gca().get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    legend = figure.legend(by_label.values(), by_label.keys(), loc='lower center', bbox_to_anchor=(1.4, 0.5))

    for legend_handle in legend.legendHandles:
        legend_handle.set_markersize(3)

In [None]:
plt.rcParams["figure.autolayout"] = True
fig, axs = plt.subplots(nrows=1, ncols=3, sharey=True, figsize=(14, 7), dpi=500)
testdata = 'original'

for num, ax in enumerate(fig.axes):
    
    group = groups[num]

#     results_list = ['results/densenet_mimic_ERM_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
#                     'results/densenet_mimic_ERM_balanced_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
#                     'results/densenet_mimic_ERM_stratified_on_{testdata}_{group}_results'.format(testdata=testdata, group=group), 
#                     'results/densenet_mimic_Adv_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
#                     'results/densenet_mimic_DistMatchMMD_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
#                     'results/densenet_mimic_DistMatchMean_on_{testdata}_{group}_results'.format(testdata=testdata, group=group), 
#                     'results/densenet_mimic_FairALM_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
#                     'results/densenet_mimic_ERM_proposed_on_{testdata}_{group}_results'.format(testdata=testdata, group=group)]

#     model_list = ['Baseline', 
#                   'Balanced', 'Stratified', 'Adversarial learning',
#                   'DistMatchMMD', 
#                   'DistMatchMean', 'FairALM', 
#                   'Proposed augmentation']

#     color_list = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'black']



    results_list = ['results/resnet_mimic_ERM_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
                    'results/resnet_mimic_ERM_proposed_on_{testdata}_{group}_results'.format(testdata=testdata, group=group)]


#     results_list = ['results/densenet_Chexpert_ERM_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
#                     'results/densenet_Chexpert_ERM_proposed_on_{testdata}_{group}_results'.format(testdata=testdata, group=group)]

    model_list = ['Baseline', 
                  'Proposed augmentation']

    color_list = ['C0', 'black']



    target_label = [0, 1, 2, 3, 4, 7, 8, 9, 11, 12]

    gap_result = 5
    gap_between_metrics = 20
    gap_metrics = gap_result*(len(results_list)-1) + gap_between_metrics

    top = int(gap_metrics*(len(metrics)-1) + gap_result*(len(results_list)/2))

    ax.set_title('{group} disparity'.format(testdata=testdata, group=group.capitalize()), fontsize=13)
    ax.set_yticks([(i*gap_metrics) for i in range(len(metrics))][::-1], metrics, fontsize=15)

    plt_dot = []

    for i, metric in enumerate(metrics):

        all_mean_disparity = []
        all_mean_score = []
        for idx, result_name in enumerate(results_list):

            with open("{result_name}".format(result_name=result_name), "rb") as fp:   # Unpickling
                dfs = pickle.load(fp)

            all_disparity = []
            all_mean = []
            for j in target_label:
                mean_scores = []
                for k in range(len(group_type[group])):
                    dfs[k][j].replace([np.inf, -np.inf], np.nan, inplace=True)
                    mean_scores.append(dfs[k][j][metric].mean(skipna=True))

                median = np.nanmedian(mean_scores)

                disparity = 0
                for k in range(len(group_type[group])):
                    disparity += (np.abs(mean_scores[k]-median))

                all_disparity.append(disparity)
                all_mean.append(np.nanmean(mean_scores))

            all_mean_score.append(np.nanmean(all_mean))
            all_mean_disparity.append(np.nanmean(all_disparity))
            std_dev = np.nanstd(all_disparity)
            std_error = std_dev / np.math.sqrt(1)
            ci =  2.262 * std_error
            all_lower = (np.nanmean(all_disparity) - ci)
            all_upper = (np.nanmean(all_disparity) + ci)

            color = color_list[idx]

            ax.plot([all_upper, all_lower], [top-(i*gap_metrics+idx*gap_result), top-(i*gap_metrics+idx*gap_result)], color=color, linewidth = 0.8, label=model_list[idx])

        for idx, result_name in enumerate(results_list):

            color = color_list[idx]

            ax.plot(all_mean_disparity[idx], top-(i*gap_metrics+idx*gap_result), 'o', color=color, markersize=3)

        ax.plot([all_mean_disparity[-1], all_mean_disparity[-1]], [top-(i*gap_metrics+idx*gap_result), top-(i*gap_metrics)], linestyle='--', color='black', linewidth = 0.5)

    for i in range(len(metrics)-1):
        ax.axhline(top-((i+1)*gap_metrics)+int(gap_between_metrics/2), linestyle='--', color='black', linewidth = 0.3)

    legend_without_duplicate_labels(plt)
    
plt.savefig('results_imgs/disparity_resnet_on_{testdata}.jpg'.format(testdata=testdata))


### Plot performance

In [None]:
from matplotlib.lines import Line2D

color_list = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'black']

testdata = 'original'

model_list = ['Baseline', 
              'Balanced', 'Stratified', 'Adversarial learning',
              'DistMatchMMD', 
              'DistMatchMean', 'FairALM', 
              'Proposed augmentation']
target_label = [0, 1, 2, 3, 4, 7, 8, 9, 11, 12]


plt.rcParams["figure.autolayout"] = True
fig, axs = plt.subplots(nrows=3, ncols=5, sharey=True, sharex='col', figsize=(13, 11), dpi=400)
groups = ['race', 'age', 'gender']
center = [(i*80-17.5) for i in range(len(model_list))][::-1]
start = [(i*80) for i in range(len(model_list))][::-1]
steps = [17.5, 11.67, 35]

for num, ax in enumerate(fig.axes):
    
    group = groups[int(num/5)]

    results_list = ['results/densenet_mimic_ERM_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
                    'results/densenet_mimic_ERM_balanced_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
                    'results/densenet_mimic_ERM_stratified_on_{testdata}_{group}_results'.format(testdata=testdata, group=group), 
                    'results/densenet_mimic_Adv_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
                    'results/densenet_mimic_DistMatchMMD_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
                    'results/densenet_mimic_DistMatchMean_on_{testdata}_{group}_results'.format(testdata=testdata, group=group), 
                    'results/densenet_mimic_FairALM_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
                    'results/densenet_mimic_ERM_proposed_on_{testdata}_{group}_results'.format(testdata=testdata, group=group)]


    metric = metrics[num%5]

    if (num < 5):
        ax.set_title(metric, fontsize=15)
        
    ax.set_yticks(center, model_list, fontsize=15)

    for idx, result_name in enumerate(results_list):

        with open("{result_name}".format(result_name=result_name), "rb") as fp:   # Unpickling
            dfs = pickle.load(fp)

        all_mean_score = []
        all_upper = []
        all_lower = []
        for k in range(len(group_type[group])):
            mean_scores = []
            for j in target_label:
                dfs[k][j].replace([np.inf, -np.inf], np.nan, inplace=True)
                mean_scores.append(dfs[k][j][metric].mean(skipna=True))

            all_mean_score.append(np.nanmean(mean_scores))
            std_dev = np.nanstd(mean_scores)
            std_error = std_dev / np.math.sqrt(1)
            ci =  2.262 * std_error
            all_lower.append(np.nanmean(mean_scores) - ci)
            all_upper.append(np.nanmean(mean_scores) + ci)

        for k in range(len(group_type[group])):

            color = color_list[k]

            ax.plot([all_upper[k], all_lower[k]], [start[idx]-k*steps[int(num/5)], start[idx]-k*steps[int(num/5)]], color=color, linewidth = 1)
            ax.plot(all_mean_score[k], start[idx]-k*steps[int(num/5)], 'o', color=color, label=group_name[group][k], markersize=2)

    for i in range(len(results_list)-1):
        ax.axhline(start[i]-k*steps[int(num/5)]-20, linestyle='--', color='k', linewidth = 0.3)
    
    if (num%5==4):
        lines = []
        for i, g in enumerate(group_name[group]):
            lines.append(Line2D([0], [0], color=color_list[i], label=g))
        
        ax.legend(handles=lines, loc='lower center', bbox_to_anchor=(1.4, 0.5))
        
plt.savefig('results_imgs/performance_{testdata}.jpg'.format(testdata=testdata))


In [None]:
from matplotlib.lines import Line2D

testdata = 'original'

color_list = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'black']
model_list = ['Baseline', 'Proposed augmentation']

plt.rcParams["figure.autolayout"] = True
fig, axs = plt.subplots(nrows=3, ncols=5, sharey=True, sharex='col', figsize=(13, 11), dpi=400)
groups = ['race', 'age', 'gender']


start = [20, 10]
center = [17.5, 7.5]
steps = [2.5, 1.75, 5]
for num, ax in enumerate(fig.axes):
    
    group = groups[int(num/5)]


#     results_list = ['results/densenet_Chexpert_ERM_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
#                 'results/densenet_Chexpert_ERM_proposed_on_{testdata}_{group}_results'.format(testdata=testdata, group=group)]

    results_list = ['results/resnet_mimic_ERM_on_{testdata}_{group}_results'.format(testdata=testdata, group=group),
                'results/resnet_mimic_ERM_proposed_on_{testdata}_{group}_results'.format(testdata=testdata, group=group)]

    metric = metrics[num%5]

    if (num < 5):
        ax.set_title(metric, fontsize=15)
        
    ax.set_yticks(center, model_list, fontsize=15)

    for idx, result_name in enumerate(results_list):

        with open("{result_name}".format(result_name=result_name), "rb") as fp:   # Unpickling
            dfs = pickle.load(fp)

        all_mean_score = []
        all_upper = []
        all_lower = []
        for k in range(len(group_type[group])):
            mean_scores = []
            for j in target_label:
                dfs[k][j].replace([np.inf, -np.inf], np.nan, inplace=True)
                mean_scores.append(dfs[k][j][metric].mean(skipna=True))

            all_mean_score.append(np.nanmean(mean_scores))
            std_dev = np.nanstd(mean_scores)
            std_error = std_dev / np.math.sqrt(1)
            ci =  2.262 * std_error
            all_lower.append(np.nanmean(mean_scores) - ci)
            all_upper.append(np.nanmean(mean_scores) + ci)

        for k in range(len(group_type[group])):

            color = color_list[k]

            ax.plot([all_upper[k], all_lower[k]], [start[idx]-k*steps[int(num/5)], start[idx]-k*steps[int(num/5)]], color=color, linewidth = 1)
            ax.plot(all_mean_score[k], start[idx]-k*steps[int(num/5)], 'o', color=color, label=group_name[group][k], markersize=2)

    for i in range(len(results_list)-1):
        ax.axhline(start[i]-k*steps[int(num/5)]-2.5, linestyle='--', color='k', linewidth = 0.3)
    
    if (num%5==4):
        lines = []
        for i, g in enumerate(group_name[group]):
            lines.append(Line2D([0], [0], color=color_list[i], label=g))
        
        ax.legend(handles=lines, loc='lower center', bbox_to_anchor=(1.4, 0.5))
        
plt.savefig('results_imgs/performance_resnet_{testdata}.jpg'.format(testdata=testdata))


## Task transfer 

In [None]:
def legend_without_duplicate_labels(figure):
    handles, labels = plt.gca().get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    legend = figure.legend(by_label.values(), by_label.keys(), loc='lower center', bbox_to_anchor=(1.12, 0.5), fontsize=6)

    for legend_handle in legend.legendHandles:
        legend_handle.set_markersize(2)

In [None]:
def get_data(dataset='mimic'):
    
    np.random.seed(2021)
            
    race_labels = []
    age_labels = []
    gender_labels = []
    
    filename = 'data/{dataset}_test.tfrecords'.format(dataset=dataset)

    raw_dataset = tf.data.TFRecordDataset(filename)
    for raw_record in raw_dataset:
        
        example = tf.train.Example()
        example.ParseFromString(raw_record.numpy())
        
        race = example.features.feature['race'].int64_list.value[0]
        age = example.features.feature['age'].int64_list.value[0]
        if (dataset == 'mimic' and age > 0):
            age -= 1
        gender = example.features.feature['gender'].int64_list.value[0]
                        
                        
        if (race == 0):
            race_labels.append([1, 0, 0])
        elif (race == 1):
            race_labels.append([0, 1, 0])
        else:
            race_labels.append([0, 0, 1])
                
            
        if (age == 0):
            age_labels.append([1, 0, 0, 0])
        elif (age == 1):
            age_labels.append([0, 1, 0, 0])
        elif (age == 2):
            age_labels.append([0, 0, 1, 0])
        else:
            age_labels.append([0, 0, 0, 1])
                
            
        if (gender == 0):
            gender_labels.append([1, 0])
        else:
            gender_labels.append([0, 1])
                            
    
    return np.array(race_labels), np.array(age_labels), np.array(gender_labels)

dataset = 'mimic'

race_labels, age_labels, gender_labels = get_data(dataset=dataset)

### Plot CXR task transfer figure

In [None]:
color_list = ['C0', 'C1']

archi = 'resnet'

results_list = ['predictions/model_{archi}_{dataset}_ERM_task_transfer_race_on_original'.format(archi=archi, dataset=dataset),
                'predictions/model_{archi}_{dataset}_ERM_task_transfer_race_proposed_on_aug'.format(archi=archi, dataset=dataset),
                'predictions/model_{archi}_{dataset}_ERM_task_transfer_age_on_original'.format(archi=archi, dataset=dataset),
                'predictions/model_{archi}_{dataset}_ERM_task_transfer_age_proposed_on_aug'.format(archi=archi, dataset=dataset),
                'predictions/model_{archi}_{dataset}_ERM_task_transfer_gender_on_original'.format(archi=archi, dataset=dataset),
                'predictions/model_{archi}_{dataset}_ERM_task_transfer_gender_proposed_on_aug'.format(archi=archi, dataset=dataset)
                ]

model_list = ['Race', 'Age', 'Geder']


plt.figure(figsize=(5, 2), dpi = 300)
plt.title('AUC of CXR task transfer', fontsize=8)
plt.yticks([9, 5, 1], model_list, fontsize=8)
plt.xticks(fontsize=5)
plt.tight_layout()
loc = [10, 8, 6, 4, 2, 0]


for k in range(len(results_list)):
                            
    with open(results_list[k], "rb") as fp:
        y_preds = pickle.load(fp)
    fp.close()
        
    if (int(k/2) == 0):
        all_mean_score, all_lower, all_upper = task_transfer_test(y_preds, race_labels)
    elif (int(k/2) == 1):
        all_mean_score, all_lower, all_upper = task_transfer_test(y_preds, age_labels)
    else:
        all_mean_score, all_lower, all_upper = task_transfer_test(y_preds, gender_labels)
        
    if (k % 2 == 0):
        label = 'Baseline'
    else:
        label = 'Proposed'
        
    plt.plot([all_lower, all_upper], [loc[k], loc[k]], color=color_list[k%2])
    plt.plot(all_mean_score, loc[k], 'o', color=color_list[k%2], label=label, markersize=3)
    
    print(all_mean_score, all_lower, all_upper)
    
    
legend_without_duplicate_labels(plt)

plt.savefig('results_imgs/{dataset}_{archi}_task_transfer_2d.jpg'.format(archi=archi, dataset=dataset))

### Plot MRI task transfer figure

In [None]:
df = pd.read_csv('data_new.csv')
data_path = '../../../mnt/usb/kuopc/ADNI_B1/MPR__GradWarp__B1_Correction_crop/'

df = df.loc[df['Group'] != 'MCI']
df = df.loc[df['Split'] == 'test']

df['Group'] = df['Group'].replace(['CN', 'AD'], [0, 1])
df['Sex'] = df['Sex'].replace(['F', 'M'], [0, 1])
df['Age'] = np.where(df['Age'] <= 75, 0, 1)


In [None]:
color_list = ['C0', 'C1']

results_list = [
                'predictions/3D_CNN_AD_CN_task_transfer_age_2_on_original',
                'predictions/3D_CNN_AD_CN_proposed_task_transfer_age_2_on_original',
                'predictions/3D_CNN_AD_CN_task_transfer_gender_on_original',
                'predictions/3D_CNN_AD_CN_proposed_task_transfer_gender_on_original']

model_list = ['Age', 'Gender']



plt.figure(figsize=(5, 2), dpi = 300)
plt.title('AUC of brain MRI task transfer', fontsize=8)
plt.tight_layout()
plt.yticks([5, 3], model_list, fontsize=8)
plt.xticks(fontsize=5)
loc = [5.5, 4.5, 3.5, 2.5]


for k in range(len(results_list)):
                            
    with open(results_list[k], "rb") as fp:
        y_preds = CPU_Unpickler(fp).load()
    fp.close()
        
    if (k <= 1):
        all_mean_score, all_lower, all_upper = task_transfer_test(y_preds, df['Age'].values)
    else:
        all_mean_score, all_lower, all_upper = task_transfer_test(y_preds, df['Sex'].values)

    if (k % 2 == 0):
        label = 'Baseline'
    else:
        label = 'Proposed'
        
    plt.plot([all_upper, all_lower], [loc[k], loc[k]], color=color_list[k%2])
    plt.plot(all_mean_score, loc[k], 'o', color=color_list[k%2], label=label, markersize=3)
    
    print(all_mean_score, all_lower, all_upper)
    
    
legend_without_duplicate_labels(plt)

plt.savefig('results_imgs/mimic_resnet_task_transfer_3d.jpg', bbox_inches='tight', transparent="True")

In [None]:
metric = 'Error rate'

target_result = 2

for group in ['race', 'age', 'gender']:

    df = pd.read_csv('disparity_results/all_result_densenet_{group}.csv'.format(group=group), index_col=0)
    
    print(df[metric].values[0], df[metric].values[target_result])
    print(group, ':', np.round(100*(float(df[metric].values[0][:5])-float(df[metric].values[target_result][:5]))/float(df[metric].values[0][:5]), 2), '%')

In [None]:
for group in ['age', 'gender']:

    df = pd.read_csv('disparity_results/3D_all_result_{group}.csv'.format(group=group), index_col=0)
    
    print(df[metric].values[0], df[metric].values[target_result])
    print(group, ':', np.round(100*(float(df[metric].values[0])-float(df[metric].values[target_result]))/float(df[metric].values[0]), 2), '%')
    

In [None]:
def task_transfer_test(y_preds, y_test, best_thresh):
    
    n_bootstraps = 1000
    rng_seed = 2021  # control reproducibility
    er = []
    auc = []

    rng = np.random.RandomState(rng_seed)
    for i in range(n_bootstraps):
        # bootstrap by sampling with replacement on the prediction indices
        
        indices = rng.randint(0, len(y_preds), len(y_preds))
        
        if len(np.unique(y_test[indices])) < 2:
            # We need at least one positive and one negative sample for ROC AUC
            # to be defined: reject the sample
            continue
            
        auc.append(roc_auc_score(y_test[indices], y_preds[indices]))
        
        y_preds_ = np.where(y_preds > best_thresh, 1, 0)

        tn, fp, fn, tp = confusion_matrix(y_test[indices], y_preds_[indices]).ravel()
        er.append((fp + fn) / (tn + fp + fn + tp))
        
    return np.nanmean(er), np.nanmean(auc)

In [None]:
ers = []

_, y_test = get_data(aug_method='', dataset='mimic', data_split='test', task='race', return_demo=False, only_label=True)

best_thresh = np.loadtxt('thresh/model_densenet_mimic_ERM_race_thresh.txt')

with open('predictions/model_densenet_mimic_ERM_race_on_original', "rb") as fp:   # Unpickling
    y_preds = pickle.load(fp)

for k in range(3):

    ers.append(task_transfer_test(y_preds[:, k], y_test[:, k], best_thresh[k]))

print(np.mean(ers, axis=0))

In [None]:
ers = []

_, y_test = get_data(aug_method='', dataset='mimic', data_split='test', task='race', return_demo=False, only_label=True)

best_thresh = np.loadtxt('thresh/model_densenet_mimic_ERM_race_proposed_thresh.txt')

with open('predictions/model_densenet_mimic_ERM_race_proposed_on_original', "rb") as fp:   # Unpickling
    y_preds = pickle.load(fp)

for k in range(3):

    ers.append(task_transfer_test(y_preds[:, k], y_test[:, k], best_thresh[k]))
    
print(np.mean(ers, axis=0))

In [None]:
best_thresh = np.loadtxt('thresh/3D_CNN_AD_CN_thresh.txt')

with open('predictions/3D_CNN_AD_CN_on_original', "rb") as fp:   # Unpickling
    y_preds = pickle.load(fp)

ers = task_transfer_test(y_preds, df['Group'].values, best_thresh)
    
print(np.mean(ers[0]), np.mean(ers[1]))

In [None]:
best_thresh = np.loadtxt('thresh/3D_CNN_AD_CN_proposed_thresh.txt')

with open('predictions/3D_CNN_AD_CN_proposed_on_original', "rb") as fp:   # Unpickling
    y_preds = pickle.load(fp)


ers = task_transfer_test(y_preds, df['Group'].values, best_thresh)
    
print(np.mean(ers[0]), np.mean(ers[1]))