In [None]:
!nvidia-smi

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

import warnings
warnings.filterwarnings("ignore")

from Parse_TFrecords import *
from define_model import *
from load_label import *
from utilities import *
import gc
import tensorflow as tf
import cv2 as cv
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

import numpy as np
from sklearn.metrics import classification_report, average_precision_score

gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.set_visible_devices(
    gpus[5:], device_type='GPU'
)
for gpu in gpus[5:]:
    tf.config.experimental.set_memory_growth(gpu, True)
    
print(tf.__version__, gpus)

## Pre-train

In [None]:
files = '../Data/Chexpert_cxr.tfrecords'

BATCH_SIZE = 32

train_dataset = (tf.data.TFRecordDataset(
    files, compression_type=None, buffer_size=BATCH_SIZE*10, num_parallel_reads=32)
.map(parse_TFrecord_pretrain)
.batch(BATCH_SIZE, drop_remainder=True)
.shuffle(256))

In [None]:
archi = 'Xception'

checkpoint_filepath = 'checkpoints_new/checkpoint_pretrain_{i}'.format(i=archi)
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_best_only=True,
    save_weights_only=True,
    monitor='loss',
    mode='min')

callback = [tf.keras.callbacks.LearningRateScheduler(scheduler),
            model_checkpoint_callback]

model = define_model(archi, nodes=6)

model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
                 optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), metrics='AUC')

model.fit(train_dataset, epochs=5, shuffle=True, callbacks=callback)

del model
gc.collect()

## Fine-tune Model

In [None]:
BATCH_SIZE = 32

# All
# pos = 57925
# neg = 212124

# Emory
# pos = 25294
# neg = 120259

# MIMIC
pos = 32631
neg = 91865

total = pos+neg

weight_for_0 = (1 / neg) * (total / 2.0)
weight_for_1 = (1 / pos) * (total / 2.0)

class_weight = {0: weight_for_0, 1: weight_for_1}
print('Weight for class 0: {:.2f}'.format(weight_for_0))
print('Weight for class 1: {:.2f}'.format(weight_for_1))

In [None]:
record_file_train = 'tfrecords/copd_mimic_train.tfrecords'
train_dataset = (tf.data.TFRecordDataset(
    record_file_train, buffer_size=BATCH_SIZE, compression_type=None, num_parallel_reads=32)
.map(parse_TFrecord_train)
.shuffle(BATCH_SIZE)
.batch(BATCH_SIZE))

record_file_val = 'tfrecords/copd_mimic_val.tfrecords'
val_dataset = (tf.data.TFRecordDataset(
    record_file_val, buffer_size=BATCH_SIZE, compression_type=None, num_parallel_reads=32)
.map(parse_TFrecord_train)
.batch(BATCH_SIZE))

In [None]:
def get_data_label(dataset='mimic', split='test', return_demo=False):
    
    filename = 'tfrecords/copd_{a}_{b}.tfrecords'.format(a=dataset, b=split)
      
    y = []
    img = []
    demo = []

    #load the test files
    raw_dataset = tf.data.TFRecordDataset(filename)
    for raw_record in raw_dataset:
        example = tf.train.Example()
        example.ParseFromString(raw_record.numpy())
        
        label = example.features.feature['COPD'].int64_list.value[0]
        
        y.append(label)
        
        if (return_demo):
        
            gender = example.features.feature['gender'].int64_list.value[0]
            race = example.features.feature['race'].int64_list.value[0]
            age = example.features.feature['age'].int64_list.value[0]

            demo.append({'Age':age, 'Gender':gender, 'Race':race})
            
        nparr = np.fromstring(example.features.feature['jpg_bytes'].bytes_list.value[0], np.uint8)
        img_np = cv.imdecode(nparr, cv.IMREAD_GRAYSCALE)  
        
        img.append(np.float32(st.resize(img_np, (256, 256, 1))))

        
    if (return_demo):
        return np.array(y), np.array(demo)
    else:
        return np.array(y), np.array(img)
    
y_train, X_train = get_data_label(dataset = 'mimic', split = 'train', return_demo = False)
y_val, X_val = get_data_label(dataset = 'mimic', split = 'val', return_demo = False)

np.random.shuffle(y_train)
np.random.shuffle(y_val)

train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(BATCH_SIZE).batch(BATCH_SIZE)
val_dataset = tf.data.Dataset.from_tensor_slices((X_val, y_val)).shuffle(BATCH_SIZE).batch(BATCH_SIZE)

del y_train, X_train, y_val, X_val
gc.collect()

In [None]:
archi = 'Xception'
checkpoint_filepath = 'checkpoints_mimic/checkpoint_BCE_{i}_shuffle'.format(i=archi)

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_best_only=True,
    save_weights_only=True,
    monitor='val_loss',
    mode='min')

callback = [tf.keras.callbacks.LearningRateScheduler(scheduler),
            tf.keras.callbacks.EarlyStopping(mode='min', patience=3, monitor='val_loss'),
            model_checkpoint_callback]

base_model = load_model_from_pretrain(archi)

model = tf.keras.Sequential()
model.add(tf.keras.layers.UpSampling3D(size=(1,1,3)))
model.add(base_model)

model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
                 optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), metrics='AUC')

model.fit(train_dataset, epochs=5, shuffle=True, validation_data=val_dataset, callbacks=callback, class_weight=class_weight)

del model
gc.collect()

In [None]:
BATCH_SIZE = 32
record_file_test = 'tfrecords/copd_emory_test.tfrecords'
test_dataset = (tf.data.TFRecordDataset(
    record_file_test, buffer_size=BATCH_SIZE, compression_type=None, num_parallel_reads=32)
.map(parse_TFrecord_test)
.batch(BATCH_SIZE))

y_test = get_data_label(dataset = 'emory', split = 'test', return_demo = False)

In [None]:
y_test[0].shape

In [None]:
# checkpoint_filepath = 'checkpoints_mimic/checkpoint_BCE_{i}_shuffle'.format(i=archi)

# model = define_model(archi)

# model.load_weights(checkpoint_filepath)

# y_preds = model.predict(test_dataset)

thresh = get_thresh(y_test[0], y_preds, 'Youden')

test_CI(y_preds, y_test[0], thresh)

gc.collect()

In [None]:
archis = ['Xception', 'DenseNet121', 'ResNet50V2', 'MobileNetV2']

for archi in archis:
    checkpoint_filepath = 'checkpoints/checkpoints_merged/checkpoint_BCE_{i}'.format(i=archi)

    model = define_model(archi)

    model.load_weights(checkpoint_filepath)

    y_preds = model.predict(test_dataset)
    
    thresh = get_thresh(y_test, y_preds, 'Youden')

    test_CI(y_preds, y_test, thresh)

    gc.collect()

    outfile = 'preds/merged/test_preds/mimic_{i}_preds.npy'.format(i=archi)
    np.save(outfile, y_preds)

In [None]:
archis = ['Xception', 'DenseNet121', 'ResNet50V2', 'MobileNetV2']

for archi in archis:
    checkpoint_filepath = 'checkpoints/checkpoints_merged/checkpoint_BCE_{i}'.format(i=archi)

    model = define_model(archi)

    model.load_weights(checkpoint_filepath)

    y_preds = model.predict(test_dataset)
    
    thresh = get_thresh(y_test, y_preds, 'Youden')

    test_CI(y_preds, y_test, thresh)

    gc.collect()

    outfile = 'preds/merged/test_preds/emory_{i}_preds.npy'.format(i=archi)
    np.save(outfile, y_preds)

In [None]:
thresh = get_thresh(y_label, y_preds, 'Youden')

test_CI(y_preds, y_label, thresh)

gc.collect()

## Fairness Analysis

In [None]:
race_list = ['White', 'Black', 'Latino', 'Others', 'Asian']

for race_num, race in enumerate(race_list):
    print(race)
    
    idx = []
    for i, l in enumerate(y_demo):
        if (l['Race']==race_num):
            idx.append(i)
            
    temp_df = y_preds[idx]
    temp_label = y_label[idx]
    
    thresh = get_thresh(temp_label, temp_df, 'Youden')

    test_CI(temp_df, temp_label, thresh)

    gc.collect()

In [None]:
gender_list = ['Female', 'Male']

for gender_num, gender in enumerate(gender_list):
    print(gender)
    
    idx = []
    for i, l in enumerate(y_demo):
        if (l['Gender']==gender_num):
            idx.append(i)

    temp_df = y_preds[idx]
    temp_label = y_label[idx]
    
    thresh = get_thresh(temp_label, temp_df, 'Youden')

    test_CI(temp_df, temp_label, thresh)

    gc.collect()

In [None]:
age_list = ['0-40', '40-60', '60-80', '80-']

for age_num, age in enumerate(age_list):
    print(age)
    
    idx = []
    for i, l in enumerate(y_demo):
        if (l['Age']==age_num):
            idx.append(i)

    temp_df = y_preds[idx]
    temp_label = y_label[idx]
    
    thresh = get_thresh(temp_label, temp_df, 'Youden')

    test_CI(temp_df, temp_label, thresh)

    gc.collect()