# Imports

In [None]:
import os
import numpy as np
from datetime import datetime
import pytz
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from tensorflow import keras as tfk
from tensorflow.keras import layers as tfkl
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.optimizers import Lion
from tensorflow.keras.models import load_model

import folding as fold

now = datetime.now(pytz.timezone("Europe/Rome")).strftime("Day: %Y-%m-%d - Time: %H:%M:%S\n")
print(f"\033[1;94m{now}\033[0m")

# Dataset

In [None]:
now = datetime.now(pytz.timezone("Europe/Rome")).strftime("Day: %Y-%m-%d - Time: %H:%M:%S\n")
print(f"\033[1;94m{now}\033[0m")

data = np.load("/kaggle/input/bloodcells-augmented/dataset_augmented.npz")
# data = np.load("/kaggle/input/bloodcells-maximisation/dataset_maximisation.npz")
test = np.load("/kaggle/input/bloodcells-evaluation/dataset_evaluation.npz")

images, labels = data['images'], data['labels']
test_images, test_labels = test['images'], test['labels']

# Set model and parameters

In [None]:
now = datetime.now(pytz.timezone("Europe/Rome")).strftime("Day: %Y-%m-%d - Time: %H:%M:%S\n")
print(f"\033[1;94m{now}\033[0m")

# Trained model
model_name = 'BrainRot'
model_dir = ""

# Base model
base_model_type = tfk.applications.ConvNeXtSmall
base_model_name = "convnext_small"
neurons_first_dense_layer = 128
neurons_second_dense_layer = 32

# Pre-training parameters
learning_rate_pretrain = 1e-4
batch_size_pretrain = 64

# Fine tuning parameters
learning_rate_tuning = 5e-5
batch_size_tuning = 64
tuning_steps = 3
learning_rate_multiplier = 1/6
val_size = 0.2

# Model build

In [None]:
now = datetime.now(pytz.timezone("Europe/Rome")).strftime("Day: %Y-%m-%d - Time: %H:%M:%S\n")
print(f"\033[1;94m{now}\033[0m")

# Current problem parameters
input_shape = [96, 96, 3]
num_classes = 8

# Load model with pre-trained weights (excluding top layers)
base_model = base_model_type(
    include_top=False,
    weights="imagenet",
    input_shape=input_shape,
    classes=num_classes,
    classifier_activation="softmax",
)

# Model blueprint
inputs = tfkl.Input(shape=input_shape)
x = base_model(inputs, training=False)
x = tfkl.GlobalAveragePooling2D()(x)
x = tfkl.Dense(neurons_first_dense_layer, activation='gelu')(x)
x = tfkl.Dropout(0.2)(x)
x = tfkl.Dense(neurons_second_dense_layer, activation='gelu')(x)
x = tfkl.Dropout(0.2)(x)
outputs = tfkl.Dense(num_classes, activation='softmax')(x)

# Build the final model
model = tf.keras.Model(inputs=inputs, outputs=outputs)

# Compile the model with sparse categorical crossentropy loss
model.compile(
    optimizer=Lion(learning_rate=1e-5),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# Display its architecture
model.summary()
# model.get_layer(base_model_name).summary()

# Pre-training

In [None]:
now = datetime.now(pytz.timezone("Europe/Rome")).strftime("Day: %Y-%m-%d - Time: %H:%M:%S\n")
print(f"\033[1;94m{now}\033[0m")

# Dataset split
validation, training = fold.split_set(data, starting=0, size=val_size)
train_images, train_labels = training['images'], training['labels']
val_images, val_labels = validation['images'], validation['labels']

# Save checkpoint
checkpoint = ModelCheckpoint(
    model_dir + "Topping.keras",
    save_best_only=True,
    monitor='val_loss',
    mode='min',
    verbose=1
)

# Train only the top layers (freeze convolutional base layers)
for layer in model.layers:
    layer.trainable = True
for layer in model.get_layer(base_model_name).layers:
    layer.trainable = False

# Recompile the model after freezing the convolutional layers
model.compile(
    optimizer = Lion(learning_rate=learning_rate_pretrain),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# Train the model (only the top layers are trained at this stage)
model.fit(
    x=train_images,
    y=train_labels,
    validation_data=(val_images, val_labels),
    epochs=20,
    batch_size=batch_size_pretrain,
    callbacks=[EarlyStopping(patience=3, restore_best_weights=True), checkpoint]
)

# Clear keras session to avoid memory build up
tf.keras.backend.clear_session()
print("Finish!")

# Fine tuning

In [None]:
now = datetime.now(pytz.timezone("Europe/Rome")).strftime("Day: %Y-%m-%d - Time: %H:%M:%S\n")
print(f"\033[1;94m{now}\033[0m")

# Clear keras session to avoid memory build up
tf.keras.backend.clear_session()

# Load model
model = load_model(model_dir + "Topping.keras")

for ITERATION in range(1, tuning_steps+1):
    print(f"Iteration {ITERATION}/{tuning_steps}")
    
    # Start by unfreezing all layers but the base model's
    for layer in model.layers:
        layer.trainable = True
    model.get_layer(base_model_name).trainable = False
    base_model_layers = model.get_layer(base_model_name).layers
    total_layers = len(base_model_layers)
    
    # Dataset split
    starting_image = (val_size * ITERATION) % 1
    validation, training = fold.split_set(data, starting=starting_image, size=val_size)
    
    # Decide which layers to unfreeze
    train_layer = int(total_layers * ITERATION/tuning_steps)
    print(f"Unfreezing layers {total_layers - train_layer}-{total_layers}")
    
    # Unfreeze only last layers
    for layer in base_model_layers[:-train_layer]:
        layer.trainable = True
    
    # Recompile the model after unfreezing the convolutional layers
    model.compile(
        optimizer=Lion(
            learning_rate=learning_rate_tuning * (learning_rate_multiplier ** (ITERATION-1))
        ),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )
    
    # Fine-tune the entire model (including the convolutional layers)
    model.fit(
        x=training['images'],
        y=training['labels'],
        validation_data=(validation['images'], validation['labels']),
        epochs=1,
        batch_size=batch_size_tuning,
        callbacks=[EarlyStopping(patience=3, restore_best_weights=True)]
    )

model.save(model_dir + model_name + ".keras",)
print("Finish!")

# Model Test

In [None]:
now = datetime.now(pytz.timezone("Europe/Rome")).strftime("Day: %Y-%m-%d - Time: %H:%M:%S\n")
print(f"\033[1;94m{now}\033[0m")

model = load_model(model_dir + model_name + ".keras")
test_loss, test_accuracy = model.evaluate(test_images, test_labels, verbose=1)
print(f"Main model: {model_name}")
print(f"  Test Loss: {test_loss:.4f}")
print(f"  Test Accuracy: {test_accuracy:.4f}")