# Load Packages and init path

## load the packages

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tensorflow.keras import models, layers
from tensorflow.keras import optimizers
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.utils import image_dataset_from_directory

## set the work path

In [None]:
root_path='../'
scenario_path='data/cat_8_binary'

souce_path=root_path+'data/raw_data'
target_path=root_path+scenario_path
print(f'souce_path:{souce_path}')
print(f'target_path:{target_path}')

# Load Raw Data

In [None]:
data = pd.read_csv(f'{souce_path}/ISIC_2019_Training_GroundTruth.csv')
display(data.head())
print(f'the shape of the processed data setis {data.shape}')
data.sum()[1:9].plot.bar();

In [None]:
CATEGORIES = ['MEL', 'NV', 'BCC', 'AK', 'BKL', 'DF', 'VASC', 'SCC']

categories=CATEGORIES
data_by_cat={}
for cat in categories:
    data_by_cat[cat]=data[data[cat]>0]; 
    print(f'shape of data_by_cat[{cat}]: {data_by_cat[cat].shape}')

# Helper functions

## Helper functions for image preprocessing

In [None]:
from PIL import Image

def square_image(image: Image)->Image:
    """
    crop image to sqare
    """
    width, height = image.size
    # Determine the size of the square crop
    size = min(width, height)

    # Calculate the crop coordinates
    left = (width - size) // 2
    top = (height - size) // 2
    right = left + size
    bottom = top + size

    # Perform the crop
    squared_image = image.crop((left, top, right, bottom))

    return squared_image


def detect_black_coners(image: Image)->int:
    """
    return the radius of the circle if detected black coner, otherwise, return -1
    This function needs to be improved for better accuracy if needed
    Args:
        image (Image): The image object which may have black corners.

    Returns:
        int: the radius of circle where is the the black corners from
    """
    width, height=image.size
    scope=min(width, height)//4 # the minimum radius to search for
    step=3  # for better searching performance
    margin=0.98

    if width!=height:
        image = square_image(image)

    b_w_image = image.convert('L')
    for i in range(1, scope, step):
        left=i;right=width-i; top=i; bottom=height-i
        color1=b_w_image.getpixel((top, left))
        color2=b_w_image.getpixel((top, right))
        color3=b_w_image.getpixel((bottom, left))
        color4=b_w_image.getpixel((bottom, right))

        if color1+color2+color3+color4>300:
            break
    if i>1:
        radius = int((width//2-i)*np.sqrt(2)*margin)
        # print(f'black corner detected, the radius is {radius}')
        return radius

    # print(f'No black corner detected')
    return -1

def remove_black_corners(image, radius) -> Image:
    """
    This function remove the black corners of an image

    Args:
        image: The image object to which has black corners.
        radius: the readius of the black corners

    Returns:
        image: image object which the black corners have been removed.
    """
    width, height=image.size

    new_width=int(radius*np.sqrt(2))
    x = (width-new_width)//2
    crop_area=(x, x, width-x, width-x)
    image = image.crop(crop_area)

    return image

def process_and_save_image(source_file, target_file, argument=False):
    # Load the image
    image = Image.open(source_file)
    # Square_Crop
    image = square_image(image)

    # Detect and remove corner edges
    radius = detect_black_coners(image)
    if radius> 0:
        image=remove_black_corners(image, radius=radius)
    # resize image
    image = image.resize((400,400),resample=Image.BILINEAR)
    
    # export image file
    image.save(f'{target_file}.jpg')
    print(f'file {target_file} saved...')
    
    if argument:
        for i in range(3):
            image = image.rotate(90)
            image.save(f'{target_file}_{i}.jpg')
            print(f'file {target_file}_{i}.jpg saved...')
        image = image.transpose(Image.FLIP_LEFT_RIGHT)
        print(f'file {target_file}_{i} saved...')
        

## Helper function(s) for visualization

In [None]:
def plot_history(history, title='', axs=None, exp_name=""):
    f,(ax1,ax2,ax3,ax4) = plt.subplots(1, 4, figsize=(20, 4))
    
    ax1.plot(history.history['loss'], label='train loss')
    ax1.plot(history.history['val_loss'], label='val loss')
    ax1.set_title('Loss')
    ax1.set_ylim(0.0, 100.0)
    ax1.legend()

    ax2.plot(history.history['accuracy'], label='train accuracy')
    ax2.plot(history.history['val_accuracy'], label='val accuracy')
    ax2.set_title('Accuracy')
    ax2.axhline(y=0.9, color='green', linestyle='--')
    ax2.set_ylim(0.5, 1.0)
    ax2.legend()
    
    ax3.plot(history.history['recall'], label='train')
    ax3.plot(history.history['val_recall'], label='val')
    ax3.set_title('recall')
    ax3.axhline(y=0.9, color='green', linestyle='--')
    ax3.set_ylim(0.5, 1.0)
    ax3.legend()

    ax4.plot(history.history['f1_metric'], label='train f1_metric')
    ax4.plot(history.history['val_f1_metric'], label='val f1_metric')
    ax4.set_title('f1_metric')
    ax4.axhline(y=0.9, color='green', linestyle='--')
    ax4.set_ylim(0.5, 1.0)
    ax4.legend()

    return (ax1, ax2, ax3, ax4)

# Data Preprocessing 

In [None]:
def create_folder(folder_path):
    import os
    import shutil

    # Check if the folder already exists
    if os.path.exists(folder_path):
        # Delete the existing folder
        shutil.rmtree(folder_path)

    # Create the folder
    os.makedirs(folder_path)

In [None]:
##########################################################################
# Please call this function only once you need re-prepare all the data
##########################################################################

from sklearn.model_selection import train_test_split

def data_prep(categories=CATEGORIES, binary=True, n_max=1000, test_size=0.2, argument_if_needed=False):
    # Create Folders
    create_folder(target_path)
    if binary:
        create_folder(f'{target_path}/train')
        create_folder(f'{target_path}/val')
        create_folder(f'{target_path}/test')
    else:
        for cat in categories:
            create_folder(f'{target_path}/train/{cat}')
            create_folder(f'{target_path}/val/{cat}')
            create_folder(f'{target_path}/test/{cat}')
    
    for cat in categories:
        data_all=data_by_cat[cat]
        argument_needed= data_by_cat[cat].shape[0]<n_max//4 and argument_if_needed
        if data_by_cat[cat].shape[0]>n_max:
            data_all=data_by_cat[cat].sample(n_max, random_state=42)

        data_train, data_test = train_test_split(data_all, test_size=test_size)
        data_train, data_val = train_test_split(data_train, test_size=test_size)
        
        for index, row in data_train.iterrows():
            source_file = f"{souce_path}/{cat}/{row['image']}.jpg"
            if binary:
                target_file = f"{target_path}/train/{row['image']}"
            else:
                target_file = f"{target_path}/train/{cat}/{row['image']}"
            process_and_save_image(source_file, target_file, argument_needed)
        
        print('======================================================')
        
        for index, row in data_val.iterrows():
            source_file = f"{souce_path}/{cat}/{row['image']}.jpg"
            if binary:
                target_file = f"{target_path}/val/{row['image']}"
            else:
                target_file = f"{target_path}/val/{cat}/{row['image']}"
            process_and_save_image(source_file, target_file, argument_needed)
        
        print('======================================================')
        
        for index, row in data_test.iterrows():
            source_file = f"{souce_path}/{cat}/{row['image']}.jpg"
            if binary:
                target_file = f"{target_path}/test/{row['image']}"
            else:
                target_file = f"{target_path}/test/{cat}/{row['image']}"
            process_and_save_image(source_file, target_file, False)
    return

data_prep(['MEL', 'NV', 'BCC', 'AK', 'BKL', 'DF', 'VASC', 'SCC'], 
           binary=False, 
           n_max=5000, 
           test_size=0.15, 
           argument_if_needed=True)

# Modeling

## preprocess input data

In [None]:
def preprocess_input_data(target_path, preprocess_input_method):
    train_dir = target_path+'/train'
    test_dir = target_path+'/test'
    val_dir = target_path+'/val'

    train_data = image_dataset_from_directory(
        train_dir,
        labels='inferred',
        label_mode='categorical',
        seed=123,
        image_size=(400, 400),
        batch_size=32)
    train_data = train_data.map(lambda x, y: (preprocess_input_method(x), y))

    val_data = image_dataset_from_directory(
        val_dir,
        labels='inferred',
        label_mode='categorical',
        seed=123,
        image_size=(400, 400),
        batch_size=32)
    val_data = val_data.map(lambda x, y: (preprocess_input_method(x), y))

    test_data = image_dataset_from_directory(
        test_dir,
        labels='inferred',
        label_mode='categorical',
        image_size=(400, 400),
        batch_size=32)
    test_data = test_data.map(lambda x, y: (preprocess_input_method(x), y))
    return train_data, val_data, test_data

## Build model

In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
from tensorflow.keras import regularizers

import keras.backend as K

def f1_metric(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    recall = true_positives / (possible_positives + K.epsilon())
    f1_val = 2*(precision*recall)/(precision+recall+K.epsilon())
    return f1_val

def load_model(pre_trained_model):
    model = pre_trained_model(weights='imagenet', include_top=False, input_shape=(400,400,3))
    return model

def set_nontrainable_layers(model, n_layers=50):

    model.trainable = False 
    return model

def add_last_layers(model):
    '''Take a pre-trained model, set its parameters as non-trainable, and add additional trainable layers on top'''
    
    model = models.Sequential(
        [
            layers.Input(shape=(400, 400, 3)),
            layers.RandomZoom(0.25),
            set_nontrainable_layers(model),
            layers.Flatten(),
            layers.Dense(128, activation='relu', kernel_regularizer=regularizers.l1_l2(l1=0.001, l2=0.001)),
#             layers.Dropout(0.125),
            layers.Dense(n_class, activation='softmax' if n_class > 2 else 'sigmoid')
        ]
    ) 
    return model

def build_model(pre_trained_model, n_class=2):
       
    model = load_model(pre_trained_model)
    model = add_last_layers(model)
    
    opt = optimizers.legacy.Adam(learning_rate=1e-6)
    model.compile(loss='categorical_crossentropy' if n_class>2 else 'binary_crossentropy',
                  optimizer=opt,
                  metrics=['accuracy', 'Recall',f1_metric])
    return model


## Instance model and input data

In [None]:
from tensorflow.keras.applications import ResNet50, resnet50
pre_processing_method = resnet50.preprocess_input
pre_trained_model = ResNet50

##############################
path='../data/cat_7_binary'
n_class=2
##############################

train_data, val_data, test_data = preprocess_input_data(
    target_path = path, 
    preprocess_input_method = pre_processing_method
)
model = build_model(pre_trained_model, n_class)

num_classes = train_data.element_spec[1].shape[1]
num_batchs = len(train_data)

scenario_name=f'{pre_trained_model.__name__}_{path}'
print(scenario_name)
model.summary()

## Train the model

In [None]:
import timeit
start_time = timeit.default_timer()

class MyModelCheckpoint(ModelCheckpoint):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.best_monitored_value = None

    def on_epoch_end(self, epoch, logs=None):
        # Access the monitored metric value
        monitored_value = logs[self.monitor]

        # Check if the current value is better than the best value so far
        if self.best_monitored_value is None or monitored_value > self.best_monitored_value:
            self.best_monitored_value = monitored_value

            # Modify the filename by appending the best monitored metric value
            filepath = self.filepath.format(epoch=epoch, **logs)
            filepath = filepath[:-3] + f'_{self.monitor}_{self.best_monitored_value:.4f}.h5'

            # Save the model with the updated filepath
            self.model.save(filepath, overwrite=True)
        
mcp = ModelCheckpoint(
    f"../models/{scenario_name}.h5",
    save_weights_only=False,
    monitor='val_recall',
    mode='max',
    verbose=0,
    save_best_only=True
)
lr = ReduceLROnPlateau(
    monitor="val_loss",
    factor=0.1,
    patience=2,
    verbose=1,
    min_lr=0
)
es = EarlyStopping(
    monitor = 'val_accuracy', 
    mode = 'max', 
    patience = 10, 
    verbose = 1, 
    restore_best_weights = True
)
num_epochs = 100

history = model.fit(
    train_data,
    epochs=num_epochs,
    callbacks=[es, lr, mcp],
    validation_data=val_data,
    batch_size=32,
    verbose=1
)

end_time = timeit.default_timer()
execution_time = end_time - start_time

In [None]:
print(f'scenario: {scenario_name}_{path}, training time(min):{execution_time//60}')
print("""
layers.RandomZoom(0.5)
dense_layer = layers.Dense(128, activation='relu')
opt = optimizers.legacy.Adam(learning_rate=1e-6)
""")
model.evaluate(test_data)
plot_history(history);