# Import

In [None]:
import tensorflow as tf
import numpy as np
import cv2 as cv
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, f1_score, roc_curve, auc, roc_auc_score, classification_report, accuracy_score, silhouette_score

In [None]:
from tensorflow import keras
from tensorflow.keras.applications.densenet import DenseNet121
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Activation, Flatten, Dropout, BatchNormalization, GlobalAveragePooling2D, Input
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger, LearningRateScheduler, ReduceLROnPlateau, EarlyStopping, TensorBoard
from tensorflow.keras.preprocessing import image
from tensorflow.keras.optimizers import Adadelta, Adam, SGD, Adagrad

In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
    # Currently, memory growth needs to be the same across GPUs
        tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
      #  for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpus[0], True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        print(e)

In [None]:
train_filename = './data/new_mimic_train.tfrecords'
val_filename = './data/new_mimic_val.tfrecords'
test_filename = './data/new_mimic_test.tfrecords'
IMAGE_WIDTH, IMAGE_HEIGHT = 256, 256
BUFFER_SIZE = 1000
BATCH_SIZE = 16
current_mode = 'mmd'
current_group = 'race'

In [None]:
feature_description = {
    'jpg_bytes': tf.io.FixedLenFeature([], tf.string),
    'race': tf.io.FixedLenFeature([], tf.int64),
    'age': tf.io.FixedLenFeature([], tf.int64),
    'gender': tf.io.FixedLenFeature([], tf.int64),
    'subject_id': tf.io.FixedLenFeature([], tf.int64),
    'Cardiomegaly': tf.io.FixedLenFeature([], tf.float32),
    'Consolidation': tf.io.FixedLenFeature([], tf.float32),
    'Edema': tf.io.FixedLenFeature([], tf.float32),
    'Enlarged Cardiomediastinum': tf.io.FixedLenFeature([], tf.float32),
    'Lung Opacity': tf.io.FixedLenFeature([], tf.float32),
    'Atelectasis': tf.io.FixedLenFeature([], tf.float32),
    'No Finding': tf.io.FixedLenFeature([], tf.float32),
    'Pleural Effusion': tf.io.FixedLenFeature([], tf.float32),
    'Pneumonia': tf.io.FixedLenFeature([], tf.float32),
    'Pneumothorax': tf.io.FixedLenFeature([], tf.float32)
}

label_list = ['Pneumothorax', 'Pneumonia', 'Pleural Effusion', 'No Finding', 'Atelectasis', 'Lung Opacity','Enlarged Cardiomediastinum', 'Edema', 'Consolidation', 'Cardiomegaly']

# Preprocess

In [None]:
def _parse(example):
    return tf.io.parse_single_example(example, feature_description)
def read_tfrecord(example):
    example = tf.io.parse_single_example(example, feature_description)
    img = tf.image.decode_jpeg(example['jpg_bytes'], channels=3)
    labels = [tf.cast(0, tf.int64) if example['No Finding'] == 1 else tf.cast(1, tf.int64)]
    
    if current_group == 'age':
        groups = [tf.cast(0, tf.int64) if example['age'] <= 1 else tf.cast(example['age']-1, tf.int64)]
    elif current_group == 'race':
        groups = [tf.cast(2, tf.int64) if example['race'] == 4 else tf.cast(example['race'], tf.int64)]
    else:
        groups = [tf.cast(example['gender'], tf.int64)]
    return img, labels, groups

def _fixup_shape(images, labels, groups):
    images.set_shape([IMAGE_HEIGHT,IMAGE_WIDTH, 3])
    labels.set_shape([1])
    groups.set_shape([1])
    return images, labels, groups

In [None]:
def load_dataset(filename):
    dataset = tf.data.TFRecordDataset(filename)
    dataset = dataset.map(read_tfrecord, num_parallel_calls = tf.data.experimental.AUTOTUNE).map(_fixup_shape)
    dataset = dataset.shuffle(BUFFER_SIZE)
    dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
    dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
    return dataset


In [None]:
dataset_train = load_dataset(train_filename)
dataset_val = load_dataset(val_filename)
dataset_test = load_dataset(test_filename)

In [None]:
plt.figure(figsize=(10, 10))
for img, label, domain in dataset_val.take(1):
    for n in range(16):
        ax = plt.subplot(4, 8, n+1)
        plt.imshow(img[n])
        plt.axis("off")

# Classifier

In [None]:
input_shape = (IMAGE_WIDTH, IMAGE_HEIGHT, 3)
input_img = Input(input_shape)
x = tf.keras.applications.densenet.preprocess_input(input_img)
base_model = DenseNet121(weights='imagenet', include_top=False, input_tensor = x, pooling = 'avg')
x = base_model(x)
predictions = Dense(1, activation="sigmoid")(x)
model = Model(inputs=input_img, outputs=predictions)

In [None]:
model.summary()

# Train

In [None]:
loss_object = tf.keras.losses.BinaryCrossentropy()
optimizer = tf.keras.optimizers.SGD(learning_rate=0.0001)
train_accuracy = tf.keras.metrics.BinaryAccuracy(name='train_accuracy')
val_accuracy = tf.keras.metrics.BinaryAccuracy(name='val_accuracy')

train_loss = tf.keras.metrics.Mean(name='train_lp_loss')
val_loss = tf.keras.metrics.Mean(name='val_lp_loss')

In [None]:
def my_cdist(x1, x2):
    x1_norm = tf.reduce_sum(tf.math.pow(x1, 2), axis=-1, keepdims=True)
    x2_norm = tf.reduce_sum(tf.math.pow(x2, 2), axis=-1, keepdims=True)
    
    mm = -2*tf.linalg.matmul(x1, tf.transpose(x2, perm=[1, 0]))

    madd = tf.math.add(tf.transpose(x2_norm, perm=[1, 0]), mm)

    res = tf.math.add(madd, x1_norm)

    return tf.clip_by_value(res, clip_value_min=1e-30, clip_value_max=1e30)

def gaussian_kernel(x, y, gamma=[0.001, 0.01, 0.1, 1, 10, 100,
                                       1000]):
    D = my_cdist(x, y)
    
    K = tf.zeros_like(D)

    for g in gamma:
        K = tf.math.add(K, tf.math.exp(tf.math.multiply(D, -g)))

    return K/len(gamma)

def mmd(x, y):
    # https://stats.stackexchange.com/questions/276497/maximum-mean-discrepancy-distance-distribution
    Kxx = tf.reduce_mean(gaussian_kernel(x, x))
    Kyy = tf.reduce_mean(gaussian_kernel(y, y))
    Kxy = tf.reduce_mean(gaussian_kernel(x, y))
    return Kxx + Kyy - 2 * Kxy

In [None]:
if current_group == 'race':
    group = [0,1,2]
elif current_group == 'age':
    group = [0,1,2,3]
else:
    group = [0,1]
    
@tf.function
def train_step(images, labels, groups, mode='mmd'):
    penalty = 0.0
    with tf.GradientTape() as tape:
        pred = model(images)
        loss = loss_object(labels, pred)  
        for target in group:
            mask = groups == target
            grp_trg = pred[tf.squeeze(mask)][:,-1]
            all_trg = pred[:,-1]
            if len(grp_trg) > 0:
                if mode == 'mmd':
                    penalty +=  mmd(tf.expand_dims(grp_trg, -1), tf.expand_dims(all_trg,-1))
                else:
                    penalty += tf.math.abs(tf.reduce_mean(grp_trg) -  tf.reduce_mean(pred))
        if mode == 'mean':
            total_loss = loss + 1*penalty
        else:
            total_loss = loss + 1*penalty
        
    grads = tape.gradient(total_loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    
    train_loss(total_loss)
    train_accuracy(labels, pred)
    

In [None]:
@tf.function
def val_step(images, labels, domains):
    pred = model(images)
    loss = loss_object(labels, pred)  
    val_loss(loss)
    val_accuracy(labels, pred)
        

In [None]:
def val(domain=True):
    for batch in tqdm(dataset_val):
        val_step(*batch)
        
def test(domain=True):
    for batch in tqdm(dataset_test):
        val_step(*batch)

def reset_metrics(target):

    if target == 'train':
        train_loss.reset_states()
        train_accuracy.reset_states()

    if target == 'val':
        val_loss.reset_states()
        val_accuracy.reset_states()


In [None]:
epochs = 30
prev = 0
count = 0
for epoch in range(epochs):
    for batch in tqdm(dataset_train):
        train_step(*batch, mode = current_mode)

    print("Training: Epoch {} :\t Accuracy : {:.3%}, loss : {}"
          .format(epoch, train_accuracy.result(), train_loss.result()))
    
    reset_metrics('train')
    
    val()
    val_acc = val_accuracy.result()
    print("Val: Accuracy : {:.3%}, loss : {}"
          .format(val_accuracy.result(), val_loss.result()))
    
    if val_acc > prev:
        prev = val_acc
        print('save acc: {}'.format(val_acc))
        model.save_weights("./model/" + current_mode + "/distmatch_" + current_group)
        count = 1
    else:
        count += 1
    
    if count > 10:
        break
        
    reset_metrics('val')