# **TRAINING**

### Initial configurations

In [None]:
# Import necessary libraries
import os
import warnings
import logging
import random
import numpy as np
import tensorflow as tf
from tensorflow import keras as tfk
from tensorflow.keras import layers as tfkl
from tensorflow.keras import Model as tfkModel
import matplotlib.pyplot as plt
import pandas as pd
from keras.utils import register_keras_serializable
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score, f1_score
import seaborn as sns
from keras.callbacks import Callback
import IPython.display as display
from PIL import Image
import matplotlib.gridspec as gridspec
import json
import keras_cv
from tqdm import tqdm

In [None]:
tfk.mixed_precision.set_global_policy("mixed_bfloat16")

In [None]:
# Configure plot display settings
sns.set(font_scale=1.4)
sns.set_style('white')
plt.rc('font', size=14)
%matplotlib inline

### Set accelerator

In [None]:
def auto_select_accelerator():
    """
    Reference:
        * https://www.kaggle.com/mgornergoogle/getting-started-with-100-flowers-on-tpu
        * https://www.kaggle.com/xhlulu/ranzcr-efficientnet-tpu-training
    """
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.TPUStrategy(tpu)
        print("Running on TPU:", tpu.master())
    except ValueError:
        strategy = tf.distribute.get_strategy()
    print(f"Running on {strategy.num_replicas_in_sync} replicas")

    return strategy

In [None]:
# Setting che correct strategy for TPU / batch sizes
strategy = auto_select_accelerator()
numGPU = len(tf.config.list_physical_devices('GPU'))
numTPU = len(tf.config.list_logical_devices('TPU'))
print("Num GPUs Available: ", numGPU)
print("Num TPUs Available: ", numTPU)

In [None]:
batch_size = 32
if numTPU != 0:
    batch_size = strategy.num_replicas_in_sync * 32

print(f"Batch size: {batch_size}")

## **DATA PREPROCESSING**

In [None]:
train_path = "/kaggle/input/blood-cells-augmented/8_nocleanval_balanced_heavy_full_and_augMix_training_data.npz"
val_path = "/kaggle/input/blood-cells-augmented/8_noclean_balanced_heavy_full_and_augMix_validation_data.npz"
print(f"reading {train_path}")
print(f"reading {val_path}")

In [None]:
data_train = np.load(train_path, allow_pickle=True)
X_train = data_train['images']
y_train = data_train['labels']

data_val = np.load(val_path, allow_pickle=True)
X_val = data_val['images']
y_val = data_val['labels']

In [None]:
# Plot 10 random images from X_train
plt.figure(figsize=(15, 10))
for i in range(10):
    plt.subplot(2, 5, i + 1)
    random_idx = np.random.randint(0, X_train.shape[0])
    plt.imshow(X_train[random_idx])
    plt.title(f"Label: {np.argmax(y_train[random_idx])}")
    plt.axis('off')
plt.show()

In [None]:
# Plot 10 random images from X_train
plt.figure(figsize=(15, 10))
for i in range(10):
    plt.subplot(2, 5, i + 1)
    random_idx = np.random.randint(0, X_val.shape[0])
    plt.imshow(X_val[random_idx])
    plt.title(f"Label: {np.argmax(y_val[random_idx])}")
    plt.axis('off')
plt.show()

In [None]:
# Print the shapes of the loaded datasets
print("Training Data Shape:", X_train.shape)
print("Training Label Shape:", y_train.shape)
print("Validation Data Shape:", X_val.shape)
print("Validation Label Shape:", y_val.shape)

In [None]:
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)).cache().shuffle(65536).batch(batch_size).repeat().prefetch(tf.data.AUTOTUNE)
val_dataset = tf.data.Dataset.from_tensor_slices((X_val, y_val)).cache().shuffle(4096).batch(batch_size).prefetch(tf.data.AUTOTUNE)

### Custom callbacks

In [None]:
# Custom implementation of ReduceLROnPlateau
class CustomReduceLROnPlateau(tf.keras.callbacks.Callback):
    def __init__(self, monitor='val_accuracy', factor=0.33, patience=20, min_lr=1e-8, verbose=1):
        super(CustomReduceLROnPlateau, self).__init__()
        self.monitor = monitor
        self.factor = factor
        self.patience = patience
        self.min_lr = min_lr
        self.verbose = verbose
        self.wait = 0
        self.best = None
        self.new_lr = None

    def on_epoch_end(self, epoch, logs=None):
        current = logs.get(self.monitor)
        
        # Initialize best metric if it's the first epoch
        if self.best is None:
            self.best = current
            return

        # Check if the monitored metric has improved
        if current > self.best:
            self.best = current
            self.wait = 0
        else:
            self.wait += 1

            # If patience is exceeded, reduce the learning rate
            if self.wait >= self.patience:
                old_lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate))
                if old_lr == self.min_lr:
                    return
                self.new_lr = max(old_lr * self.factor, self.min_lr)
                self.model.optimizer.learning_rate.assign(self.new_lr)
                
                if self.verbose > 0:
                    print(f"\nEpoch {epoch + 1}: reducing learning rate to {self.new_lr}.")
                
                self.wait = 0  # Reset patience counter

In [None]:
# Custom callback class for real-time plotting
class RealTimePlot(Callback):
    def on_train_begin(self, logs=None):
        # Initialize the lists that will store the metrics
        self.epochs = []
        self.train_loss = []
        self.val_loss = []
        self.train_acc = []
        self.val_acc = []

        # Set up the plot
        self.fig, (self.ax_loss, self.ax_acc) = plt.subplots(1, 2, figsize=(14, 5))
        plt.show()

    def on_epoch_end(self, epoch, logs=None):
        # Append the metrics to the lists
        self.epochs.append(epoch)
        self.train_loss.append(logs['loss'])
        self.val_loss.append(logs['val_loss'])
        self.train_acc.append(logs['categorical_accuracy'])
        self.val_acc.append(logs['val_categorical_accuracy'])

        # Clear the previous output
        display.clear_output(wait=True)

        # Plot training and validation loss
        self.ax_loss.clear()
        self.ax_loss.plot(self.epochs, self.train_loss, label='Training Loss')
        self.ax_loss.plot(self.epochs, self.val_loss, label='Validation Loss')
        self.ax_loss.set_title('Training and Validation Loss')
        self.ax_loss.set_xlabel('Epoch')
        self.ax_loss.set_ylabel('Loss')
        #self.ax_loss.set_ylim(top=2.5, bottom=0.0)
        self.ax_loss.legend()

        # Plot training and validation accuracy
        self.ax_acc.clear()
        self.ax_acc.plot(self.epochs, self.train_acc, label='Training Accuracy')
        self.ax_acc.plot(self.epochs, self.val_acc, label='Validation Accuracy')
        self.ax_acc.set_title('Training and Validation Accuracy')
        self.ax_acc.set_xlabel('Epoch')
        self.ax_acc.set_ylabel('Accuracy')
        self.ax_acc.legend()

        # Redraw the updated plots
        display.display(self.fig)
        plt.pause(0.1)

In [None]:
class DisplayLearningRateCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        # Get the current learning rate from the optimizer and display it
        lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate))
        print(f"Epoch {epoch+1} : Learning rate = {tf.keras.backend.get_value(lr)}")

### Model definition

In [None]:
# Input shape for the model
input_shape = X_train.shape[1:]

# Output shape for the model
output_shape = y_train.shape[1]

steps_per_epoch = y_train.shape[0] // batch_size

print("Input Shape: ", input_shape)
print("Output Shape: ", output_shape)
print("Steps per epoch: ", steps_per_epoch)

In [None]:
@register_keras_serializable()
class CustomCastLayer(tfk.layers.Layer):
    def call(self, inputs):
        return tf.cast(inputs * 255, tf.uint8)

@register_keras_serializable()
class CustomAugmentLayer(tfk.layers.Layer):
    def __init__(self, max_rotation=30.0, max_zoom=0.2, **kwargs):
        super(CustomAugmentLayer, self).__init__(**kwargs)
        self.max_rotation = max_rotation / 360.0
        self.max_zoom = max_zoom
        
    def call(self, inputs, training=False):
        if training:
            inputs = tf.image.random_flip_up_down(tf.image.random_flip_left_right(inputs))
        return inputs


In [None]:
def build_model(
            shape=input_shape, 
            n_labels=output_shape, 
            base_model_trainable=False, #standard definitions
            n_dense_layers=1, 
            initial_dense_neurons=1024, 
            min_neurons=64, # architecture definitions
            include_dropout=True, 
            dropout_rate=0.3, 
            l2_lambda=4e-3, # against overfitting
            learning_rate=1e-3,
            mult_next_layer = 1/2,
            include_batch_normalization = True):
    
    # The input layer
    inputs = tfkl.Input(shape=input_shape, name='Input')   
    
    # The two augmentation layers
    x = CustomCastLayer()(inputs)
    x = CustomAugmentLayer()(x, training=True)

    # The convnext layer with include top=False to take the convolutional part only
    base_model = tfk.applications.ConvNeXtXLarge(
                input_shape=input_shape,
                weights='imagenet',
                include_top=False
            )

    # Here we freeze the convnext to perform Tranfer Learning
    base_model.trainable = base_model_trainable

    x = base_model(x)
    x = tfkl.BatchNormalization(name="BatchNorm_After_ConvNeXt")(x) if include_batch_normalization else x # BatchNorm after ConvNeXt
    x = tfkl.GlobalAveragePooling2D()(x)

    # Hidden layers building
    neurons = initial_dense_neurons
    for k in range(n_dense_layers):
        x = tfkl.Dense(units=neurons, activation=None, name=f'Dense_layer_{k}', 
                       kernel_regularizer=tfk.regularizers.L2(l2_lambda))(x)
        x = tfkl.BatchNormalization(name=f'BatchNorm_Dense_layer_{k}')(x) if include_batch_normalization else x   # BatchNorm in dense layer
        x = tfkl.Activation('silu', name=f'Activation_layer_{k}')(x)  # Apply activation after BatchNorm
        if include_dropout:
            x = tfkl.Dropout(dropout_rate, name=f'Dropout_layer_{k}')(x)
        neurons = int(neurons * mult_next_layer)

    outputs = tfkl.Dense(output_shape, activation='softmax', name='output_layer')(x)

    # Final model building
    model = tfk.Model(inputs=inputs, outputs=outputs, name='TF-CNN')

    # Compile the model
    loss = tfk.losses.CategoricalFocalCrossentropy(
                                                alpha=0.25,
                                                gamma=2.0,
                                                from_logits=False,
                                                label_smoothing=0.0,
                                                axis=-1,
                                                reduction="sum_over_batch_size",
                                                name="categorical_focal_crossentropy",
                                                dtype=None,
                                            )
    # Metrics definition
    METRICS = [tfk.metrics.CategoricalAccuracy()]
    optimizer = tf.keras.optimizers.AdamW(
                                learning_rate=learning_rate,
                                weight_decay=l2_lambda,
                                beta_1=0.9,
                                beta_2=0.999,
                                epsilon=1e-07,
                                amsgrad=False,
                                use_ema=False,
                                ema_momentum=0.99,
                                name="adamw"
                            )
                                
    model.compile(loss=loss, optimizer=optimizer, metrics=METRICS)

    # Return the model
    return model


## **TRANSFER LEARNING**

In [None]:
# Best values found so far
n_dense_layers = 5
initial_dense_neurons = 1943
dropout_rate = 0.4
l2_lambda = 5e-5
learning_rate = 1.59608e-5
mult_next_layer = 0.44626
include_batch_normalization = True

epochs = 500

In [None]:
# Build the model with specified input and output shapes
with strategy.scope():
    model = build_model(
            base_model_trainable=False,
            n_dense_layers=n_dense_layers,
            initial_dense_neurons=initial_dense_neurons,
            include_dropout=True,
            dropout_rate=dropout_rate,
            l2_lambda=l2_lambda,
            learning_rate=learning_rate,
            mult_next_layer=mult_next_layer,
            include_batch_normalization = include_batch_normalization
        )

# Display a summary of the model architecture
model.summary(expand_nested=False, show_trainable=True)

In [None]:
# Define the patience value for early stopping
patience = 50

# Create an EarlyStopping callback
early_stopping = tfk.callbacks.EarlyStopping(
    monitor='val_loss',
    mode='min',
    patience=patience,
    restore_best_weights=True
)

lr_reducer = CustomReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=30, min_lr=1e-8)
plot_callback = RealTimePlot()

# Store the callback in a list
callbacks = [early_stopping, plot_callback, lr_reducer, DisplayLearningRateCallback()]

In [None]:
# Train the model with early stopping callback
history = model.fit(
    train_dataset,
    epochs=epochs,
    steps_per_epoch=steps_per_epoch,
    validation_data=val_dataset,
    shuffle=True,
    callbacks=callbacks
).history

In [None]:
# Plot training and validation loss
plt.figure(figsize=(15, 5))
plt.plot(history['loss'], label='Training loss', alpha=.8)
plt.plot(history['val_loss'], label='Validation loss', alpha=.8)
plt.ylim(top=2.5, bottom=0.0)
plt.title('Loss')
plt.legend()
plt.grid(alpha=.3)

# Plot training and validation accuracy
plt.figure(figsize=(15, 5))
plt.plot(history['categorical_accuracy'], label='Training accuracy', alpha=.8)
plt.plot(history['val_categorical_accuracy'], label='Validation accuracy', alpha=.8)
plt.title('Accuracy')
plt.legend()
plt.grid(alpha=.3)
plt.show()

In [None]:
# Save the trained model to a file with the accuracy included in the filename
with strategy.scope():
    model_weights_filename = 'ADAM_HEAVY_AUG_MODEL.weights.h5'
    model.save_weights(model_weights_filename)

In [None]:
LABELS = [0, 1, 2, 3, 4, 5, 6, 7]

In [None]:
def evaluations(model, ds, y_ds, labels, name):
    # Predict class probabilities and get predicted classes
    ds_predictions = model.predict(ds, verbose=0)
    ds_predictions = np.argmax(ds_predictions, axis=-1)
    
    # Extract ground truth classes
    ds_gt = np.argmax(y_ds, axis=-1)
    
    # Calculate and display training set accuracy
    ds_accuracy = accuracy_score(ds_gt, ds_predictions)
    print(f'Accuracy score over the {name} set: {round(ds_accuracy, 4)}')
    
    # Calculate and display training set precision
    ds_precision = precision_score(ds_gt, ds_predictions, average='weighted')
    print(f'Precision score over the {name} set: {round(ds_precision, 4)}')
    
    # Calculate and display training set recall
    ds_recall = recall_score(ds_gt, ds_predictions, average='weighted')
    print(f'Recall score over the {name} set: {round(ds_recall, 4)}')
    
    # Calculate and display training set F1 score
    ds_f1 = f1_score(ds_gt, ds_predictions, average='weighted')
    print(f'F1 score over the {name} set: {round(ds_f1, 4)}')
    
    # Compute the confusion matrix
    cm = confusion_matrix(ds_gt, ds_predictions)
    
    # Create labels combining confusion matrix values
    labels = np.array([f"{num}" for num in cm.flatten()]).reshape(cm.shape)
    
    # Plot the confusion matrix with class labels
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=labels, fmt='', xticklabels=labels, yticklabels=labels, cmap='Blues')
    plt.xlabel('True labels')
    plt.ylabel('Predicted labels')
    plt.show()

In [None]:
evaluations(model, ds=X_val, y_ds=y_val, labels=LABELS, name='validation')

In [None]:
# Get the weights of the last two dense layers and the output layer
dense_layer_1_weights = model.get_layer('Dense_layer_0').get_weights()
dense_layer_2_weights = model.get_layer('Dense_layer_1').get_weights()
dense_layer_3_weights = model.get_layer('Dense_layer_2').get_weights()
output_layer_weights = model.get_layer('output_layer').get_weights()

# Save the weights to files
np.savez('dense_layer_1_weights.npz', *dense_layer_1_weights)
np.savez('dense_layer_2_weights.npz', *dense_layer_2_weights)
np.savez('dense_layer_3_weights.npz', *dense_layer_3_weights)
np.savez('output_layer_weights.npz', *output_layer_weights)

## Fine tunning

In [None]:
model.get_layer('convnext_xlarge').trainable = True  # Base model

# Unfreeze only specific layers
trainable_layers_count = 0
for layer in model.get_layer('convnext_xlarge').layers:
    if 'stage_4' in layer.name or 'stage_3' in layer.name:  # Adjust based on architecture
        layer.trainable = True
        trainable_layers_count+=1
    else:
        layer.trainable = False

print(trainable_layers_count)

# Replace Dropout layers with higher rates
def increase_dropout(model, new_rate):
    for layer in model.layers:
        if isinstance(layer, tf.keras.layers.Dropout):
            layer.rate = new_rate
    return model

model = increase_dropout(model, new_rate=0)

with strategy.scope():
    # Recompile with weight decay
    model.compile(
        optimizer=tf.keras.optimizers.AdamW(learning_rate=1e-5),  # Adjust weight decay here
        loss=tf.keras.losses.CategoricalCrossentropy(),
        metrics=[tf.keras.metrics.CategoricalAccuracy()]
    )

# Train the model
history = model.fit(
    train_dataset,
    epochs=epochs,
    steps_per_epoch=steps_per_epoch,
    validation_data=val_dataset,
    shuffle=True,
    callbacks=callbacks
)

In [None]:
# Save the trained model to a file with the accuracy included in the filename
with strategy.scope():
    model_weights_filename = 'FINE_ADAM_HEAVY_AUG_MODEL.weights.h5'
    model.save_weights(model_weights_filename)

In [None]:
evaluations(model, ds=X_val, y_ds=y_val, labels=LABELS, name='validation')