In [None]:
##### from sklearn.model_selection import KFold
import numpy as np
import tensorflow as tf
from tensorflow.keras.callbacks import LearningRateScheduler, EarlyStopping
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Assuming dice_coef and iou_coef are already defined functions

# Parameters
K = 5
random_state = 42
batch_size = 32
epochs = 100
size = 256

# Normalize all images once to avoid repeated division
tiles_img_norm = tiles_img / 255.0
tiles_mask = tiles_mask[..., np.newaxis]

# Cross-validation setup
kf = KFold(n_splits=K, shuffle=True, random_state=random_state)

# Arrays to store evaluation metrics
val_accuracies = []
val_losses = []
val_dices = []
val_ious = []

fold = 1
for train_index, val_index in kf.split(tiles_img_norm):
    print(f"\n===== Fold {fold}/{K} =====")
    
    # Split the data
    X_train_cv, X_val_cv = tiles_img_norm[train_index], tiles_img_norm[val_index]
    y_train_cv, y_val_cv = tiles_mask[train_index], tiles_mask[val_index]

    # Data augmentation
    datagen_cv = ImageDataGenerator(
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        rotation_range=20,
        width_shift_range=0.2,
        height_shift_range=0.2,
        brightness_range=[0.8, 1.2]
    )

    # Create model
    model_cv = alexnet_segmentation_model(input_size=(size, size, 3), freeze_encoder=True)

    # Callbacks
    lr_scheduler_cv = LearningRateScheduler(lr_schedule)
    early_stop_cv = EarlyStopping(monitor='val_loss', patience=5, verbose=1, mode='min')
    
    # Train
    history_cv = model_cv.fit(
        datagen_cv.flow(X_train_cv, y_train_cv, batch_size=batch_size),
        epochs=epochs,
        validation_data=(X_val_cv, y_val_cv),
        callbacks=[lr_scheduler_cv, early_stop_cv],
        verbose=1
    )

    # Evaluate
    loss, accuracy, dice, iou = model_cv.evaluate(X_val_cv, y_val_cv, verbose=0)
    val_losses.append(loss)
    val_accuracies.append(accuracy)
    val_dices.append(dice)
    val_ious.append(iou)

    print(f"Fold {fold} — Loss: {loss:.4f}, Acc: {accuracy:.4f}, Dice: {dice:.4f}, IoU: {iou:.4f}")

    fold += 1

# Final summary
print("\n==== Cross-validation results ====")
print(f"Avg Accuracy: {np.mean(val_accuracies):.4f} ± {np.std(val_accuracies):.4f}")
print(f"Avg Dice:     {np.mean(val_dices):.4f} ± {np.std(val_dices):.4f}")
print(f"Avg IoU:      {np.mean(val_ious):.4f} ± {np.std(val_ious):.4f}")
print(f"Avg Loss:     {np.mean(val_losses):.4f} ± {np.std(val_losses):.4f}")
