In [None]:
!nvidia-smi

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import skimage.transform as st
import os
import gc
import warnings
from sklearn.metrics import classification_report, average_precision_score
from Parse_TFrecords import *
from define_model import *
from load_data import *
from utilities import *
 
print(tf.__version__)

warnings.filterwarnings("ignore")
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

gpus = tf.config.list_physical_devices(device_type='GPU')
tf.config.set_visible_devices(devices=gpus[1], device_type='GPU')

In [None]:
pos = 27297
neg = 119133

total = pos+neg

weight_for_0 = (1 / neg) * (total / 2.0)
weight_for_1 = (1 / pos) * (total / 2.0)

class_weight = {0: weight_for_0, 1: weight_for_1}

print('Weight for class 0: {:.2f}'.format(weight_for_0))
print('Weight for class 1: {:.2f}'.format(weight_for_1))

## Demographic attributes

In [None]:
checkpoint_filepath = 'checkpoints/PrecisionAtRecall/checkpoint_demo'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_best_only=True,
    save_weights_only=True,
    monitor='val_precision_at_recall',
    mode='max')

callback = [tf.keras.callbacks.LearningRateScheduler(scheduler),
            tf.keras.callbacks.EarlyStopping(mode='max', patience=4, monitor='val_precision_at_recall'),
            model_checkpoint_callback]

In [None]:
BATCH_SIZE = 32

record_file_train = 'copd_train_new.tfrecords'
train_dataset = (tf.data.TFRecordDataset(
    record_file_train, compression_type=None, buffer_size=32, num_parallel_reads=32)
.map(parse_TFrecord_train_demo)
.shuffle(total)
.batch(BATCH_SIZE))

record_file_val = 'copd_val_new.tfrecords'
val_dataset = (tf.data.TFRecordDataset(
    record_file_val, compression_type=None, buffer_size=32, num_parallel_reads=32)
.map(parse_TFrecord_train_demo)
.shuffle(36601)
.batch(BATCH_SIZE))

In [None]:
model = define_model_demo('Dnet121')

model.summary()

metric = tf.keras.metrics.PrecisionAtRecall(recall=0.55, name='precision_at_recall')

model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
                 optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), metrics=metric)

model.fit(train_dataset, epochs=20, shuffle=True, validation_data=val_dataset, callbacks=callback, class_weight=class_weight)

gc.collect()

In [None]:
# checkpoint_filepath = 'checkpoints/PrecisionAtRecall/checkpoint_demo'

# model = define_model_demo()

# model.load_weights(checkpoint_filepath)

test_image, test_demo, y_test = get_test_data_demo()

y_preds = model.predict([test_image, test_demo])

In [None]:
test_CI(y_preds, y_test)

thresh = get_thresh(y_test, y_preds, 'Youden')

print(thresh)
print(classification_report(y_test, np.where(y_preds >= thresh, 1, 0)))
print(average_precision_score(y_test, y_preds, average=None))

gc.collect()

In [None]:
for i in range(2):

    test_image, test_demo, y_test = get_test_data_demo(category='Gender', types=i)

    y_preds = model.predict([test_image, test_demo])

    test_CI(y_preds, y_test)

    thresh = get_thresh(y_test, y_preds)

    print(classification_report(y_test, np.where(y_preds >= thresh, 1, 0)))
    print(average_precision_score(y_test, y_preds, average=None))

    del test_image, test_demo, y_test, y_preds
    gc.collect()

In [None]:
for i in range(5):

    test_image, test_demo, y_test = get_test_data_demo(category='Race', types=i)

    y_preds = model.predict([test_image, test_demo])

    test_CI(y_preds, y_test)

    thresh = get_thresh(y_test, y_preds)

    print(classification_report(y_test, np.where(y_preds >= thresh, 1, 0)))
    print(average_precision_score(y_test, y_preds, average=None))

    del test_image, test_demo, y_test, y_preds
    gc.collect()

In [None]:
for i in range(4):

    test_image, test_demo, y_test = get_test_data_demo(category='Age', types=i)

    y_preds = model.predict([test_image, test_demo])

    test_CI(y_preds, y_test)

    thresh = get_thresh(y_test, y_preds)

    print(classification_report(y_test, np.where(y_preds >= thresh, 1, 0)))
    print(average_precision_score(y_test, y_preds, average=None))

    del test_image, test_demo, y_test, y_preds
    gc.collect()