In [None]:
import os
import cv2
import numpy as np
import tifffile
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

import tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, Concatenate, BatchNormalization, Activation
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import LearningRateScheduler, ModelCheckpoint, EarlyStopping
from tensorflow.keras.preprocessing.image import ImageDataGenerator


In [None]:
def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = tf.keras.backend.flatten(y_true)
    y_pred_f = tf.keras.backend.flatten(y_pred)
    intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f) + smooth)

def iou_coef(y_true, y_pred, smooth=1):
    y_true_f = tf.keras.backend.flatten(y_true)
    y_pred_f = tf.keras.backend.flatten(y_pred)
    intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
    union = tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f) - intersection
    return (intersection + smooth) / (union + smooth)


In [None]:
# Function to split images into tiles
def split_image_into_tiles(image_path, mask_path, tile_size, size):
    img = tifffile.imread(image_path)
    mask = tifffile.imread(mask_path)
    mask = mask[:, :, 0] if len(mask.shape) == 3 else mask

    tiles_img, tiles_mask = [], []
    for x in range(0, img.shape[1], tile_size):
        for y in range(0, img.shape[0], tile_size):
            tile_img = img[y:y+tile_size, x:x+tile_size, :]
            tile_mask = mask[y:y+tile_size, x:x+tile_size]

            tile_img = cv2.resize(tile_img, (size, size))
            tile_mask = cv2.resize(tile_mask, (size, size))
            tile_mask = (tile_mask > 0).astype(np.uint8)

            tiles_img.append(tile_img)
            tiles_mask.append(tile_mask)

    return np.array(tiles_img), np.array(tiles_mask)

# Load dataset
def load_data(image_dir, mask_dir, tile_size=256, size=256):
    images, masks = [], []
    image_filenames = sorted(os.listdir(image_dir))
    mask_filenames = sorted(os.listdir(mask_dir))

    for image_filename in image_filenames:
        if image_filename.endswith(".TIF"):
            mask_filename = image_filename.replace(".TIF", "_mask.TIF")
            if mask_filename in mask_filenames:
                img_path = os.path.join(image_dir, image_filename)
                mask_path = os.path.join(mask_dir, mask_filename)
                img, mask = split_image_into_tiles(img_path, mask_path, tile_size, size)
                images.extend(img)
                masks.extend(mask)
    return np.array(images), np.array(masks)

# Paths
image_dir = "../../datasets/images"
mask_dir = "../../datasets/masks"
size = 256

# Load data
tiles_img, tiles_mask = load_data(image_dir, mask_dir, tile_size=size, size=size)

# Split sets
X_train, X_test, y_train, y_test = train_test_split(tiles_img, tiles_mask, test_size=0.2, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=42)

# Reshape masks
y_train = y_train[..., np.newaxis]
y_val = y_val[..., np.newaxis]
y_test = y_test[..., np.newaxis]


In [None]:
from tensorflow.keras.applications import DenseNet121
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, Concatenate
from tensorflow.keras.optimizers import Adam
import os
import numpy as np
import matplotlib.pyplot as plt
import cv2
import tifffile
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, BatchNormalization, Conv2DTranspose
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, LearningRateScheduler
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split



# Define DenseNet U-Net model
def densenet_unet_model(input_size=(size, size, 3), freeze_encoder=True):
    # Use DenseNet121 as the backbone
    densenet_base = DenseNet121(weights='imagenet', include_top=False, input_shape=input_size)

    # Encoder
    encoder_output = densenet_base.get_layer('conv5_block16_concat').output
    
    # Freeze the layers in the encoder
    if freeze_encoder:
        for layer in densenet_base.layers:
            layer.trainable = False

    # Decoder
    x = Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(encoder_output)
    x = Concatenate()([x, densenet_base.get_layer('conv4_block24_concat').output])
    x = Conv2D(512, (3, 3), activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(x)
    x = Concatenate()([x, densenet_base.get_layer('conv3_block12_concat').output])
    x = Conv2D(256, (3, 3), activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(x)
    x = Concatenate()([x, densenet_base.get_layer('conv2_block6_concat').output])
    x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(x)
    x = Concatenate()([x, densenet_base.get_layer('conv1/conv').output])
    x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    # Output layer
    output = Conv2D(1, (1, 1), activation='sigmoid', padding='same')(x)

    # Resize the output to match the size of the ground truth masks
    output = tf.image.resize(output, (size, size), method='bilinear')

    model = Model(inputs=densenet_base.input, outputs=output)

    #model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    model.compile(
        optimizer=Adam(learning_rate=1e-4),
        loss='binary_crossentropy',
        metrics=['accuracy', dice_coef, iou_coef]
    )

    model.summary()

    return model
# Usage
size = 256
model = densenet_unet_model(input_size=(size, size, 3), freeze_encoder=True)
model.summary()


In [None]:
# LR schedule
def lr_schedule(epoch):
    initial_lr = 1e-4
    decay = 0.9
    return initial_lr * (decay ** (epoch // 10))

lr_scheduler = LearningRateScheduler(lr_schedule)

# Data augmentation
datagen = ImageDataGenerator(rescale=1./255,
                             shear_range=0.2,
                             zoom_range=0.2,
                             horizontal_flip=True,
                             rotation_range=20,
                             width_shift_range=0.2,
                             height_shift_range=0.2,
                             brightness_range=[0.8, 1.2])

# Callbacks
checkpointer = ModelCheckpoint("best_densenet_unet.h5", monitor="val_dice_coef", mode="max",
                               save_best_only=True, verbose=1)
earlyStopping = EarlyStopping(monitor="val_dice_coef", patience=5, mode="max", verbose=1)


In [None]:
history = model.fit(datagen.flow(X_train, y_train, batch_size=32),
                    validation_data=(X_val/255.0, y_val),
                    epochs=50,
                    callbacks=[lr_scheduler, earlyStopping, checkpointer])


In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score
from sklearn.utils import resample



def bootstrap_confidence_interval(y_true, y_pred, metric_fn, n_bootstraps=50, alpha=0.95):
    """Bootstrap CI + std for a given metric"""
    stats = []
    n = len(y_true)
    for _ in range(n_bootstraps):
        indices = np.random.randint(0, n, n)
        if metric_fn.__name__ == "roc_auc_score":  # ROC-AUC requires probs
            stat = metric_fn(y_true[indices], y_pred[indices])
        else:  # Binary metrics
            stat = metric_fn(y_true[indices], y_pred[indices])
        stats.append(stat)
    
    stats = np.array(stats)
    mean_val = np.mean(stats)
    std_val  = np.std(stats)
    lower = np.percentile(stats, ((1 - alpha) / 2) * 100)
    upper = np.percentile(stats, (alpha + (1 - alpha) / 2) * 100)
    
    return mean_val, std_val, (lower, upper)


# --- Evaluate model ---
loss, acc, dice, iou = model.evaluate(X_test/255.0, y_test)
print(f"Test Loss: {loss:.4f}, Accuracy: {acc:.4f}, Dice: {dice:.4f}, IoU: {iou:.4f}")

# --- Predictions ---
y_pred = model.predict(X_test/255.0)
y_pred_bin = (y_pred > 0.5).astype(np.uint8)

# Flatten
y_true_flat = y_test.flatten()
y_pred_flat = y_pred_bin.flatten()
y_pred_probs = y_pred.flatten()

# --- Metrics ---
metrics = {
    "Accuracy": lambda yt, yp: np.mean(yt == yp),
    "Precision": lambda yt, yp: precision_score(yt, yp),
    "Recall": lambda yt, yp: recall_score(yt, yp),
    "F1-score": lambda yt, yp: f1_score(yt, yp),
    "ROC-AUC": lambda yt, yp: roc_auc_score(yt, yp),
    "Dice": lambda yt, yp: (2*np.sum(yt*yp))/(np.sum(yt)+np.sum(yp)+1e-7),
    "IoU": lambda yt, yp: np.sum(yt*yp)/(np.sum(yt)+np.sum(yp)-np.sum(yt*yp)+1e-7)
}

print("\nðŸ“Š Metrics with 95% Confidence Intervals:")
for name, fn in metrics.items():
    if name == "ROC-AUC":
        mean_val, std_val, (low, high) = bootstrap_confidence_interval(y_true_flat, y_pred_probs, fn)
    else:
        mean_val, std_val, (low, high) = bootstrap_confidence_interval(y_true_flat, y_pred_flat, fn)
    print(f"{name}: {mean_val:.4f} Â± {std_val:.4f}  (95% CI: {low:.4f} â€“ {high:.4f})")


In [None]:
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

# Flatten ground truth and predictions
y_true_flat = y_test.flatten()
y_pred_flat = (y_pred.flatten() > 0.5).astype(int)  # threshold at 0.5

# Confusion Matrix
cm = confusion_matrix(y_true_flat, y_pred_flat)
print("Confusion Matrix:\n", cm)

# Classification Report
print(classification_report(y_true_flat, y_pred_flat))

# Heatmap
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=["Non-Forest","Forest"],
            yticklabels=["Non-Forest","Forest"])
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.show()


In [None]:
# Plot training history for loss, accuracy, dice, iou
plt.figure(figsize=(12, 6))
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Val Accuracy')

plt.xlabel('Epoch')
plt.ylabel('Metrics')
plt.title('Resnet Segmentation Training History')
plt.legend()
plt.show()

In [None]:
import matplotlib.pyplot as plt

# Epochs: 1 to 17 (early stopping at 17)
epochs = range(1, 18)

# Training metrics
train_loss = [
    0.5527, 0.3876, 0.3325, 0.3180, 0.3140, 0.3094, 0.3067, 0.2958, 0.3021,
    0.2911, 0.2939, 0.2940, 0.2890, 0.2801, 0.2838, 0.2831, 0.2860
]

train_accuracy = [
    0.7705, 0.8303, 0.8567, 0.8621, 0.8624, 0.8648, 0.8671, 0.8710, 0.8671,
    0.8712, 0.8677, 0.8683, 0.8720, 0.8742, 0.8731, 0.8723, 0.8748
]

train_dice = [
    0.6196, 0.7474, 0.7922, 0.8061, 0.8046, 0.8096, 0.8120, 0.8223, 0.8167,
    0.8271, 0.8253, 0.8202, 0.8267, 0.8323, 0.8298, 0.8297, 0.8303
]

train_iou = [
    0.4520, 0.5984, 0.6573, 0.6774, 0.6761, 0.6825, 0.6856, 0.6995, 0.6926,
    0.7071, 0.7040, 0.6975, 0.7066, 0.7145, 0.7111, 0.7117, 0.7117
]

# Validation metrics
val_loss = [
    0.5907, 0.4341, 0.3285, 0.2855, 0.2707, 0.2582, 0.2762, 0.2757, 0.2792,
    0.2456, 0.2389, 0.2173, 0.2440, 0.2383, 0.2281, 0.2203, 0.2591
]

val_accuracy = [
    0.7914, 0.7953, 0.8303, 0.8937, 0.8900, 0.9038, 0.8832, 0.8651, 0.8852,
    0.8923, 0.9029, 0.9078, 0.9044, 0.9029, 0.9062, 0.9088, 0.9045
]

val_dice = [
    0.7017, 0.7694, 0.7922, 0.8099, 0.8067, 0.8197, 0.8285, 0.8105, 0.8318,
    0.8371, 0.8443, 0.8614, 0.8433, 0.8432, 0.8533, 0.8592, 0.8261
]

val_iou = [
    0.5427, 0.6273, 0.6574, 0.6829, 0.6781, 0.6967, 0.7098, 0.6831, 0.7149,
    0.7215, 0.7327, 0.7583, 0.7313, 0.7309, 0.7461, 0.7551, 0.7060
]

In [None]:
plt.figure(figsize=(16, 12))

# Plot 1: Loss
plt.subplot(2, 2, 1)
plt.plot(epochs, train_loss, 'bo-', label='Training Loss', linewidth=2)
plt.plot(epochs, val_loss, 'r-o', label='Validation Loss', linewidth=2)
plt.title('Training and Validation Loss', fontsize=14)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot 2: Accuracy
plt.subplot(2, 2, 2)
plt.plot(epochs, train_accuracy, 'bo-', label='Training Accuracy', linewidth=2)
plt.plot(epochs, val_accuracy, 'r-o', label='Validation Accuracy', linewidth=2)
plt.title('Training and Validation Accuracy', fontsize=14)
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot 3: Dice Coefficient
plt.subplot(2, 2, 3)
plt.plot(epochs, train_dice, 'bo-', label='Training Dice Coefficient', linewidth=2)
plt.plot(epochs, val_dice, 'r-o', label='Validation Dice Coefficient', linewidth=2)
plt.title('Training and Validation Dice Coefficient', fontsize=14)
plt.xlabel('Epochs')
plt.ylabel('Dice Coefficient')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot 4: IoU
plt.subplot(2, 2, 4)
plt.plot(epochs, train_iou, 'bo-', label='Training IoU', linewidth=2)
plt.plot(epochs, val_iou, 'r-o', label='Validation IoU', linewidth=2)
plt.title('Training and Validation IoU', fontsize=14)
plt.xlabel('Epochs')
plt.ylabel('IoU')
plt.legend()
plt.grid(True, alpha=0.3)

# Adjust layout and show
plt.tight_layout()
plt.show()

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import KFold
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    confusion_matrix, roc_auc_score
)
import tensorflow as tf

# ==============================
# Parameters
# ==============================
K = 5
random_state = 42
batch_size = 32
epochs = 100
size = 256  # image size

# Normalize input images
tiles_img_norm = tiles_img / 255.0
# Expand mask dims if needed: (N, H, W) -> (N, H, W, 1)
tiles_mask = tiles_mask[..., np.newaxis]

# Initialize K-Fold
kf = KFold(n_splits=K, shuffle=True, random_state=random_state)

# Store metrics per fold
metrics_per_fold = {
    "Accuracy": [], "Precision": [], "Recall": [], "F1": [],
    "Dice": [], "mIoU": [], "ROC-AUC": []
}
conf_matrices = []

# ==============================
# Learning Rate Scheduler
# ==============================
def lr_schedule(epoch):
    lr = 1e-4
    if epoch > 70:
        lr *= 0.1
    elif epoch > 50:
        lr *= 0.5
    return lr

# ==============================
# Start Cross-Validation
# ==============================
fold = 1
for train_index, val_index in kf.split(tiles_img_norm):
    print(f"\n===== Fold {fold}/{K} =====")
    
    # Split data
    X_train_cv, X_val_cv = tiles_img_norm[train_index], tiles_img_norm[val_index]
    y_train_cv, y_val_cv = tiles_mask[train_index], tiles_mask[val_index]
    
    # Data augmentation
    datagen_cv = tf.keras.preprocessing.image.ImageDataGenerator(
        shear_range=0.2, zoom_range=0.2, horizontal_flip=True,
        rotation_range=20, width_shift_range=0.2, height_shift_range=0.2,
        brightness_range=[0.8, 1.2]
    )
    
    # Create model (replace with your actual function)
    model_cv = densenet_unet_model(input_size=(size, size, 3), freeze_encoder=True)
    
    # Compile model
    model_cv.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
        loss='binary_crossentropy',
        metrics=['accuracy']
    )
    
    # Callbacks
    lr_scheduler_cv = tf.keras.callbacks.LearningRateScheduler(lr_schedule)
    early_stop_cv = tf.keras.callbacks.EarlyStopping(
        monitor='val_loss', patience=5, mode='min', restore_best_weights=True, verbose=1
    )
    
    # Train model
    print(f"Training Fold {fold}...")
    model_cv.fit(
        datagen_cv.flow(X_train_cv, y_train_cv, batch_size=batch_size),
        epochs=epochs,
        validation_data=(X_val_cv, y_val_cv),
        callbacks=[lr_scheduler_cv, early_stop_cv],
        verbose=1
    )
    
    # Predict
    y_prob = model_cv.predict(X_val_cv, verbose=0)
    y_pred = (y_prob > 0.5).astype(np.uint8)
    
    # Flatten for metric computation
    y_true_flat = y_val_cv.flatten()
    y_pred_flat = y_pred.flatten()
    y_prob_flat = y_prob.flatten()
    
    # Compute metrics
    acc = accuracy_score(y_true_flat, y_pred_flat)
    prec = precision_score(y_true_flat, y_pred_flat, zero_division=0)
    rec = recall_score(y_true_flat, y_pred_flat, zero_division=0)
    f1 = f1_score(y_true_flat, y_pred_flat, zero_division=0)
    
    # Dice Coefficient (same as F1 for binary, but standard in segmentation)
    intersection = np.logical_and(y_true_flat, y_pred_flat).sum()
    dice = 2. * intersection / (y_true_flat.sum() + y_pred_flat.sum() + 1e-8)
    
    # mIoU
    union = np.logical_or(y_true_flat, y_pred_flat).sum()
    miou = intersection / union if union > 0 else 0.0
    
    # ROC-AUC
    roc = roc_auc_score(y_true_flat, y_prob_flat)
    
    # Store metrics
    metrics_per_fold["Accuracy"].append(acc)
    metrics_per_fold["Precision"].append(prec)
    metrics_per_fold["Recall"].append(rec)
    metrics_per_fold["F1"].append(f1)
    metrics_per_fold["Dice"].append(dice)
    metrics_per_fold["mIoU"].append(miou)
    metrics_per_fold["ROC-AUC"].append(roc)
    
    # Confusion matrix
    cm = confusion_matrix(y_true_flat, y_pred_flat)
    conf_matrices.append(cm)
    
    # Print fold results
    print(f"Fold {fold} Results:")
    print(f"  Accuracy: {acc:.4f}, Precision: {prec:.4f}, Recall: {rec:.4f}")
    print(f"  F1: {f1:.4f}, Dice: {dice:.4f}, mIoU: {miou:.4f}, ROC-AUC: {roc:.4f}")
    
    fold += 1

# ==============================
# Final Results Summary
# ==============================
print("\n" + "="*70)
print("            CROSS-VALIDATION RESULTS (5-FOLD)")
print("="*70)

results_summary = {}
for metric, values in metrics_per_fold.items():
    mean_val = np.mean(values)
    std_val = np.std(values)
    ci95 = 1.96 * (std_val / np.sqrt(K))  # Approximate 95% CI
    ci_lower = mean_val - ci95
    ci_upper = mean_val + ci95
    results_summary[metric] = (mean_val, std_val, ci_lower, ci_upper)
    print(f"{metric:<10}: {mean_val:.4f} Â± {std_val:.4f} | 95% CI: [{ci_lower:.4f}, {ci_upper:.4f}]")

print("="*70)

# ==============================
# Per-Fold Results Table
# ==============================
results_df = pd.DataFrame({
    'Fold': [f"F{i}" for i in range(1, K+1)],
    'Accuracy': [f"{v:.4f}" for v in metrics_per_fold["Accuracy"]],
    'Precision': [f"{v:.4f}" for v in metrics_per_fold["Precision"]],
    'Recall': [f"{v:.4f}" for v in metrics_per_fold["Recall"]],
    'F1': [f"{v:.4f}" for v in metrics_per_fold["F1"]],
    'Dice': [f"{v:.4f}" for v in metrics_per_fold["Dice"]],
    'mIoU': [f"{v:.4f}" for v in metrics_per_fold["mIoU"]],
    'ROC-AUC': [f"{v:.4f}" for v in metrics_per_fold["ROC-AUC"]]
})
print("\nPer-Fold Results:")
print(results_df.to_string(index=False))

# Optional: Save to CSV
# results_df.to_csv("cv_per_fold_results.csv", index=False)

# ==============================
# Plot: Metric Stability Across Folds
# ==============================
plt.figure(figsize=(12, 6))
metrics_to_plot = ["Accuracy", "Dice", "mIoU", "F1", "Recall"]
x_folds = np.arange(1, K+1)

for metric in metrics_to_plot:
    plt.plot(x_folds, metrics_per_fold[metric], 'o-', label=metric)

plt.axhline(np.mean(metrics_per_fold["Dice"]), color='gray', linestyle='--', alpha=0.7,
            label=f"Mean Dice = {np.mean(metrics_per_fold['Dice']):.4f}")
plt.xlabel("Fold", fontsize=12)
plt.ylabel("Score", fontsize=12)
plt.title("Cross-Validation Performance per Fold", fontsize=14, fontweight='bold')
plt.legend()
plt.grid(True, alpha=0.3)
plt.ylim(0.8, 1.0)
plt.xticks(x_folds)
plt.tight_layout()
plt.show()

# ==============================
# Average Confusion Matrix
# ==============================
avg_cm = np.mean(conf_matrices, axis=0)
plt.figure(figsize=(6, 5))
sns.heatmap(avg_cm, annot=True, fmt=".0f", cmap="Blues", cbar=True,
            xticklabels=["Non-Forest", "Forest"],
            yticklabels=["Non-Forest", "Forest"], square=True)
plt.xlabel("Predicted Label", fontsize=12)
plt.ylabel("True Label", fontsize=12)
plt.title("Average Confusion Matrix (5-Fold CV)", fontsize=13)
plt.tight_layout()
plt.show()

# Optional: Save confusion matrix
# plt.savefig("avg_confusion_matrix_cv.png", dpi=300, bbox_inches='tight')

In [None]:
# 5 fold cross validation

In [None]:
import numpy as np
import tensorflow as tf
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import KFold
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    confusion_matrix, roc_auc_score
)
import tensorflow as tf
from tensorflow.keras import backend as K

# ==============================
# Parameters
# ==============================
K = 5
random_state = 42
batch_size = 32
epochs = 20
size = 256  # image size

# Normalize input images
tiles_img_norm = tiles_img / 255.0
# Expand mask dims: (N, H, W) -> (N, H, W, 1)
tiles_mask = tiles_mask[..., np.newaxis]

# Initialize K-Fold
kf = KFold(n_splits=K, shuffle=True, random_state=random_state)

# Store metrics per fold
metrics_per_fold = {
    "Accuracy": [], "Precision": [], "Recall": [], "F1": [],
    "Dice": [], "mIoU": [], "ROC-AUC": []
}
conf_matrices = []

# ==============================
# Dice Coefficient Metric
# ==============================




def dice_coef(y_true, y_pred):
    y_true_f = tf.cast(tf.reshape(y_true, [-1]), tf.float32)
    y_pred_f = tf.cast(tf.reshape(y_pred, [-1]), tf.float32)
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return (2. * intersection + 1e-8) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + 1e-8)
# ==============================
# Learning Rate Scheduler (on Dice)
# ==============================
lr_scheduler_cv = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_dice_coef',
    factor=0.5,
    patience=3,
    mode='max',
    min_lr=1e-7,
    verbose=1
)

# ==============================
# Start Cross-Validation
# ==============================
fold = 1
for train_index, val_index in kf.split(tiles_img_norm):
    print(f"\n===== Fold {fold}/{K} =====")
    
    # Split data
    X_train_cv, X_val_cv = tiles_img_norm[train_index], tiles_img_norm[val_index]
    y_train_cv, y_val_cv = tiles_mask[train_index], tiles_mask[val_index]
    
    # Data augmentation
    datagen_cv = tf.keras.preprocessing.image.ImageDataGenerator(
        shear_range=0.2, zoom_range=0.2, horizontal_flip=True,
        rotation_range=20, width_shift_range=0.2, height_shift_range=0.2,
        brightness_range=[0.8, 1.2]
    )
    
    # Create model (replace with your actual function)
    model_cv = densenet_unet_model(input_size=(size, size, 3), freeze_encoder=True)
    
    # Compile model with Dice as metric
    model_cv.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
        loss='binary_crossentropy',
        metrics=['accuracy', dice_coef]
    )
    
    # Callbacks: Early stopping and checkpoint on Dice
    early_stop_cv = tf.keras.callbacks.EarlyStopping(
        monitor='val_dice_coef',
        patience=5,
        mode='max',
        restore_best_weights=True,
        verbose=1
    )
    
    model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
        f'best_densenet_unet_fold_{fold}.h5',
        monitor='val_dice_coef',
        save_best_only=True,
        mode='max',
        verbose=1
    )
    
    # Train model
    print(f"Training Fold {fold}...")
    model_cv.fit(
        datagen_cv.flow(X_train_cv, y_train_cv, batch_size=batch_size),
        epochs=epochs,
        validation_data=(X_val_cv, y_val_cv),
        callbacks=[lr_scheduler_cv, early_stop_cv, model_checkpoint],
        verbose=1
    )
    
    # Predict
    y_prob = model_cv.predict(X_val_cv, verbose=0)
    y_pred = (y_prob > 0.5).astype(np.uint8)
    
    # Flatten for metric computation
    y_true_flat = y_val_cv.flatten()
    y_pred_flat = y_pred.flatten()
    y_prob_flat = y_prob.flatten()
    
    # Compute metrics
    acc = accuracy_score(y_true_flat, y_pred_flat)
    prec = precision_score(y_true_flat, y_pred_flat, zero_division=0)
    rec = recall_score(y_true_flat, y_pred_flat, zero_division=0)
    f1 = f1_score(y_true_flat, y_pred_flat, zero_division=0)
    
    # Dice Coefficient
    intersection = np.logical_and(y_true_flat, y_pred_flat).sum()
    dice = 2. * intersection / (y_true_flat.sum() + y_pred_flat.sum() + 1e-8)
    
    # mIoU
    union = np.logical_or(y_true_flat, y_pred_flat).sum()
    miou = intersection / union if union > 0 else 0.0
    
    # ROC-AUC
    roc = roc_auc_score(y_true_flat, y_prob_flat)
    
    # Store metrics
    metrics_per_fold["Accuracy"].append(acc)
    metrics_per_fold["Precision"].append(prec)
    metrics_per_fold["Recall"].append(rec)
    metrics_per_fold["F1"].append(f1)
    metrics_per_fold["Dice"].append(dice)
    metrics_per_fold["mIoU"].append(miou)
    metrics_per_fold["ROC-AUC"].append(roc)
    
    # Confusion matrix
    cm = confusion_matrix(y_true_flat, y_pred_flat)
    conf_matrices.append(cm)
    
    # Print fold results
    print(f"Fold {fold} Results:")
    print(f"  Accuracy: {acc:.4f}, Precision: {prec:.4f}, Recall: {rec:.4f}")
    print(f"  F1: {f1:.4f}, Dice: {dice:.4f}, mIoU: {miou:.4f}, ROC-AUC: {roc:.4f}")
    
    fold += 1

# ==============================
# Final Results Summary
# ==============================
print("\n" + "="*70)
print("            CROSS-VALIDATION RESULTS (5-FOLD)")
print("="*70)

results_summary = {}
for metric, values in metrics_per_fold.items():
    mean_val = np.mean(values)
    std_val = np.std(values)
    ci95 = 1.96 * (std_val / np.sqrt(K))  # Approximate 95% CI
    ci_lower = mean_val - ci95
    ci_upper = mean_val + ci95
    results_summary[metric] = (mean_val, std_val, ci_lower, ci_upper)
    print(f"{metric:<10}: {mean_val:.4f} Â± {std_val:.4f} | 95% CI: [{ci_lower:.4f}, {ci_upper:.4f}]")

print("="*70)

# ==============================
# Per-Fold Results Table
# ==============================
results_df = pd.DataFrame({
    'Fold': [f"F{i}" for i in range(1, K+1)],
    'Accuracy': [f"{v:.4f}" for v in metrics_per_fold["Accuracy"]],
    'Precision': [f"{v:.4f}" for v in metrics_per_fold["Precision"]],
    'Recall': [f"{v:.4f}" for v in metrics_per_fold["Recall"]],
    'F1': [f"{v:.4f}" for v in metrics_per_fold["F1"]],
    'Dice': [f"{v:.4f}" for v in metrics_per_fold["Dice"]],
    'mIoU': [f"{v:.4f}" for v in metrics_per_fold["mIoU"]],
    'ROC-AUC': [f"{v:.4f}" for v in metrics_per_fold["ROC-AUC"]]
})
print("\nPer-Fold Results:")
print(results_df.to_string(index=False))

# Optional: Save to CSV
# results_df.to_csv("cv_per_fold_results.csv", index=False)

# ==============================
# Plot: Metric Stability Across Folds
# ==============================
plt.figure(figsize=(12, 6))
metrics_to_plot = ["Accuracy", "Dice", "mIoU", "F1", "Recall"]
x_folds = np.arange(1, K+1)

for metric in metrics_to_plot:
    plt.plot(x_folds, metrics_per_fold[metric], 'o-', label=metric)

plt.axhline(np.mean(metrics_per_fold["Dice"]), color='gray', linestyle='--', alpha=0.7,
            label=f"Mean Dice = {np.mean(metrics_per_fold['Dice']):.4f}")
plt.xlabel("Fold", fontsize=12)
plt.ylabel("Score", fontsize=12)
plt.title("Cross-Validation Performance per Fold", fontsize=14, fontweight='bold')
plt.legend()
plt.grid(True, alpha=0.3)
plt.ylim(0.8, 1.0)
plt.xticks(x_folds)
plt.tight_layout()
plt.show()

# ==============================
# Average Confusion Matrix
# ==============================
avg_cm = np.mean(conf_matrices, axis=0)
plt.figure(figsize=(6, 5))
sns.heatmap(avg_cm, annot=True, fmt=".0f", cmap="Blues", cbar=True,
            xticklabels=["Non-Forest", "Forest"],
            yticklabels=["Non-Forest", "Forest"], square=True)
plt.xlabel("Predicted Label", fontsize=12)
plt.ylabel("True Label", fontsize=12)
plt.title("Average Confusion Matrix (5-Fold CV)", fontsize=13)
plt.tight_layout()
plt.show()

# Optional: Save results
# plt.savefig("avg_confusion_matrix_cv.png", dpi=300, bbox_inches='tight')