In [None]:
from tensorflow.keras.layers import Layer, GlobalAveragePooling2D, Conv1D, Multiply, Reshape, Add, Conv2D, BatchNormalization, Activation, Dense, Dropout, Input, MaxPooling2D, Concatenate, AveragePooling2D, GlobalMaxPooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.applications import MobileNet
from tensorflow.keras.applications.inception_v3 import InceptionV3
import keras
import tensorflow as tf
import os
from tensorflow.keras.applications import ResNet50, DenseNet121
from tensorflow.keras.applications.inception_v3 import InceptionV3
from tensorflow.keras.applications.efficientnet import EfficientNetB0
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Activation, Lambda, GlobalAveragePooling2D, Dense, Add, Input, BatchNormalization  # Add BatchNormalization here
from tensorflow.keras.applications import VGG16
from tensorflow.keras.applications import VGG19
from tensorflow.keras.applications import InceptionResNetV2
from tensorflow.keras.applications import EfficientNetB3
from tensorflow.keras.applications import Xception
from tensorflow.keras.applications import NASNetMobile
from tensorflow.keras.applications import EfficientNetB7
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.applications import ResNet101
from tensorflow.keras.applications import DenseNet201
from tensorflow.keras.applications import NASNetLarge
from tensorflow.keras.applications import EfficientNetV2B0
from tensorflow.keras.applications import ConvNeXtBase
import tempfile 
import shutil
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.preprocessing import label_binarize
from sklearn.metrics import precision_recall_fscore_support
from sklearn.model_selection import KFold
import numpy as np
############################################Attention Mechanisms##########################
# Custom Layers (SE Block)
@keras.saving.register_keras_serializable()
class SEBlock(Layer):
    def __init__(self, ratio=16, **kwargs):
        super(SEBlock, self).__init__(**kwargs)
        self.ratio = ratio

    def build(self, input_shape):
        self.num_channels = input_shape[-1]
        self.squeeze = GlobalAveragePooling2D()
        self.excitation = Dense(self.num_channels // self.ratio, activation='relu')
        self.scale = Dense(self.num_channels, activation='sigmoid')

    def call(self, inputs):
        # Squeeze: Global Average Pooling
        x = self.squeeze(inputs)
        x = Reshape((1, 1, self.num_channels))(x)

        # Excitation: Two fully connected layers
        x = self.excitation(x)
        x = self.scale(x)
 # Scale: Multiply the input with the excitation output
        return Multiply()([inputs, x])
############################################################################
@keras.saving.register_keras_serializable()
class CBAM(Layer):
    def __init__(self, ratio=16, **kwargs):
        super(CBAM, self).__init__(**kwargs)
        self.ratio = ratio

    def build(self, input_shape):
        self.channel_attention = ChannelAttention(ratio=self.ratio)
        self.spatial_attention = SpatialAttention()

    def call(self, inputs):
        x = self.channel_attention(inputs)
        x = self.spatial_attention(x)
        return x

class ChannelAttention(Layer):
    def __init__(self, ratio=16, **kwargs):
        super(ChannelAttention, self).__init__(**kwargs)
        self.ratio = ratio

    def build(self, input_shape):
        self.num_channels = input_shape[-1]
        self.avg_pool = GlobalAveragePooling2D()
        self.max_pool = GlobalMaxPooling2D()
        self.fc1 = Dense(self.num_channels // self.ratio, activation='relu')
        self.fc2 = Dense(self.num_channels, activation='sigmoid')

    def call(self, inputs):
        avg_out = self.fc2(self.fc1(self.avg_pool(inputs)))
        max_out = self.fc2(self.fc1(self.max_pool(inputs)))
        out = avg_out + max_out
        return Multiply()([inputs, out])

class SpatialAttention(Layer):
    def __init__(self, **kwargs):
        super(SpatialAttention, self).__init__(**kwargs)

    def build(self, input_shape):
        self.conv = Conv2D(1, 7, padding='same', activation='sigmoid')

    def call(self, inputs):
        avg_out = tf.reduce_mean(inputs, axis=3, keepdims=True)
        max_out = tf.reduce_max(inputs, axis=3, keepdims=True)
        out = tf.concat([avg_out, max_out], axis=3)
        out = self.conv(out)
        return Multiply()([inputs, out])
############################################################################        
@keras.saving.register_keras_serializable()
class SelfAttention(Layer):
    def __init__(self, **kwargs):
        super(SelfAttention, self).__init__(**kwargs)

    def build(self, input_shape):
        self.num_channels = input_shape[-1]
        self.query = Conv2D(self.num_channels // 8, 1, padding='same')
        self.key = Conv2D(self.num_channels // 8, 1, padding='same')
        self.value = Conv2D(self.num_channels, 1, padding='same')
        self.gamma = self.add_weight(name='gamma', shape=[1], initializer='zeros')

    def call(self, inputs):
        batch_size, height, width, num_channels = tf.unstack(tf.shape(inputs))
        query = self.query(inputs)
        key = self.key(inputs)
        value = self.value(inputs)

        query = tf.reshape(query, [batch_size, height * width, num_channels // 8])
        key = tf.reshape(key, [batch_size, height * width, num_channels // 8])
        value = tf.reshape(value, [batch_size, height * width, num_channels])

        attention = tf.matmul(query, key, transpose_b=True)
        attention = tf.nn.softmax(attention, axis=-1)

        out = tf.matmul(attention, value)
        out = tf.reshape(out, [batch_size, height, width, num_channels])
        out = self.gamma * out + inputs
        return out
#################################################################################
@keras.saving.register_keras_serializable()
class ECANet(Layer):
    def __init__(self, k_size=32, **kwargs):
        super(ECANet, self).__init__(**kwargs)
        self.k_size = k_size

    def build(self, input_shape):
        self.num_channels = input_shape[-1]
        self.conv = Conv1D(1, kernel_size=self.k_size, padding='same', use_bias=False)

    def call(self, inputs):
        x = tf.reduce_mean(inputs, axis=[1, 2], keepdims=True)
        x = tf.squeeze(x, axis=[1, 2])
        x = self.conv(tf.expand_dims(x, axis=-1))
        x = tf.squeeze(x, axis=-1)
        x = tf.nn.sigmoid(x)
        x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1)
        return Multiply()([inputs, x])
###############################################################################
@keras.saving.register_keras_serializable()
class TripletAttention(Layer):
    def __init__(self, **kwargs):
        super(TripletAttention, self).__init__(**kwargs)

    def build(self, input_shape):
        self.num_channels = input_shape[-1]
        self.conv_h = Conv2D(self.num_channels, 1, padding='same', activation='sigmoid')
        self.conv_w = Conv2D(self.num_channels, 1, padding='same', activation='sigmoid')
        self.conv_c = Conv2D(self.num_channels, 1, padding='same', activation='sigmoid')

    def call(self, inputs):
        # Height attention
        x_h = tf.reduce_mean(inputs, axis=2, keepdims=True)
        x_h = self.conv_h(x_h)

        # Width attention
        x_w = tf.reduce_mean(inputs, axis=1, keepdims=True)
        x_w = self.conv_w(x_w)

        # Channel attention
        x_c = tf.reduce_mean(inputs, axis=[1, 2], keepdims=True)
        x_c = self.conv_c(x_c)

        out = inputs * x_h * x_w * x_c
        return out
####################################################################################
# Residual Block with Bottleneck
def residual_block(x, filters, strides=1):
    shortcut = x
    x = Conv2D(filters, (1, 1), strides=strides, padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2D(filters, (3, 3), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2D(filters * 4, (1, 1), padding='same')(x)
    x = BatchNormalization()(x)

    if strides != 1 or shortcut.shape[-1] != filters * 4:
        shortcut = Conv2D(filters * 4, (1, 1), strides=strides, padding='same')(shortcut)
        shortcut = BatchNormalization()(shortcut)

    x = Add()([x, shortcut])
    x = Activation('relu')(x)
    return x
##################################################################################
# Global Context Block
def global_context_block(x):
    gap = GlobalAveragePooling2D()(x)
    gap = Reshape((1, 1, gap.shape[-1]))(gap)
    gap = Conv2D(x.shape[-1], (1, 1), activation='sigmoid')(gap)
    return Multiply()([x, gap])
##################################################################################
##################################################################################
def inception_module(x, filters):
    """Inception module with dimension reductions"""
    # 1x1 branch
    branch1x1 = Conv2D(filters[0], (1, 1), padding='same', activation='relu')(x)
    
    # 3x3 branch
    branch3x3 = Conv2D(filters[1], (1, 1), padding='same', activation='relu')(x)
    branch3x3 = Conv2D(filters[2], (3, 3), padding='same', activation='relu')(branch3x3)
    
    # 5x5 branch
    branch5x5 = Conv2D(filters[3], (1, 1), padding='same', activation='relu')(x)
    branch5x5 = Conv2D(filters[4], (5, 5), padding='same', activation='relu')(branch5x5)
    
    # Pooling branch
    branch_pool = AveragePooling2D((3, 3), strides=(1, 1), padding='same')(x)
    branch_pool = Conv2D(filters[5], (1, 1), padding='same', activation='relu')(branch_pool)
    
    # Concatenate all branches
    return Concatenate()([branch1x1, branch3x3, branch5x5, branch_pool])
####################################################################################
# Custom Inception-Attention Block
def inception_mfr_attention_block(x, filters, attention_type='se'):
    """Inception -> Attention block"""
    # First process through Inception module
    x = inception_module(x, filters)
    x = residual_block(x, filters=64, strides=2)
    x = residual_block(x, filters=64, strides=1)
    # Finally apply attention
    if attention_type == 'se':
       x = SEBlock()(x)
       #x = ECANet()(x)
       #x = CBAM()(x)
       #x = SelfAttention()(x)
       #x = TripletAttention()(x)
    
    return Activation('relu')(x)
#######################################################################################
# Define paths and classes
data_dir = '/kaggle/input/kvasirv2/KvasirV2'
CLASSES = [
    'dyed-lifted-polyps', 'dyed-resection-margins', 'esophagitis', 
    'normal-cecum', 'normal-pylorus', 'normal-z-line', 'polyps', 'ulcerative-colitis'
]

# First, collect all image paths and labels
image_paths = []
labels = []

for class_idx, class_name in enumerate(CLASSES):
    class_path = os.path.join(data_dir, class_name)
    for img_name in os.listdir(class_path):
        image_paths.append(os.path.join(class_path, img_name))
        labels.append(class_idx)

# Convert to numpy arrays
image_paths = np.array(image_paths)
labels = np.array(labels)

Folds = 3
# Initialize KFold
kfold = KFold(n_splits=Folds, shuffle=True, random_state=42)

# Store histories for each fold
fold_histories = []
fold_metrics = []

for fold_idx, (train_idx, val_idx) in enumerate(kfold.split(image_paths)):
    print(f"\nTraining fold {fold_idx + 1}/3")
    
    # Create temporary directories for this fold
    temp_dir = tempfile.mkdtemp()
    train_folder = os.path.join(temp_dir, 'train')
    val_folder = os.path.join(temp_dir, 'val')
    os.makedirs(train_folder, exist_ok=True)
    os.makedirs(val_folder, exist_ok=True)
    
    # Create class subdirectories
    for class_name in CLASSES:
        os.makedirs(os.path.join(train_folder, class_name), exist_ok=True)
        os.makedirs(os.path.join(val_folder, class_name), exist_ok=True)
    
    # Copy images to respective folders
    for idx in train_idx:
        src = image_paths[idx]
        class_name = CLASSES[labels[idx]]
        dst = os.path.join(train_folder, class_name, os.path.basename(src))
        shutil.copy(src, dst)
    
    for idx in val_idx:
        src = image_paths[idx]
        class_name = CLASSES[labels[idx]]
        dst = os.path.join(val_folder, class_name, os.path.basename(src))
        shutil.copy(src, dst)
    
    # Data generators
    train_datagen = ImageDataGenerator(rescale=1./255)
    
    val_datagen = ImageDataGenerator(rescale=1./255)
    
    train_generator = train_datagen.flow_from_directory(
        train_folder,
        target_size=(224, 224),
        batch_size=20,
        color_mode='rgb',
        shuffle=True,
        class_mode='categorical'
    )
    
    val_generator = val_datagen.flow_from_directory(
        val_folder,
        target_size=(224, 224),
        batch_size=20,
        color_mode='rgb',
        shuffle=False,
        class_mode='categorical'
    )
    
    # Load DenseNet201 as a feature extractor
    densenet201_base = DenseNet201(include_top=False, weights='imagenet', input_shape=(224, 224, 3))
   
    # Build model
    x = densenet201_base.output
    x = inception_mfr_attention_block(x, [32, 32, 64, 32, 64, 32], attention_type='se')
    x = global_context_block(x)
    x = GlobalAveragePooling2D()(x)
    output = Dense(8, activation='softmax')(x)
    model = Model(inputs=densenet201_base.input, outputs=output)
    
    # Compile
    model.compile(optimizer=Adam(learning_rate=0.0001), 
                  loss='categorical_crossentropy', 
                  metrics=['accuracy'])
    
    # Callbacks
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, min_lr=1e-6)
    model_checkpoint = ModelCheckpoint(f"model_inception_mfr_attention_fold{fold_idx}.keras", 
                                     save_best_only=True, 
                                     monitor='val_accuracy', mode = 'max')
    
    # Train
    history = model.fit(
        train_generator,
        epochs=20,
        validation_data=val_generator,
        callbacks=[reduce_lr, model_checkpoint]
    )
    
    # Save history and metrics
    fold_histories.append(history)
    fold_metrics.append(model.evaluate(val_generator))
    
    # Clean up temporary files
    shutil.rmtree(temp_dir)

# Print cross-validation results
print("\nCross-validation results:")
for i, (loss, acc) in enumerate(fold_metrics):
    print(f"Fold {i+1}: Loss = {loss:.4f}, Accuracy = {acc:.4f}")

mean_loss = np.mean([m[0] for m in fold_metrics])
mean_acc = np.mean([m[1] for m in fold_metrics])
print(f"\nMean across folds: Loss = {mean_loss:.4f}, Accuracy = {mean_acc:.4f}")
###########################################################################################
# Initialize variables to store metrics across folds
all_reports = []
all_confusion_matrices = []
all_metrics = []

# Function to evaluate model and plot confusion matrix
def evaluate_model(model, generator, fold_idx):
    # Get true labels and predictions
    y_true = generator.classes
    y_pred = model.predict(generator)
    y_pred_classes = np.argmax(y_pred, axis=1)
    
    # Classification report
    class_names = list(generator.class_indices.keys())
    report = classification_report(y_true, y_pred_classes, target_names=class_names, output_dict=True)
    
    # Calculate additional metrics
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred_classes, average=None)
    support = np.bincount(y_true)
    
    # Calculate specificity for each class
    cm = confusion_matrix(y_true, y_pred_classes)
    specificity = []
    for i in range(len(class_names)):
        tn = np.sum(np.delete(np.delete(cm, i, axis=0), i, axis=1))
        fp = np.sum(cm[:, i]) - cm[i, i]
        specificity.append(tn / (tn + fp))
    
    # Add to fold metrics
    fold_metrics = {
        'accuracy': report['accuracy'],
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'specificity': specificity,
        'support': support,
        'confusion_matrix': cm
    }
    
    return fold_metrics, report, cm


# Process each fold
for fold_idx in range(Folds):
    print(f"\nEvaluating Fold {fold_idx + 1}")
    
    # Load the saved model for this fold
    model = load_model(f"model_inception_mfr_attention_fold{fold_idx}.keras", compile=False)
    model.compile(optimizer=Adam(learning_rate=0.0001), 
                  loss='categorical_crossentropy', 
                  metrics=['accuracy'])
    
    # Recreate validation data for this fold
    temp_dir = tempfile.mkdtemp()
    val_folder = os.path.join(temp_dir, 'val')
    os.makedirs(val_folder, exist_ok=True)
    
    for class_name in CLASSES:
        os.makedirs(os.path.join(val_folder, class_name), exist_ok=True)
    
    # Get the original fold indices
    folds = list(kfold.split(image_paths))
    _, val_idx = folds[fold_idx]
    
    # Copy validation images
    for idx in val_idx:
        src = image_paths[idx]
        class_name = CLASSES[labels[idx]]
        dst = os.path.join(val_folder, class_name, os.path.basename(src))
        shutil.copy(src, dst)
    
    # Create validation generator
    val_datagen = ImageDataGenerator(rescale=1./255)
    val_generator = val_datagen.flow_from_directory(
        val_folder,
        target_size=(224, 224),
        batch_size=20,
        color_mode='rgb',
        shuffle=False,
        class_mode='categorical'
    )
    
    # Evaluate the model
    fold_metrics, report, cm = evaluate_model(model, val_generator, fold_idx)
    all_reports.append(report)
    all_confusion_matrices.append(cm)
    all_metrics.append(fold_metrics)
    
    # Plot confusion matrix
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=CLASSES, 
                yticklabels=CLASSES)
    plt.title(f'Confusion Matrix - Fold {fold_idx + 1}')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.xticks(rotation=45)
    plt.yticks(rotation=45)
    plt.tight_layout()
    plt.savefig(f'confusion_matrix_fold_{fold_idx+1}.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Print metrics for this fold
    print(f"\nMetrics for Fold {fold_idx + 1}:")
    print(f"Accuracy: {fold_metrics['accuracy']:.4f}")
    print("Class-wise metrics:")
    for i, class_name in enumerate(CLASSES):
        print(f"{class_name}:")
        print(f"  Precision: {fold_metrics['precision'][i]:.4f}")
        print(f"  Recall: {fold_metrics['recall'][i]:.4f}")
        print(f"  F1-score: {fold_metrics['f1_score'][i]:.4f}")
        print(f"  Specificity: {fold_metrics['specificity'][i]:.4f}")
    
    # Clean up
    shutil.rmtree(temp_dir)

# Calculate average metrics across folds
avg_metrics = {
    'accuracy': np.mean([m['accuracy'] for m in all_metrics]),
    'precision': np.mean([m['precision'] for m in all_metrics], axis=0),
    'recall': np.mean([m['recall'] for m in all_metrics], axis=0),
    'f1_score': np.mean([m['f1_score'] for m in all_metrics], axis=0),
    'specificity': np.mean([m['specificity'] for m in all_metrics], axis=0),
    'support': all_metrics[0]['support']  # Support is the same across folds
}

# Calculate macro-averaged metrics
macro_precision = np.mean(avg_metrics['precision'])
macro_recall = np.mean(avg_metrics['recall'])
macro_f1 = np.mean(avg_metrics['f1_score'])

# Save all metrics to Excel
with pd.ExcelWriter('cross_validation_metrics.xlsx') as writer:
    # Save each fold's metrics
    for fold_idx in range(Folds):
        fold_df = pd.DataFrame({
            'Class': CLASSES,
            'Precision': all_metrics[fold_idx]['precision'],
            'Recall': all_metrics[fold_idx]['recall'],
            'F1-score': all_metrics[fold_idx]['f1_score'],
            'Specificity': all_metrics[fold_idx]['specificity'],
            'Support': all_metrics[fold_idx]['support']
        })
        fold_df.to_excel(writer, sheet_name=f'Fold_{fold_idx+1}', index=False)
    
    # Save average metrics
    avg_df = pd.DataFrame({
        'Class': CLASSES,
        'Avg_Precision': avg_metrics['precision'],
        'Avg_Recall': avg_metrics['recall'],
        'Avg_F1-score': avg_metrics['f1_score'],
        'Avg_Specificity': avg_metrics['specificity'],
        'Support': avg_metrics['support']
    })
    avg_df.to_excel(writer, sheet_name='Average_Metrics', index=False)
    
    # Add overall metrics
    overall_df = pd.DataFrame({
        'Metric': ['Accuracy', 'Precision (Macro)', 'Recall (Macro)', 'F1-score (Macro)'],
        'Average': [avg_metrics['accuracy'], macro_precision, macro_recall, macro_f1]
    })
    overall_df.to_excel(writer, sheet_name='Overall', index=False)

# Print average metrics
print("\nAverage Metrics Across All Folds:")
print(f"Accuracy: {avg_metrics['accuracy']:.4f}")
print(f"Macro Precision: {macro_precision:.4f}")
print(f"Macro Recall: {macro_recall:.4f}")
print(f"Macro F1-score: {macro_f1:.4f}")
print("Class-wise averages:")
for i, class_name in enumerate(CLASSES):
    print(f"{class_name}:")
    print(f"  Precision: {avg_metrics['precision'][i]:.4f}")
    print(f"  Recall: {avg_metrics['recall'][i]:.4f}")
    print(f"  F1-score: {avg_metrics['f1_score'][i]:.4f}")
    print(f"  Specificity: {avg_metrics['specificity'][i]:.4f}")