In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from skimage import io
from sklearn.model_selection import train_test_split
import sklearn.metrics as metrics
import cv2 as cv
import numpy as np
import seaborn as sns
import random

print(f"Tensor Flow Version: {tf.__version__}")

In [None]:
SEED = 123
BATCH_SIZE = 16
IMAGE_SHAPE = (224, 224)
DATASET_DIR = "../../dataset/preprocessed-datasets/preprocessed_roi"
BUFFER_SIZE = 1000
NUM_CLASSES = 1

LABEL_MAP = {
    "BENIGN": 0,
    "BENIGN_WITHOUT_CALLBACK": 0,
    "MALIGNANT": 1
}


CHECKPOINT = "./checkpoints/checkpoint_pathology_ensemble"

##  Load data

In [None]:
metadata = pd.read_csv("../../dataset/preprocessed-datasets/preprocessed_roi_metadata.csv")

metadata["label_encoded"] = metadata.label.apply(lambda label: LABEL_MAP[label])

metadata["stratyfier"] = str(metadata["label_encoded"]) + "_" + str(metadata["shape"]) + "_" + str(metadata["margin"])

metadata.label_encoded.value_counts().plot.bar()
metadata.head()

## Split dataset

In [None]:
# Split dataset with sklearn train_test_split
train_df = metadata.query("dataset == 'train'")
train_df, val_df = train_test_split(train_df, stratify=train_df["stratyfier"], test_size=0.2, random_state=SEED)
test_df = metadata.query("dataset == 'test'")

print(len(train_df), len(val_df), len(test_df))

## Dataset generator

In [None]:
tf.get_logger().setLevel('ERROR')

def show_batch(dataset):
    for batch, labels in dataset.take(1):
        plt.figure(figsize=(15, 5))
        for i, data in enumerate(zip(batch["original"], labels)):
            plt.subplot(2, BATCH_SIZE // 2, i + 1)
            plt.imshow(data[0])
        plt.show()

        plt.figure(figsize=(15, 5))
        for i, data in enumerate(zip(batch["margin"], labels)):
            plt.subplot(2, BATCH_SIZE // 2, i + 1)
            plt.imshow(data[0])
        plt.show()

        plt.figure(figsize=(15, 5))
        for i, data in enumerate(zip(batch["shape"], labels)):
            plt.subplot(2, BATCH_SIZE // 2, i + 1)
            plt.imshow(data[0])
        plt.show()

In [None]:
def apply_aug(image, seed):
    image = tf.image.random_flip_left_right(image, seed=seed)
    image = tf.image.random_flip_up_down(image, seed=seed)
    image = tf.keras.layers.RandomRotation(factor=1.0, fill_mode="constant", seed=seed)(image)
    image = tf.keras.layers.RandomTranslation(height_factor=0.1, width_factor=0.1, fill_mode="constant", seed=seed)(image)
    image = tf.keras.layers.RandomZoom((-0.3, 0.3), fill_mode="constant", seed=seed)(image)
    
    return image

def augmentation(inputs, labels):
    seed = random.randint(0, 1000)
    inputs["original"] = apply_aug(inputs["original"], seed)
    inputs["margin"] = apply_aug(inputs["margin"], seed)
    inputs["shape"] = apply_aug(inputs["shape"], seed)
    
    return inputs, labels

def augmentation_single(inputs, labels):
    seed = random.randint(0, 1000)
    inputs = apply_aug(inputs, seed)
    
    return inputs, labels

In [None]:
def load_and_preprocess_image(data, channels=3):
    image = tf.io.read_file(DATASET_DIR + "/" + data[0])
    image = tf.image.decode_png(image, channels=channels)
    type = data[2]
    if type != "normal":
        mask = tf.io.read_file(DATASET_DIR + "/" + data[1])
        mask = tf.image.decode_png(mask, channels=1)
        if type == "shape":
            image = tf.bitwise.bitwise_and(image, mask)
        elif type == "margin":
            mask = tf.bitwise.invert(mask)
            image = tf.bitwise.bitwise_and(image, mask)

    image = tf.image.resize(image, size=IMAGE_SHAPE, method="nearest")
    return image

def load_image(data):
    img = load_and_preprocess_image(data)
    img = tf.cast(img, tf.float32)
    img /= 0xff
    return img

def create_data_generator(df, type: str = "normal"):
    dataset = tf.data.Dataset.from_tensor_slices(
        ([(row["cropped_img"], row["cropped_mask_img"], type) for i, row in df.iterrows()])
    )
    dataset = dataset.map(load_image)
    return dataset

original_train_ds = create_data_generator(train_df)
original_val_ds = create_data_generator(val_df)
original_test_ds = create_data_generator(test_df)

margin_train_ds = create_data_generator(train_df, "margin")
margin_val_ds = create_data_generator(val_df, "margin")
margin_test_ds = create_data_generator(test_df, "margin")

shape_train_ds = create_data_generator(train_df, "shape")
shape_val_ds = create_data_generator(val_df, "shape")
shape_test_ds = create_data_generator(test_df, "shape")

In [None]:
train_labels = tf.data.Dataset.from_tensor_slices((train_df.label_encoded))
val_labels = tf.data.Dataset.from_tensor_slices((val_df.label_encoded))
test_labels = tf.data.Dataset.from_tensor_slices((test_df.label_encoded))

In [None]:
train_ds = tf.data.Dataset.zip((original_train_ds, margin_train_ds, shape_train_ds)).map(lambda x1, x2, x3: {"original": x1, "margin": x2, "shape": x3})
train_ds = tf.data.Dataset.zip((train_ds, train_labels)).batch(BATCH_SIZE).shuffle(BUFFER_SIZE).map(augmentation)

val_ds = tf.data.Dataset.zip((original_val_ds, margin_val_ds, shape_val_ds)).map(lambda x1, x2, x3: {"original": x1, "margin": x2, "shape": x3})
val_ds = tf.data.Dataset.zip((val_ds, val_labels)).batch(BATCH_SIZE)

test_ds = tf.data.Dataset.zip((original_test_ds, margin_test_ds, shape_test_ds)).map(lambda x1, x2, x3: {"original": x1, "margin": x2, "shape": x3})
test_ds = tf.data.Dataset.zip((test_ds, test_labels)).batch(BATCH_SIZE)

show_batch(train_ds)

In [None]:
def show_confusion_matrix(cm, labels, norm_axis, title):
    sum = cm.numpy().sum(axis=norm_axis)
    if norm_axis == 1:
        sum = sum[:, np.newaxis]
        
    cmn = cm.numpy().astype('float') / sum
    sns.heatmap(
        cmn, 
        xticklabels=labels, 
        yticklabels=labels, 
        annot=True, 
        fmt='.2f',
        vmin=0.0,
        vmax=1.0,
    )
    plt.xlabel('Prediction')
    plt.ylabel('Label')
    plt.title(title)
    plt.show()
    
def plot_history(history):
    acc, val_acc, loss, val_loss = [], [], [], []

    for hist in history:
        acc += hist.history["accuracy"]
        val_acc += hist.history["val_accuracy"]
        loss += hist.history["loss"]
        val_loss += hist.history["val_loss"]
    plt.figure(figsize=(15, 8))
    plt.subplot(2, 2, 1)
    plt.plot(acc, label='Training accuracy')
    plt.plot(val_acc, label='Validation accuracy')
    plt.legend()
    plt.ylabel('Accuracy')
    plt.xlabel('epoch')
    plt.title('Training and Validation accuracy')

    plt.subplot(2, 2, 2)
    plt.plot(loss, label='Training Loss')
    plt.plot(val_loss, label='Validation Loss')
    plt.legend()
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('epoch')
    plt.show()
    
def create_confusion_matrix(model, dataset):
    predictions = model.predict(dataset)
    labels = np.concatenate([y for x, y in dataset])
    predictions_pre = predictions.copy()

    predictions[predictions > 0.5] = 1
    predictions[predictions != 1] = 0

    return tf.math.confusion_matrix(labels, predictions, num_classes=2), predictions_pre, labels

def draw_roc_curve(labels, predictions):
    fpr, tpr, threshold = metrics.roc_curve(labels, predictions)
    auc = metrics.auc(fpr, tpr)
    plt.figure(1)
    plt.plot([0, 1], [0, 1], 'k--')
    plt.plot(fpr, tpr, label='AUC = {:.3f}'.format(auc))
    plt.xlabel('False positive rate')
    plt.ylabel('True positive rate')
    plt.title('ROC curve')
    plt.legend(loc='best')
    plt.show()

In [None]:
def create_resnet(inputs: tf.keras.Input, trainable: bool = False):
    resnet_model = tf.keras.applications.ResNet50V2(
        input_shape=IMAGE_SHAPE + (3,),
        include_top=False,
        weights='imagenet',
    )
    
    resnet_model.trainable = trainable
    
    res_x = resnet_model(inputs)
    res_x = tf.keras.layers.GlobalAveragePooling2D()(res_x)

    return res_x

In [None]:
def create_submodel(input_name):
    inputs = tf.keras.Input(shape=IMAGE_SHAPE + (3,), name=input_name)
    resnet = create_resnet(inputs, trainable=True)
    
    outputs = tf.keras.layers.Dense(1, activation="sigmoid")(resnet)
    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    
    model._name = input_name + "_model"

    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4), 
        loss=tf.keras.losses.BinaryCrossentropy(label_smoothing=0.25),
        metrics=[
            "accuracy",
            tf.keras.metrics.Precision(),
            tf.keras.metrics.Recall(),
            tf.keras.metrics.AUC()
        ],
    )
    
    model.summary()
    
    return model

def train_submodel(model, train_ds, val_ds, learning_rate = 1e-4, epochs = 30):
    checkpoint = CHECKPOINT + "_" + model._name
    try:
        model.load_weights(checkpoint)
    except:
        print("Not weights found")
        
    callbacks = [
        tf.keras.callbacks.ModelCheckpoint(
            checkpoint, 
            save_best_only=True,
            monitor='val_loss',
            mode='min',
        ),
    #     tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, start_from_epoch=5)
    ]
    
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), 
        loss=tf.keras.losses.BinaryCrossentropy(label_smoothing=0.25),
        metrics=[
            "accuracy",
            tf.keras.metrics.Precision(),
            tf.keras.metrics.Recall(),
            tf.keras.metrics.AUC()
        ],
    )

    return model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=epochs,
        callbacks=callbacks,
    )

In [None]:
def create_single_datasets(train_ds, val_ds, test_ds):
    temp_train_ds = tf.data.Dataset.zip((margin_train_ds, train_labels)).batch(BATCH_SIZE).shuffle(BUFFER_SIZE).prefetch(tf.data.AUTOTUNE).cache().map(augmentation_single)
    temp_val_ds = tf.data.Dataset.zip((margin_val_ds, val_labels)).batch(BATCH_SIZE).shuffle(BUFFER_SIZE)
    temp_test_ds = tf.data.Dataset.zip((margin_test_ds, test_labels)).batch(BATCH_SIZE)
    return temp_train_ds, temp_val_ds, temp_test_ds

## Margin submodel

In [None]:
margin_model = create_submodel("margin")

In [None]:
temp_margin_train_ds, temp_margin_val_ds, temp_margin_test_ds = create_single_datasets(margin_train_ds, margin_val_ds, margin_test_ds)

lrs = ((1e-4, 30), (1e-5, 30))

margin_history = [train_submodel(margin_model, temp_margin_train_ds, temp_margin_val_ds, learning_rate = lr, epochs = epochs) for lr, epochs in lrs]

In [None]:
plot_history(margin_history)

margin_model.load_weights(CHECKPOINT + "_margin_model")
margin_model.evaluate(temp_margin_test_ds)

cm, predictions, labels = create_confusion_matrix(margin_model, temp_margin_test_ds)
show_confusion_matrix(cm, ["Benign", "Malignant"], 1, "Recall")
show_confusion_matrix(cm, ["Benign", "Malignant"], 0, "Precision")

draw_roc_curve(labels, predictions)

## Shape submodel

In [None]:
shape_model = create_submodel("shape")

In [None]:
temp_shape_train_ds, temp_shape_val_ds, temp_shape_test_ds = create_single_datasets(shape_train_ds, shape_val_ds, shape_test_ds)

lrs = ((1e-4, 30), (1e-5, 30))

shape_history = [train_submodel(shape_model, temp_shape_train_ds, temp_shape_val_ds, learning_rate = lr, epochs = epochs) for lr, epochs in lrs]

In [None]:
plot_history(shape_history)

shape_model.load_weights(CHECKPOINT + "_shape_model")
shape_model.evaluate(temp_shape_test_ds)

cm, predictions, labels = create_confusion_matrix(shape_model, temp_shape_test_ds)
show_confusion_matrix(cm, ["Benign", "Malignant"], 1, "Recall")
show_confusion_matrix(cm, ["Benign", "Malignant"], 0, "Precision")

draw_roc_curve(labels, predictions)

## Original submodel

In [None]:
original_model = create_submodel("original")

In [None]:
temp_original_train_ds, temp_original_val_ds, temp_original_test_ds = create_single_datasets(original_train_ds, original_val_ds, original_test_ds)

lrs = ((1e-4, 30), (1e-5, 30))

original_history = [train_submodel(original_model, temp_original_train_ds, temp_original_val_ds, learning_rate = lr, epochs = epochs) for lr, epochs in lrs]

In [None]:
plot_history(original_history)

original_model.load_weights(CHECKPOINT + "_original_model")
original_model.evaluate(temp_original_test_ds)

cm, predictions, labels = create_confusion_matrix(original_model, temp_original_test_ds)
show_confusion_matrix(cm, ["Benign", "Malignant"], 1, "Recall")
show_confusion_matrix(cm, ["Benign", "Malignant"], 0, "Precision")

draw_roc_curve(labels, predictions)

In [None]:
original_model.load_weights(CHECKPOINT + "_original_model")
margin_model.load_weights(CHECKPOINT + "_margin_model")
shape_model.load_weights(CHECKPOINT + "_shape_model")

original_model.trainable = False
margin_model.trainable = False
shape_model.trainable = False

x = tf.keras.layers.concatenate([
    original_model.layers[-2].output, 
    margin_model.layers[-2].output, 
    shape_model.layers[-2].output
])
x = tf.keras.layers.Dense(512, activation="relu")(x)
x = tf.keras.layers.Dropout(0.3)(x)
x = tf.keras.layers.Dense(256, activation="relu")(x)
x = tf.keras.layers.Dropout(0.3)(x)
x = tf.keras.layers.Dense(128, activation="relu")(x)
x = tf.keras.layers.Dropout(0.3)(x)

outputs = tf.keras.layers.Dense(1, activation="sigmoid")(x)


model = tf.keras.Model(
    inputs=[
        original_model.input, 
        margin_model.input, 
        shape_model.input
    ],
    outputs=outputs
)

model._name = "full"

model.summary()

In [None]:
lrs = ((1e-4, 10), (1e-5, 10))

full_history = [train_submodel(model, train_ds, val_ds, learning_rate = lr, epochs = epochs) for lr, epochs in lrs]

In [None]:
plot_history(full_history)

model.load_weights(CHECKPOINT + "_full_model")
model.evaluate(test_ds)

cm, predictions, labels = create_confusion_matrix(model, test_ds)
show_confusion_matrix(cm, ["Benign", "Malignant"], 1, "Recall")
show_confusion_matrix(cm, ["Benign", "Malignant"], 0, "Precision")

draw_roc_curve(labels, predictions)

In [None]:
original_model.trainable = True
margin_model.trainable = True
shape_model.trainable = True

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-7), 
    loss=tf.keras.losses.BinaryCrossentropy(label_smoothing=0.25),
    metrics=[
        "accuracy",
        tf.keras.metrics.Precision(),
        tf.keras.metrics.Recall(),
        tf.keras.metrics.AUC()
    ],
)

model.summary()

In [None]:
callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        CHECKPOINT + "_full_model", 
        save_best_only=True,
        monitor='val_loss',
        mode='min',
    ),
]

fine_history = model.fit(
    train_ds,
    validation_data=test_ds,
    epochs=20,
    callbacks=callbacks,
)

In [None]:
model.load_weights(CHECKPOINT + "_full_model")
model.evaluate(test_ds)

cm, predictions, labels = create_confusion_matrix(model, test_ds)
show_confusion_matrix(cm, ["Benign", "Malignant"], 1, "Recall")
show_confusion_matrix(cm, ["Benign", "Malignant"], 0, "Precision")

draw_roc_curve(labels, predictions)

In [None]:
model.save("./exported_model")