In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
from tensorflow.keras.applications import EfficientNetB3, ResNet152V2, Xception
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, CSVLogger
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import (
    confusion_matrix, classification_report,
    accuracy_score, precision_recall_fscore_support,
    roc_auc_score, cohen_kappa_score
)
import matplotlib.pyplot as plt
import seaborn as sns
import cv2

In [None]:
# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)


IMG_SIZE = (224, 224)
INPUT_SHAPE = (224, 224, 3)
NUM_CLASSES = 4
BATCH_SIZE = 32
EPOCHS = 30
LEARNING_RATE = 0.001
DROPOUT_RATE = 0.45
SENET_REDUCTION_RATIO = 16

# Data directory
DATA_DIR = "-"

print("=" * 80)
print("WEATHER CLASSIFICATION - EXACT PAPER IMPLEMENTATION")
print("=" * 80)
print(f"Image Size: {IMG_SIZE}")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Epochs: {EPOCHS}")
print(f"Learning Rate: {LEARNING_RATE}")
print(f"Dropout Rate: {DROPOUT_RATE}")
print(f"SENet Reduction Ratio: {SENET_REDUCTION_RATIO}")
print("=" * 80)


In [None]:
# ============================================================================
# 1. DATA LOADING AND PREPROCESSING
# ============================================================================

def load_dataset(data_dir):

    X, y = [], []
    class_names = sorted(os.listdir(data_dir))

    print(f"\nLoading dataset from: {data_dir}")
    print(f"Classes: {class_names}")

    for label, class_name in enumerate(class_names):
        class_path = os.path.join(data_dir, class_name)
        if not os.path.isdir(class_path):
            continue

        image_files = [f for f in os.listdir(class_path)
                      if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

        for img_file in image_files:
            img_path = os.path.join(class_path, img_file)


            img = cv2.imread(img_path)
            if img is None:
                continue

            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = cv2.resize(img, IMG_SIZE)
            img = img.astype(np.float32) / 255.0

            X.append(img)
            y.append(label)

        print(f"  {class_name}: {len(image_files)} images")

    X = np.array(X, dtype=np.float32)
    y = np.array(y, dtype=np.int32)

    print(f"\nTotal images loaded: {len(X)}")
    print(f"Image shape: {X[0].shape}")

    return X, y, class_names


def create_data_generators():

    train_datagen = ImageDataGenerator(
        rotation_range=20,
        width_shift_range=0.2,
        height_shift_range=0.2,
        horizontal_flip=True,
        vertical_flip=True,
        zoom_range=0.2,
        brightness_range=[0.8, 1.2],
        fill_mode='nearest'
    )


    val_test_datagen = ImageDataGenerator()

    return train_datagen, val_test_datagen


In [None]:
# ============================================================================
# 2. MODEL COMPONENTS
# ============================================================================

class SENetBlock(layers.Layer):


    def __init__(self, reduction_ratio=16, **kwargs):
        super(SENetBlock, self).__init__(**kwargs)
        self.reduction_ratio = reduction_ratio

    def build(self, input_shape):
        channels = input_shape[-1]


        self.gap = layers.GlobalAveragePooling2D()


        self.fc1 = layers.Dense(
            channels // self.reduction_ratio,
            activation='relu',
            use_bias=False,
            name='se_fc1'
        )
        self.fc2 = layers.Dense(
            channels,
            activation='sigmoid',
            use_bias=False,
            name='se_fc2'
        )

        super(SENetBlock, self).build(input_shape)

    def call(self, inputs):

        z = self.gap(inputs)  # (B, C)


        s = self.fc1(z)       # (B, C/r)
        s = self.fc2(s)       # (B, C)


        s = tf.reshape(s, [-1, 1, 1, tf.shape(inputs)[-1]])


        output = inputs * s

        return output

    def get_config(self):
        config = super(SENetBlock, self).get_config()
        config.update({'reduction_ratio': self.reduction_ratio})
        return config


class AttentionBlock(layers.Layer):


    def __init__(self, **kwargs):
        super(AttentionBlock, self).__init__(**kwargs)

    def build(self, input_shape):
        channels = input_shape[-1]


        self.query_conv = layers.Conv2D(
            channels, 1,
            padding='same',
            name='attention_query'
        )
        self.key_conv = layers.Conv2D(
            channels, 1,
            padding='same',
            name='attention_key'
        )
        self.value_conv = layers.Conv2D(
            channels, 1,
            padding='same',
            name='attention_value'
        )

        super(AttentionBlock, self).build(input_shape)

    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        height = tf.shape(inputs)[1]
        width = tf.shape(inputs)[2]
        channels = tf.shape(inputs)[3]


        query = self.query_conv(inputs)  # (B, H, W, C)
        key = self.key_conv(inputs)      # (B, H, W, C)
        value = self.value_conv(inputs)  # (B, H, W, C)


        query = tf.reshape(query, [batch_size, -1, channels])  # (B, N, C)
        key = tf.reshape(key, [batch_size, -1, channels])      # (B, N, C)
        value = tf.reshape(value, [batch_size, -1, channels])  # (B, N, C)



        attention_scores = tf.matmul(query, key, transpose_b=True)  # (B, N, N)


        dk = tf.cast(channels, tf.float32)
        attention_scores = attention_scores / tf.sqrt(dk)


        attention_weights = tf.nn.softmax(attention_scores, axis=-1)


        attention_output = tf.matmul(attention_weights, value)  # (B, N, C)


        attention_output = tf.reshape(
            attention_output,
            [batch_size, height, width, channels]
        )


        output = inputs + attention_output

        return output

    def get_config(self):
        config = super(AttentionBlock, self).get_config()
        return config

In [None]:
# ============================================================================
# 3. MODEL ARCHITECTURES
# ============================================================================

def create_baseline_efficientnetb3():

    base_model = EfficientNetB3(
        include_top=False,
        weights='imagenet',
        input_shape=INPUT_SHAPE,
        pooling='avg'
    )


    base_model.trainable = False


    inputs = layers.Input(shape=INPUT_SHAPE)
    x = base_model(inputs, training=False)
    x = layers.BatchNormalization()(x)
    x = layers.Dense(256, activation='relu')(x)
    x = layers.Dropout(DROPOUT_RATE)(x)
    outputs = layers.Dense(NUM_CLASSES, activation='softmax')(x)

    model = models.Model(inputs, outputs, name='EfficientNetB3_FineTuned')

    return model


def create_baseline_resnet152v2():

    base_model = ResNet152V2(
        include_top=False,
        weights='imagenet',
        input_shape=INPUT_SHAPE,
        pooling='avg'
    )


    base_model.trainable = False


    inputs = layers.Input(shape=INPUT_SHAPE)
    x = base_model(inputs, training=False)
    x = layers.BatchNormalization()(x)
    x = layers.Dense(256, activation='relu')(x)
    x = layers.Dropout(DROPOUT_RATE)(x)
    outputs = layers.Dense(NUM_CLASSES, activation='softmax')(x)

    model = models.Model(inputs, outputs, name='ResNet152V2_FineTuned')

    return model


def create_baseline_xception():

    base_model = Xception(
        include_top=False,
        weights='imagenet',
        input_shape=INPUT_SHAPE,
        pooling='avg'
    )


    base_model.trainable = False


    inputs = layers.Input(shape=INPUT_SHAPE)
    x = base_model(inputs, training=False)
    x = layers.BatchNormalization()(x)
    x = layers.Dense(256, activation='relu')(x)
    x = layers.Dropout(DROPOUT_RATE)(x)
    outputs = layers.Dense(NUM_CLASSES, activation='softmax')(x)

    model = models.Model(inputs, outputs, name='Xception_FineTuned')

    return model


def create_proposed_model():


    base_model = Xception(
        include_top=False,
        weights='imagenet',
        input_shape=INPUT_SHAPE
    )


    base_model.trainable = False


    inputs = layers.Input(shape=INPUT_SHAPE, name='input')


    x = base_model(inputs, training=False)


    x = layers.BatchNormalization(name='bn_after_xception')(x)


    x = SENetBlock(reduction_ratio=SENET_REDUCTION_RATIO, name='senet_block')(x)


    x = AttentionBlock(name='attention_block')(x)


    x = layers.GlobalAveragePooling2D(name='global_avg_pool')(x)


    x = layers.Dense(256, activation='relu', name='dense_256')(x)


    x = layers.Dropout(DROPOUT_RATE, name='dropout')(x)


    outputs = layers.Dense(NUM_CLASSES, activation='softmax', name='output')(x)


    model = models.Model(
        inputs,
        outputs,
        name='Xception_SENet_Attention'
    )

    return model

In [None]:
# ============================================================================
# 4. TRAINING FUNCTION
# ============================================================================

def train_model(model, X_train, y_train, X_val, y_val, model_name, class_weights=None):

    print(f"\n{'='*80}")
    print(f"Training {model_name}")
    print(f"{'='*80}")


    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
        loss='categorical_crossentropy',
        metrics=['accuracy',
                keras.metrics.Precision(name='precision'),
                keras.metrics.Recall(name='recall')]
    )


    callbacks = [
        ModelCheckpoint(
            f'{model_name}_best.h5',
            monitor='val_accuracy',
            save_best_only=True,
            mode='max',
            verbose=1
        ),
        ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=3,
            min_lr=1e-7,
            verbose=1
        ),
        CSVLogger(f'{model_name}_training_log.csv')
    ]


    train_datagen, val_datagen = create_data_generators()


    train_generator = train_datagen.flow(
        X_train, y_train,
        batch_size=BATCH_SIZE,
        shuffle=True
    )

    val_generator = val_datagen.flow(
        X_val, y_val,
        batch_size=BATCH_SIZE,
        shuffle=False
    )


    history = model.fit(
        train_generator,
        validation_data=val_generator,
        epochs=EPOCHS,
        callbacks=callbacks,
        class_weight=class_weights,
        verbose=1
    )

    return history

In [None]:
# ============================================================================
# 5. EVALUATION FUNCTION
# ============================================================================

def evaluate_model(model, X_test, y_test, class_names, model_name):

    print(f"\n{'='*80}")
    print(f"Evaluating {model_name} on Test Set")
    print(f"{'='*80}")


    y_pred_probs = model.predict(X_test, verbose=0)
    y_pred = np.argmax(y_pred_probs, axis=1)
    y_true = np.argmax(y_test, axis=1)


    accuracy = accuracy_score(y_true, y_pred)


    kappa = cohen_kappa_score(y_true, y_pred)


    precision, recall, f1, support = precision_recall_fscore_support(
        y_true, y_pred, average=None
    )


    avg_precision = np.mean(precision)
    avg_recall = np.mean(recall)
    avg_f1 = np.mean(f1)


    print(f"\nOverall Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
    print(f"Cohen's Kappa:    {kappa:.4f}")
    print(f"\nAverage Metrics:")
    print(f"  Precision: {avg_precision:.4f}")
    print(f"  Recall:    {avg_recall:.4f}")
    print(f"  F1-Score:  {avg_f1:.4f}")

    print(f"\nPer-Class Performance:")
    print(f"{'Class':<12} {'Precision':<12} {'Recall':<12} {'F1-Score':<12} {'Support':<12}")
    print("-" * 60)
    for i, class_name in enumerate(class_names):
        print(f"{class_name:<12} {precision[i]:<12.4f} {recall[i]:<12.4f} "
              f"{f1[i]:<12.4f} {support[i]:<12.0f}")


    cm = confusion_matrix(y_true, y_pred)

    plt.figure(figsize=(8, 7))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names,
                yticklabels=class_names,
                cbar_kws={'label': 'Count'})
    plt.title(f'Confusion Matrix - {model_name}',
              fontsize=14, fontweight='bold', pad=20)
    plt.xlabel('Predicted Label', fontsize=12)
    plt.ylabel('True Label', fontsize=12)
    plt.tight_layout()
    plt.savefig(f'{model_name}_confusion_matrix.png', dpi=300, bbox_inches='tight')
    plt.show()


    print(f"\nDetailed Classification Report:")
    print(classification_report(y_true, y_pred, target_names=class_names, digits=4))


    try:
        auc_roc = roc_auc_score(y_test, y_pred_probs, multi_class='ovr', average='macro')
        print(f"\nAUC-ROC (macro): {auc_roc:.4f}")
    except:
        print("\nAUC-ROC: Could not compute")
        auc_roc = None

    return {
        'accuracy': accuracy,
        'kappa': kappa,
        'precision': avg_precision,
        'recall': avg_recall,
        'f1_score': avg_f1,
        'confusion_matrix': cm,
        'y_pred': y_pred,
        'y_pred_probs': y_pred_probs,
        'y_true': y_true,
        'auc_roc': auc_roc
    }

In [None]:
# ============================================================================
# 5A. CALIBRATION METRICS - ECE (Expected Calibration Error)
# ============================================================================

def compute_ece(y_true, y_pred_probs, n_bins=10):


    confidences = np.max(y_pred_probs, axis=1)
    predictions = np.argmax(y_pred_probs, axis=1)
    accuracies = predictions == y_true

    bin_boundaries = np.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]

    ece = 0.0
    bin_data = {
        'accuracies': [],
        'confidences': [],
        'counts': [],
        'gaps': []
    }

    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):

        in_bin = (confidences > bin_lower) & (confidences <= bin_upper)
        bin_count = np.sum(in_bin)

        if bin_count > 0:

            bin_accuracy = np.mean(accuracies[in_bin])

            bin_confidence = np.mean(confidences[in_bin])

            gap = np.abs(bin_confidence - bin_accuracy)


            ece += (bin_count / len(y_true)) * gap

            bin_data['accuracies'].append(bin_accuracy)
            bin_data['confidences'].append(bin_confidence)
            bin_data['counts'].append(bin_count)
            bin_data['gaps'].append(gap)
        else:
            bin_data['accuracies'].append(0)
            bin_data['confidences'].append(0)
            bin_data['counts'].append(0)
            bin_data['gaps'].append(0)

    return ece, bin_data


def plot_calibration_curve(y_true, y_pred_probs, model_name, n_bins=10, save_path=None):

    ece, bin_data = compute_ece(y_true, y_pred_probs, n_bins)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))


    bin_centers = np.linspace(0, 1, n_bins + 1)[:-1] + 0.5 / n_bins

    ax1.plot([0, 1], [0, 1], 'k--', label='Perfect Calibration', linewidth=2)
    ax1.bar(bin_centers, bin_data['accuracies'], width=1.0/n_bins,
            alpha=0.7, edgecolor='black', label='Model Output')
    ax1.set_xlabel('Confidence', fontsize=12, fontweight='bold')
    ax1.set_ylabel('Accuracy', fontsize=12, fontweight='bold')
    ax1.set_title(f'Calibration Curve\nECE = {ece:.4f}',
                  fontsize=13, fontweight='bold')
    ax1.legend(loc='upper left')
    ax1.grid(True, alpha=0.3)
    ax1.set_xlim([0, 1])
    ax1.set_ylim([0, 1])


    confidences = np.max(y_pred_probs, axis=1)
    ax2.hist(confidences, bins=20, edgecolor='black', alpha=0.7, color='steelblue')
    ax2.set_xlabel('Confidence', fontsize=12, fontweight='bold')
    ax2.set_ylabel('Count', fontsize=12, fontweight='bold')
    ax2.set_title(f'Confidence Distribution\nMean: {np.mean(confidences):.3f}',
                  fontsize=13, fontweight='bold')
    ax2.grid(True, alpha=0.3, axis='y')
    ax2.set_xlim([0, 1])

    plt.suptitle(f'Model Calibration - {model_name}',
                 fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

    print(f"\nCalibration Metrics:")
    print(f"  ECE (Expected Calibration Error): {ece:.4f}")
    print(f"  Mean Confidence: {np.mean(confidences):.4f}")
    print(f"  Max Confidence: {np.max(confidences):.4f}")
    print(f"  Min Confidence: {np.min(confidences):.4f}")

    return ece


# ============================================================================
# 5B. SOFTMAX CONFIDENCE PLOT
# ============================================================================

def plot_softmax_confidence(y_true, y_pred_probs, class_names, model_name, save_path=None):

    predictions = np.argmax(y_pred_probs, axis=1)
    confidences = np.max(y_pred_probs, axis=1)


    correct_mask = predictions == y_true
    correct_conf = confidences[correct_mask]
    incorrect_conf = confidences[~correct_mask]

    fig, axes = plt.subplots(2, 2, figsize=(14, 10))


    ax1 = axes[0, 0]
    box_data = [correct_conf, incorrect_conf]
    bp = ax1.boxplot(box_data, labels=['Correct', 'Incorrect'],
                     patch_artist=True, showmeans=True)
    bp['boxes'][0].set_facecolor('lightgreen')
    bp['boxes'][1].set_facecolor('lightcoral')
    ax1.set_ylabel('Confidence Score', fontsize=11, fontweight='bold')
    ax1.set_title('Confidence Distribution: Correct vs Incorrect',
                  fontsize=12, fontweight='bold')
    ax1.grid(True, alpha=0.3, axis='y')
    ax1.set_ylim([0, 1.05])


    ax1.text(1, np.mean(correct_conf) + 0.05,
             f'μ={np.mean(correct_conf):.3f}',
             ha='center', fontweight='bold')
    ax1.text(2, np.mean(incorrect_conf) + 0.05,
             f'μ={np.mean(incorrect_conf):.3f}',
             ha='center', fontweight='bold')


    ax2 = axes[0, 1]
    ax2.hist(correct_conf, bins=20, alpha=0.6, label='Correct',
             color='green', edgecolor='black')
    ax2.hist(incorrect_conf, bins=20, alpha=0.6, label='Incorrect',
             color='red', edgecolor='black')
    ax2.set_xlabel('Confidence Score', fontsize=11, fontweight='bold')
    ax2.set_ylabel('Count', fontsize=11, fontweight='bold')
    ax2.set_title('Confidence Histogram', fontsize=12, fontweight='bold')
    ax2.legend()
    ax2.grid(True, alpha=0.3, axis='y')


    ax3 = axes[1, 0]
    class_confidences = []
    for i in range(len(class_names)):
        class_mask = y_true == i
        class_conf = confidences[class_mask]
        class_confidences.append(class_conf)

    bp2 = ax3.boxplot(class_confidences, labels=class_names,
                      patch_artist=True, showmeans=True)
    colors = ['lightblue', 'lightgreen', 'lightyellow', 'lightcoral']
    for patch, color in zip(bp2['boxes'], colors):
        patch.set_facecolor(color)

    ax3.set_ylabel('Confidence Score', fontsize=11, fontweight='bold')
    ax3.set_xlabel('Class', fontsize=11, fontweight='bold')
    ax3.set_title('Confidence by True Class', fontsize=12, fontweight='bold')
    ax3.grid(True, alpha=0.3, axis='y')
    ax3.set_ylim([0, 1.05])
    plt.setp(ax3.xaxis.get_majorticklabels(), rotation=15, ha='right')


    ax4 = axes[1, 1]
    sorted_correct = np.sort(correct_conf)
    sorted_incorrect = np.sort(incorrect_conf)

    ax4.plot(sorted_correct, np.linspace(0, 1, len(sorted_correct)),
             label='Correct', color='green', linewidth=2)
    ax4.plot(sorted_incorrect, np.linspace(0, 1, len(sorted_incorrect)),
             label='Incorrect', color='red', linewidth=2)
    ax4.set_xlabel('Confidence Score', fontsize=11, fontweight='bold')
    ax4.set_ylabel('Cumulative Probability', fontsize=11, fontweight='bold')
    ax4.set_title('Cumulative Distribution Function', fontsize=12, fontweight='bold')
    ax4.legend()
    ax4.grid(True, alpha=0.3)
    ax4.set_xlim([0, 1])
    ax4.set_ylim([0, 1])

    plt.suptitle(f'Softmax Confidence Analysis - {model_name}',
                 fontsize=14, fontweight='bold', y=1.00)
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()


    print(f"\nSoftmax Confidence Statistics:")
    print(f"  Correct Predictions:")
    print(f"    Mean: {np.mean(correct_conf):.4f}")
    print(f"    Std:  {np.std(correct_conf):.4f}")
    print(f"    Min:  {np.min(correct_conf):.4f}")
    print(f"  Incorrect Predictions:")
    print(f"    Mean: {np.mean(incorrect_conf):.4f}")
    print(f"    Std:  {np.std(incorrect_conf):.4f}")
    print(f"    Max:  {np.max(incorrect_conf):.4f}")


# ============================================================================
# 5C. ATTENTION HEATMAP VISUALIZATION
# ============================================================================

def extract_attention_heatmap(model, image):


    attention_layer = None
    for layer in model.layers:
        if isinstance(layer, AttentionBlock):
            attention_layer = layer
            break

    if attention_layer is None:
        print("Warning: No AttentionBlock found in model")
        return None


    attention_model = keras.Model(
        inputs=model.input,
        outputs=attention_layer.output
    )


    img_batch = np.expand_dims(image, axis=0)
    attention_output = attention_model.predict(img_batch, verbose=0)


    attention_map = np.mean(attention_output[0], axis=-1)  # (H, W)


    attention_map = np.maximum(attention_map, 0)
    attention_map = attention_map / (np.max(attention_map) + 1e-8)

    return attention_map


def plot_attention_overlay(model, X_samples, y_true, y_pred, class_names,
                           model_name, num_samples=6, save_path=None):

    num_samples = min(num_samples, len(X_samples))

    fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4*num_samples))

    if num_samples == 1:
        axes = axes.reshape(1, -1)

    for i in range(num_samples):
        image = X_samples[i]
        true_class = class_names[y_true[i]]
        pred_class = class_names[y_pred[i]]


        attention_map = extract_attention_heatmap(model, image)

        if attention_map is None:
            print(f"Skipping sample {i+1}: Could not extract attention")
            continue


        attention_resized = cv2.resize(
            attention_map,
            (image.shape[1], image.shape[0]),
            interpolation=cv2.INTER_LINEAR
        )


        heatmap_uint8 = np.uint8(255 * attention_resized)
        heatmap_colored_bgr = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
        heatmap_colored = cv2.cvtColor(heatmap_colored_bgr, cv2.COLOR_BGR2RGB)


        img_uint8 = np.uint8(image * 255)
        overlay = cv2.addWeighted(img_uint8, 0.6, heatmap_colored, 0.4, 0)


        axes[i, 0].imshow(img_uint8)
        axes[i, 0].set_title(f'Original\nTrue: {true_class}',
                            fontsize=10, fontweight='bold')
        axes[i, 0].axis('off')


        im = axes[i, 1].imshow(attention_resized, cmap='jet', vmin=0, vmax=1)
        axes[i, 1].set_title('Attention Heatmap', fontsize=10, fontweight='bold')
        axes[i, 1].axis('off')


        axes[i, 2].imshow(overlay)
        color = 'green' if true_class == pred_class else 'red'
        axes[i, 2].set_title(f'Overlay\nPred: {pred_class}',
                            fontsize=10, fontweight='bold', color=color)
        axes[i, 2].axis('off')


        if i == 0:
            cbar = plt.colorbar(im, ax=axes[i, 1], fraction=0.046, pad=0.04)
            cbar.set_label('Attention\nIntensity', fontsize=8)
            cbar.set_ticks([0, 0.5, 1])
            cbar.set_ticklabels(['Low', 'Med', 'High'], fontsize=7)

    plt.suptitle(f'Attention Heatmap Visualization - {model_name}',
                 fontsize=14, fontweight='bold', y=0.995)
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()


# ============================================================================
# 5D. COMPREHENSIVE ANALYSIS FUNCTION
# ============================================================================

def comprehensive_analysis(model, X_test, y_test, class_names, model_name):

    print(f"\n{'='*80}")
    print(f"COMPREHENSIVE ANALYSIS: {model_name}")
    print(f"{'='*80}")


    results = evaluate_model(model, X_test, y_test, class_names, model_name)


    print(f"\n{'='*80}")
    print("CALIBRATION ANALYSIS")
    print(f"{'='*80}")
    ece = plot_calibration_curve(
        results['y_true'],
        results['y_pred_probs'],
        model_name,
        n_bins=10,
        save_path=f'{model_name}_calibration.png'
    )
    results['ece'] = ece


    print(f"\n{'='*80}")
    print("SOFTMAX CONFIDENCE ANALYSIS")
    print(f"{'='*80}")
    plot_softmax_confidence(
        results['y_true'],
        results['y_pred_probs'],
        class_names,
        model_name,
        save_path=f'{model_name}_confidence.png'
    )


    print(f"\n{'='*80}")
    print("ATTENTION HEATMAP VISUALIZATION")
    print(f"{'='*80}")


    correct_indices = np.where(results['y_pred'] == results['y_true'])[0]
    incorrect_indices = np.where(results['y_pred'] != results['y_true'])[0]


    selected_indices = []
    if len(correct_indices) >= 4:
        selected_indices.extend(np.random.choice(correct_indices, 4, replace=False))
    else:
        selected_indices.extend(correct_indices)

    if len(incorrect_indices) >= 2:
        selected_indices.extend(np.random.choice(incorrect_indices, 2, replace=False))
    elif len(incorrect_indices) > 0:
        selected_indices.extend(incorrect_indices)

    if len(selected_indices) > 0:
        selected_indices = selected_indices[:6]
        plot_attention_overlay(
            model,
            X_test[selected_indices],
            results['y_true'][selected_indices],
            results['y_pred'][selected_indices],
            class_names,
            model_name,
            num_samples=len(selected_indices),
            save_path=f'{model_name}_attention_overlay.png'
        )
    else:
        print("No samples available for attention visualization")

    print(f"\n{'='*80}")
    print("ANALYSIS COMPLETE")
    print(f"{'='*80}")
    print(f"\nSummary for {model_name}:")
    print(f"  Accuracy:      {results['accuracy']:.4f}")
    print(f"  Cohen's Kappa: {results['kappa']:.4f}")
    print(f"  Precision:     {results['precision']:.4f}")
    print(f"  Recall:        {results['recall']:.4f}")
    print(f"  F1-Score:      {results['f1_score']:.4f}")
    print(f"  ECE:           {results['ece']:.4f}")
    if results['auc_roc']:
        print(f"  AUC-ROC:       {results['auc_roc']:.4f}")

    return results

In [None]:
# ============================================================================
# 6. VISUALIZATION FUNCTIONS
# ============================================================================

def plot_training_history(history, model_name):

    fig, axes = plt.subplots(2, 2, figsize=(14, 10))


    axes[0, 0].plot(history.history['accuracy'], label='Train', color='blue')
    axes[0, 0].plot(history.history['val_accuracy'], label='Validation', color='red')
    axes[0, 0].set_title('Model Accuracy', fontweight='bold')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Accuracy')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)


    axes[0, 1].plot(history.history['loss'], label='Train', color='blue')
    axes[0, 1].plot(history.history['val_loss'], label='Validation', color='red')
    axes[0, 1].set_title('Model Loss', fontweight='bold')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Loss')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)

    axes[1, 0].plot(history.history['precision'], label='Train', color='blue')
    axes[1, 0].plot(history.history['val_precision'], label='Validation', color='red')
    axes[1, 0].set_title('Model Precision', fontweight='bold')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Precision')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)


    axes[1, 1].plot(history.history['recall'], label='Train', color='blue')
    axes[1, 1].plot(history.history['val_recall'], label='Validation', color='red')
    axes[1, 1].set_title('Model Recall', fontweight='bold')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Recall')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)

    plt.suptitle(f'Training History - {model_name}',
                 fontsize=16, fontweight='bold', y=1.00)
    plt.tight_layout()
    plt.savefig(f'{model_name}_training_history.png', dpi=300, bbox_inches='tight')
    plt.show()

In [None]:
# ============================================================================
# 7. MAIN EXECUTION
# ============================================================================

def main():

    print("\n" + "="*80)
    print("LOADING DATASET")
    print("="*80)

    # Load dataset