In [None]:
import os
import warnings
warnings.filterwarnings("ignore")

from Parse_TFrecords import *
from define_model import *
from load_label import *
from utilities import *
import gc
import tensorflow as tf
import numpy as np

gpus = tf.config.list_physical_devices(device_type='GPU')
tf.config.set_logical_device_configuration(
        gpus[1],
        [tf.config.LogicalDeviceConfiguration(memory_limit=1024)])
tf.config.set_visible_devices(devices=gpus[1], device_type='GPU')


archis = ['DenseNet121', 'ResNet50V2', 'MobileNetV2', 'InceptionResNetV2', 'Xception', 'EfficientNetV2S',  'InceptionV3']

for archi in archis:
    checkpoint_filepath = 'checkpoints_new/checkpoint_BCE_{i}'.format(i=archi)
    monitor_ = 'val_loss'

    model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_filepath,
        save_best_only=True,
        save_weights_only=True,
        monitor=monitor_,
        mode='min')

    callback = [tf.keras.callbacks.LearningRateScheduler(scheduler),
                tf.keras.callbacks.EarlyStopping(mode='min', patience=3, monitor=monitor_),
                model_checkpoint_callback]

    model = load_model_from_pretrain(archi)

    loss_func = tf.keras.losses.BinaryCrossentropy(from_logits=True)

    model.compile(loss=loss_func,
                     optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), metrics='AUC')
   
    model.fit(train_dataset, epochs=3, shuffle=True, validation_data=val_dataset, callbacks=callback, class_weight=class_weight)

    del model
    gc.collect()

In [None]:
BATCH_SIZE = 32

pos = 91553
neg = 329838

total = pos+neg

weight_for_0 = (1 / neg) * (total / 2.0)
weight_for_1 = (1 / pos) * (total / 2.0)

class_weight = {0: weight_for_0, 1: weight_for_1}

In [None]:
record_file_train = 'copd_merged_train.tfrecords'
train_dataset = (tf.data.TFRecordDataset(
    record_file_train, buffer_size=BATCH_SIZE, compression_type=None, num_parallel_reads=32)
.map(parse_TFrecord_train)
.shuffle(total)
.batch(BATCH_SIZE))

record_file_val = 'copd_merged_val.tfrecords'
val_dataset = (tf.data.TFRecordDataset(
    record_file_val, buffer_size=BATCH_SIZE, compression_type=None, num_parallel_reads=32)
.map(parse_TFrecord_train)
.batch(BATCH_SIZE))

record_file_test = 'copd_merged_test.tfrecords'
test_dataset = (tf.data.TFRecordDataset(
    record_file_test, buffer_size=BATCH_SIZE, compression_type=None, num_parallel_reads=32)
.map(parse_TFrecord_test)
.batch(BATCH_SIZE))

y_test = get_data_label('test')