In [None]:
!nvidia-smi

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

import warnings
warnings.filterwarnings("ignore")

from Parse_TFrecords import *
from define_model import *
from load_label import *
from utilities import *
import gc
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

import numpy as np
from sklearn.metrics import classification_report, average_precision_score

# from libauc.losses import AUCMLoss 
# from libauc.optimizers import PESG 

gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.set_visible_devices(
    gpus[5:], device_type='GPU'
)
for gpu in gpus[5:]:
    tf.config.experimental.set_memory_growth(gpu, True)
    
print(tf.__version__, gpus)

## Pre-train

In [None]:
files = '../Data/Chexpert_cxr.tfrecords'

BATCH_SIZE = 32

train_dataset = (tf.data.TFRecordDataset(
    files, compression_type=None, buffer_size=BATCH_SIZE*10, num_parallel_reads=32)
.map(parse_TFrecord_pretrain)
.batch(BATCH_SIZE, drop_remainder=True)
.shuffle(256))

In [None]:
archi = 'Xception'

checkpoint_filepath = 'checkpoints_new/checkpoint_pretrain_{i}'.format(i=archi)
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_best_only=True,
    save_weights_only=True,
    monitor='loss',
    mode='min')

callback = [tf.keras.callbacks.LearningRateScheduler(scheduler),
            model_checkpoint_callback]

model = define_model(archi, nodes=6)

model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
                 optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), metrics='AUC')

model.fit(train_dataset, epochs=5, shuffle=True, callbacks=callback)

del model
gc.collect()

## Fine-tune Model

In [None]:
BATCH_SIZE = 1

# All
# pos = 57925
# neg = 212124

# Emory
pos = 25294
neg = 120259

# MIMIC
# pos = 32631
# neg = 91865

total = pos+neg

weight_for_0 = (1 / neg) * (total / 2.0)
weight_for_1 = (1 / pos) * (total / 2.0)

class_weight = {0: weight_for_0, 1: weight_for_1}
print('Weight for class 0: {:.2f}'.format(weight_for_0))
print('Weight for class 1: {:.2f}'.format(weight_for_1))

In [None]:
record_file_train = 'tfrecords/copd_mimic_train.tfrecords'
train_dataset = (tf.data.TFRecordDataset(
    record_file_train, buffer_size=BATCH_SIZE, compression_type=None, num_parallel_reads=32)
.map(parse_TFrecord_train)
.shuffle(BATCH_SIZE)
.batch(BATCH_SIZE))

record_file_val = 'tfrecords/copd_mimic_val.tfrecords'
val_dataset = (tf.data.TFRecordDataset(
    record_file_val, buffer_size=BATCH_SIZE, compression_type=None, num_parallel_reads=32)
.map(parse_TFrecord_train)
.batch(BATCH_SIZE))

In [None]:
archi = 'Xception'
checkpoint_filepath = 'checkpoints_mimic/checkpoint_BCE_{i}'.format(i=archi)

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_best_only=True,
    save_weights_only=True,
    monitor='val_loss',
    mode='min')

callback = [tf.keras.callbacks.LearningRateScheduler(scheduler),
            tf.keras.callbacks.EarlyStopping(mode='min', patience=3, monitor=monitor_),
            model_checkpoint_callback]

model = load_model_from_pretrain(archi)

model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
                 optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), metrics='AUC')

model.fit(train_dataset, epochs=5, shuffle=True, validation_data=val_dataset, callbacks=callback, class_weight=class_weight)

del model
gc.collect()

In [None]:
BATCH_SIZE = 32
record_file_test = 'copd_mimic_test.tfrecords'
test_dataset = (tf.data.TFRecordDataset(
    record_file_test, buffer_size=BATCH_SIZE, compression_type=None, num_parallel_reads=32)
.map(parse_TFrecord_test)
.batch(BATCH_SIZE))

y_test = get_data_label('test')

In [None]:
checkpoint_filepath = 'checkpoints_mimic/checkpoint_BCE_{i}'.format(i=archi)

model = define_model()

model.load_weights(checkpoint_filepath)

y_label, y_demo = get_data_label('mimic', split, True)

y_preds = model.predict(test_dataset)

outfile = 'preds/mimic/test_preds/fusion_{i}_preds.npy'.format(i=archi)
np.save(outfile, preds)

In [None]:
thresh = get_thresh(y_label, y_preds, 'Youden')

test_CI(y_preds, y_label, thresh)

gc.collect()

## Fairness Analysis

In [None]:
race_list = ['White', 'Black', 'Latino', 'Others', 'Asian']

for race_num, race in enumerate(race_list):
    print(race)
    
    idx = []
    for i, l in enumerate(y_demo):
        if (l['Race']==race_num):
            idx.append(i)
            
    temp_df = y_preds[idx]
    temp_label = y_label[idx]
    
    thresh = get_thresh(temp_label, temp_df, 'Youden')

    test_CI(temp_df, temp_label, thresh)

    gc.collect()

In [None]:
gender_list = ['Female', 'Male']

for gender_num, gender in enumerate(gender_list):
    print(gender)
    
    idx = []
    for i, l in enumerate(y_demo):
        if (l['Gender']==gender_num):
            idx.append(i)

    temp_df = y_preds[idx]
    temp_label = y_label[idx]
    
    thresh = get_thresh(temp_label, temp_df, 'Youden')

    test_CI(temp_df, temp_label, thresh)

    gc.collect()

In [None]:
age_list = ['0-40', '40-60', '60-80', '80-']

for age_num, age in enumerate(age_list):
    print(age)
    
    idx = []
    for i, l in enumerate(y_demo):
        if (l['Age']==age_num):
            idx.append(i)

    temp_df = y_preds[idx]
    temp_label = y_label[idx]
    
    thresh = get_thresh(temp_label, temp_df, 'Youden')

    test_CI(temp_df, temp_label, thresh)

    gc.collect()