In [None]:
!nvidia-smi

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import skimage.transform as st
import os
import gc
import warnings
from sklearn.metrics import classification_report, average_precision_score
from Parse_TFrecords import *
from define_model import *
from load_data import *
from utilities import *
 
print(tf.__version__)

# warnings.filterwarnings("ignore")
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# gpus = tf.config.list_physical_devices(device_type='GPU')
# tf.config.set_visible_devices(devices=gpus[0], device_type='GPU')

# Pretrain model

In [None]:
files = '../Data/Chexpert_cxr.tfrecords'

BATCH_SIZE = 32

train_dataset = (tf.data.TFRecordDataset(
    files, compression_type=None, buffer_size=BATCH_SIZE*10, num_parallel_reads=32)
.map(parse_TFrecord_pretrain)
.batch(BATCH_SIZE, drop_remainder=True)
.shuffle(256))

In [None]:
checkpoint_filepath = 'checkpoints/AUC/checkpoint_pretrain_Dnet121'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_best_only=True,
    save_weights_only=True,
    monitor='auc',
    mode='max')

callback = [tf.keras.callbacks.LearningRateScheduler(scheduler),
            tf.keras.callbacks.EarlyStopping(mode='max', patience=3, monitor='auc'),
            model_checkpoint_callback]

metric = tf.keras.metrics.AUC()

model = define_model_pretrain(archi='Dnet121')

model.summary()

model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
                 optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), metrics=metric)

model.fit(train_dataset, epochs=5, shuffle=True, callbacks=callback)

# Transfer learning

In [None]:
pos = 27297
neg = 119133

total = pos+neg

weight_for_0 = (1 / neg) * (total / 2.0)
weight_for_1 = (1 / pos) * (total / 2.0)

class_weight = {0: weight_for_0, 1: weight_for_1}

print('Weight for class 0: {:.2f}'.format(weight_for_0))
print('Weight for class 1: {:.2f}'.format(weight_for_1))

In [None]:
BATCH_SIZE = 32

record_file_train = 'copd_train_new.tfrecords'
train_dataset = (tf.data.TFRecordDataset(
    record_file_train, buffer_size=BATCH_SIZE*10, compression_type=None, num_parallel_reads=32)
.map(parse_TFrecord_train)
.shuffle(total)
.batch(BATCH_SIZE))

record_file_val = 'copd_val_new.tfrecords'
val_dataset = (tf.data.TFRecordDataset(
    record_file_val, buffer_size=BATCH_SIZE*10, compression_type=None, num_parallel_reads=32)
.map(parse_TFrecord_train)
.shuffle(36601)
.batch(BATCH_SIZE))

record_file_test = 'copd_test_new.tfrecords'
test_dataset = (tf.data.TFRecordDataset(
    record_file_test, buffer_size=BATCH_SIZE*10, compression_type=None, num_parallel_reads=32)
.map(parse_TFrecord_test)
.batch(BATCH_SIZE))

y_test = get_data_label('test')

## Different loss functions

In [None]:
archi='Dnet121'

checkpoint_filepath = 'checkpoints/AUC/checkpoint_KLDiv_Dnet121'
monitor_ = 'val_auc'

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_best_only=True,
    save_weights_only=True,
    monitor=monitor_,
    mode='max')
    
callback = [tf.keras.callbacks.LearningRateScheduler(scheduler),
            tf.keras.callbacks.EarlyStopping(mode='max', patience=3, monitor=monitor_),
            model_checkpoint_callback]

model = load_model_from_pretrain(archi)

# loss_func = binary_focal_loss
# loss_func = tf.keras.losses.CategoricalHinge()
loss_func = tf.keras.losses.KLDivergence()
# loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)

metric = tf.keras.metrics.AUC(name='auc')

model.compile(loss=loss_func,
                 optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), metrics=metric)

model.fit(train_dataset, epochs=20, shuffle=True, validation_data=val_dataset, callbacks=callback, class_weight=class_weight)

y_preds = model.predict(test_dataset)

test_CI(y_preds, y_test)
    
thresh = get_thresh(y_test, y_preds, 'Youden')
print('thresh:', thresh)
        
print(classification_report(y_test, np.where(y_preds >= thresh, 1, 0)))
print(average_precision_score(y_test, y_preds, average=None))

thresh = get_thresh(y_test, y_preds, 'G-mean')
print('thresh:', thresh)
        
print(classification_report(y_test, np.where(y_preds >= thresh, 1, 0)))
print(average_precision_score(y_test, y_preds, average=None))

thresh = 0.5
print('thresh:', thresh)
        
print(classification_report(y_test, np.where(y_preds >= thresh, 1, 0)))
print(average_precision_score(y_test, y_preds, average=None))

gc.collect()

In [None]:
checkpoint_filepath = 'checkpoints/AUC/checkpoint_BCE_Dnet121'
archi = 'Dnet121'
model = define_model(archi)
model.load_weights(checkpoint_filepath)

y_preds = model.predict(test_dataset)

test_CI(y_preds, y_test)

print(average_precision_score(y_test, y_preds, average=None))
        
thresh = get_thresh(y_test, y_preds, 'Youden')
print('thresh:', thresh)
        
print(classification_report(y_test, np.where(y_preds >= thresh, 1, 0)))

thresh = get_thresh(y_test, y_preds, 'G-mean')
print('thresh:', thresh)
        
print(classification_report(y_test, np.where(y_preds >= thresh, 1, 0)))

thresh = 0.5
print('thresh:', thresh)
        
print(classification_report(y_test, np.where(y_preds >= thresh, 1, 0)))

gc.collect()