In [None]:
"""
imports all necessary libraries for data handling, deep learning, visualization, and xai (explainable ai).
 os: for file and directory operations
 numpy/pandas: for numerical and tabular data processing
 matplotlib/seaborn: for data visualization
 tensorflow/keras: for building and training deep learning models
 sklearn: for evaluation metrics
 lime/shap: for model interpretability (xai)
 cv2: for image processing
 warnings: to suppress unnecessary warnings
"""
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.applications import VGG16
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras import backend as K
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc, roc_auc_score
from sklearn.metrics import precision_recall_curve, average_precision_score, f1_score
import matplotlib.cm as cm

from tensorflow.keras.layers import (
    Conv2D, MaxPooling2D, Flatten, Dense, Dropout,
    BatchNormalization, GlobalAveragePooling2D, Activation,
    Input, Add, Lambda, SeparableConv2D, SpatialDropout2D
)
from tensorflow.keras.regularizers import l2
import tensorflow.keras.backend as K

# XAI Libraries
from lime import lime_image
from skimage.segmentation import mark_boundaries
import shap
import cv2
import warnings
warnings.filterwarnings('ignore')

"""
sets random seeds for numpy and tensorflow to ensure reproducibility.
np.random.seed(42): fixes numpy's random number generation
tf.random.set_seed(42): fixes tensorflow's random initialization
this helps in getting consistent results across multiple runs.
"""
np.random.seed(42)
tf.random.set_seed(42)

"""
checks if the dataset directory exists and falls back to a local path if needed.
first checks the kaggle path (/kaggle/input/...)
if not found, tries a local path ('./chest_xray')
raises an error if neither path exists.
"""
data_dir = '/kaggle/input/chest-xray-pneumonia/chest_xray'
print("TensorFlow Version:", tf.__version__)
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
gpus = tf.config.list_physical_devices('GPU')
if gpus:
  try:
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
    logical_gpus = tf.config.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
  except RuntimeError as e:
    print(e)
else:
    print("No GPU detected. Training will run on CPU (which will be slow).")
"""
defines train, test, and validation directories and checks their existence.
train_dir: path to training data
test_dir: path to testing data
val_dir: path to validation data
raises an error if any of these directories are missing.
"""
if not os.path.isdir(data_dir):
    # Fallback for local execution if Kaggle path not found
    print(f"Kaggle directory {data_dir} not found. Trying local path './chest_xray'.")
    data_dir = './chest_xray' # Example local path
    if not os.path.isdir(data_dir):
        raise FileNotFoundError(f"Dataset directory not found at: {data_dir} or the Kaggle path. Please ensure the path is correct and the dataset is downloaded/extracted.")

train_dir = os.path.join(data_dir, 'train')
test_dir = os.path.join(data_dir, 'test')
val_dir = os.path.join(data_dir, 'val')

# Check subdirectories
if not os.path.isdir(train_dir) or not os.path.isdir(test_dir) or not os.path.isdir(val_dir):
     raise FileNotFoundError(f"Train, test, or val directory not found within {data_dir}. Check the dataset structure.")


2025-05-15 09:00:24.543173: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1747299624.724218      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747299624.774283      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [None]:
print("TensorFlow Version:", tf.__version__)
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
gpus = tf.config.list_physical_devices('GPU')
if gpus:
  try:
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
    logical_gpus = tf.config.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
  except RuntimeError as e:
    print(e)
else:
    print("No GPU detected. Training will run on CPU (which will be slow).")


In [None]:
# 1. Dataset Exploration
def explore_dataset():
    """
    explores and visualizes the chest x-ray dataset distribution and characteristics.
    performs the following analyses:
    counts images in each class (normal/pneumonia) across train/val/test sets
    calculates class imbalance ratios
    plots class distribution bar charts
    displays sample normal and pneumonia x-ray images
    analyzes image dimensions in the dataset

    handles potential file access errors gracefully with try-except blocks.

    returns:
        dict: a dictionary containing image counts for each class and dataset split
              format: {
                  'train': {'normal': int, 'pneumonia': int},
                  'val': {'normal': int, 'pneumonia': int},
                  'test': {'normal': int, 'pneumonia': int}
              }

    raises:
        FileNotFoundError: if required subdirectories are missing
        other exceptions: caught and printed without stopping execution
    """
    try:
        normal_train = len(os.listdir(os.path.join(train_dir, 'NORMAL')))
        pneumonia_train = len(os.listdir(os.path.join(train_dir, 'PNEUMONIA')))

        normal_val = len(os.listdir(os.path.join(val_dir, 'NORMAL')))
        pneumonia_val = len(os.listdir(os.path.join(val_dir, 'PNEUMONIA')))

        normal_test = len(os.listdir(os.path.join(test_dir, 'NORMAL')))
        pneumonia_test = len(os.listdir(os.path.join(test_dir, 'PNEUMONIA')))
    except FileNotFoundError as e:
        print(f"Error accessing dataset files: {e}")
        print("Please ensure the 'NORMAL' and 'PNEUMONIA' subdirectories exist in train, test, and val folders.")
        raise

    print("Dataset Distribution:")
    ratio_train = pneumonia_train / normal_train if normal_train > 0 else float('inf')
    ratio_val = pneumonia_val / normal_val if normal_val > 0 else float('inf')
    ratio_test = pneumonia_test / normal_test if normal_test > 0 else float('inf')

    print(f"Training: Normal={normal_train}, Pneumonia={pneumonia_train}, Ratio=1:{ratio_train:.2f}")
    print(f"Validation: Normal={normal_val}, Pneumonia={pneumonia_val}, Ratio=1:{ratio_val:.2f}")
    print(f"Testing: Normal={normal_test}, Pneumonia={pneumonia_test}, Ratio=1:{ratio_test:.2f}")
"""
    creates visualization of class distribution across train and test sets.
    generates a 2-panel bar plot comparing normal vs pneumonia cases
    uses skyblue for normal cases, salmon for pneumonia cases
    adds grid lines for better readability
    saves the plot as 'class_distribution.png'
    """
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.bar(['Normal', 'Pneumonia'], [normal_train, pneumonia_train], color=['skyblue', 'salmon'])
    plt.title('Training Set Distribution')
    plt.ylabel('Number of Images')
    plt.grid(axis='y', alpha=0.3)
    plt.subplot(1, 2, 2)
    plt.bar(['Normal', 'Pneumonia'], [normal_test, pneumonia_test], color=['skyblue', 'salmon'])
    plt.title('Testing Set Distribution')
    plt.ylabel('Number of Images')
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    plt.savefig('class_distribution.png')
    plt.show()

    plt.figure(figsize=(12, 6))
    try:
        normal_files = os.listdir(os.path.join(train_dir, 'NORMAL'))
        if not normal_files: raise FileNotFoundError("No files found in training NORMAL directory")
        normal_img_path = os.path.join(train_dir, 'NORMAL', normal_files[0])
        normal_img = plt.imread(normal_img_path)

        pneumonia_files = os.listdir(os.path.join(train_dir, 'PNEUMONIA'))
        if not pneumonia_files: raise FileNotFoundError("No files found in training PNEUMONIA directory")
        pneumonia_img_path = os.path.join(train_dir, 'PNEUMONIA', pneumonia_files[0])
        pneumonia_img = plt.imread(pneumonia_img_path)

        plt.subplot(1, 2, 1)
        plt.imshow(normal_img, cmap='gray')
        plt.title('Normal')
        plt.axis('off')
        plt.subplot(1, 2, 2)
        plt.imshow(pneumonia_img, cmap='gray')
        plt.title('Pneumonia')
        plt.axis('off')
        plt.tight_layout()
        plt.savefig('sample_images.png')
        plt.show()
    except FileNotFoundError as e:
        print(f"Could not load sample images: {e}")
    except Exception as e:
        print(f"An error occurred displaying sample images: {e}")
    """
    analyzes image dimensions in the dataset.
    checks the shapes of the first 100 images in each class
    prints unique shapes found in each class
    helps identify if resizing/normalization will be needed
    """
    try:
        normal_sizes = []
        pneumonia_sizes = []
        for img_file in os.listdir(os.path.join(train_dir, 'NORMAL'))[:100]:
            img = plt.imread(os.path.join(train_dir, 'NORMAL', img_file))
            normal_sizes.append(img.shape)
        for img_file in os.listdir(os.path.join(train_dir, 'PNEUMONIA'))[:100]:
            img = plt.imread(os.path.join(train_dir, 'PNEUMONIA', img_file))
            pneumonia_sizes.append(img.shape)
        print("\nImage Size Distribution (Sample):")
        print("Normal images (shapes):", set(normal_sizes))
        print("Pneumonia images (shapes):", set(pneumonia_sizes))
    except Exception as e:
        print(f"Could not analyze image sizes: {e}")

    return {
        'train': {'normal': normal_train, 'pneumonia': pneumonia_train},
        'val': {'normal': normal_val, 'pneumonia': pneumonia_val},
        'test': {'normal': normal_test, 'pneumonia': pneumonia_test}
    }

In [None]:
# # 2. Data Preprocessing
def create_data_generators(img_size=224, batch_size=32, validation_split=0.2):
    """
    creates and configures data generators for training, validation, and testing.
    applies data augmentation to training set and maintains separate validation/test sets.
    
    args:
        img_size (int): target dimensions for images (square, default 224x224)
        batch_size (int): number of images per batch (default 32)
        validation_split (float): fraction of training data to use for validation (default 0.2)
    
    returns:
        tuple: contains three keras.preprocessing.image.DirectoryIterator objects:
            - train_gen: training data generator with augmentation
            - val_gen: validation data generator
            - test_gen: test data generator
    
    raises:
        exception: if directory structure is invalid or images can't be loaded
    """
    
    """
    configures the training data generator with extensive augmentation.
    transformations include:
    normalization (rescale to 0-1 range)
    random rotations (up to 15 degrees)
    small width/height shifts (10% of image size)
    slight shearing and zooming (10% range)
    horizontal flipping
    brightness adjustment (±10%)
    automatic splitting into train/validation sets
    """
    train_datagen = ImageDataGenerator(
        rescale=1./255,
        rotation_range=15,
        width_shift_range=0.1,
        height_shift_range=0.1,
        shear_range=0.1,
        zoom_range=0.1,
        horizontal_flip=True,
        fill_mode='nearest',
        brightness_range=[0.9, 1.1],
        validation_split=validation_split  # This will split the training data
    )
    """
    configures a simpler generator for validation and testing.
    only applies normalization since we don't want to augment these sets.
    """
    val_test_datagen = ImageDataGenerator(rescale=1./255)

    try:
        """
        creates the main training generator:
        uses 80% of training data (1-validation_split)
        shuffles the samples for better training
        converts images to rgb (3 channels)
        uses binary classification mode (normal=0, pneumonia=1)
        """
        train_gen = train_datagen.flow_from_directory(
            train_dir,
            target_size=(img_size, img_size),
            batch_size=batch_size,
            class_mode='binary',
            color_mode='rgb',
            shuffle=True,
            subset='training'  # This is the key change
        )
        
        """
        creates validation generator:
        uses 20% of training data (validation_split)
        maintains original order (no shuffling)
        same image processing as training (minus augmentation)
        """
        val_gen = train_datagen.flow_from_directory(
            train_dir,
            target_size=(img_size, img_size),
            batch_size=batch_size,
            class_mode='binary',
            color_mode='rgb',
            shuffle=False,
            subset='validation'  # This is the key change
        )
        
        """
        creates test generator:
        uses the original held-out test set
        no augmentation or shuffling
        same preprocessing as validation
        """
        test_gen = val_test_datagen.flow_from_directory(
            test_dir,
            target_size=(img_size, img_size),
            batch_size=batch_size,
            class_mode='binary',
            color_mode='rgb',
            shuffle=False
        )
        
        print(f"Found {train_gen.samples} images for training ({1-validation_split:.0%} of training set)")
        print(f"Found {val_gen.samples} images for validation ({validation_split:.0%} of training set)")
        print(f"Found {test_gen.samples} images for testing")
        
    except Exception as e:
        print(f"Error creating data generators: {e}")
        raise
    return train_gen, val_gen, test_gen

In [None]:
# 3. Calculate Class Weights
def calculate_class_weights(dataset_info):
    def calculate_class_weights(dataset_info):
    """
    calculates class weights to address dataset imbalance during model training.
    weights are inversely proportional to class frequencies to help the model learn
    from underrepresented classes more effectively.

    args:
        dataset_info (dict): dictionary containing class counts from explore_dataset()
                            format: {
                                'train': {'normal': int, 'pneumonia': int},
                                'val': {...},
                                'test': {...}
                            }

    returns:
        dict: class weights dictionary in format {0: weight_normal, 1: weight_pneumonia}
              returns equal weights (1.0) if either class has zero samples

    note:
        weight calculation formula: weight = total_samples / (num_classes * class_count)
        this gives more weight to underrepresented classes during training
    """
    normal_count = dataset_info['train']['normal']
    pneumonia_count = dataset_info['train']['pneumonia']
    if normal_count == 0 or pneumonia_count == 0:
        print("Warning: One class has zero samples in the training set. Cannot compute weights.")
        return {0: 1.0, 1: 1.0}
    total = normal_count + pneumonia_count
    weight_normal = total / (2 * normal_count)
    weight_pneumonia = total / (2 * pneumonia_count)
    class_weights = {0: weight_normal, 1: weight_pneumonia}
    print(f"Calculated Class weights: Normal (0)={weight_normal:.2f}, Pneumonia (1)={weight_pneumonia:.2f}")
    return class_weights

# 4. Focal Loss
def focal_loss(gamma=2., alpha=.25):
    """
    implements focal loss function for handling severe class imbalance.
    focal loss down-weights easy examples and focuses training on hard misclassified examples.

    args:
        gamma (float): focusing parameter (default=2.0)
                      higher gamma focuses more on hard examples
        alpha (float): weighting factor for class imbalance (default=0.25)
                      balances positive/negative examples importance

    returns:
        function: focal loss function ready for use in model compilation

    

    note:
        reduces loss contribution from easily classified examples
        helps when class imbalance is extreme (e.g., 1:100 ratio)
        combines well with class weights for doubly-robust imbalance handling
    """
    def focal_loss_fixed(y_true, y_pred):
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.clip_by_value(y_pred, K.epsilon(), 1. - K.epsilon())
        pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
        pt_0 = tf.where(tf.equal(y_true, 0), 1. - y_pred, tf.ones_like(y_pred))
        loss = -K.mean(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1) + \
                       (1 - alpha) * K.pow(1. - pt_0, gamma) * K.log(pt_0))
        return loss
    return focal_loss_fixed

In [None]:
# Enhanced CNN model with more layers and both ReLU/Swish options
def create_custom_cnn(img_size=224, activation='relu'):
    
    """
    creates an enhanced custom cnn model with flexible activation functions.
    the model features:
    4 convolutional blocks with increasing filters
    batch normalization after each conv layer
    spatial dropout for better regularization
    optional swish activation (alternative to relu)
    focal loss for handling class imbalance
    
    args:
        img_size (int): input image dimensions (default 224x224)
        activation (str): activation function to use ('relu' or 'swish', default 'relu')
    
    returns:
        tensorflow.keras.models.Sequential: compiled cnn model ready for training
    
    architecture details:
        block 1: 32 filters, spatial dropout 0.2
        block 2: 64 filters, spatial dropout 0.3
        block 3: 128 filters, spatial dropout 0.4
        block 4: 256 filters, spatial dropout 0.5
        dense layers: 512 and 256 units with 0.5 dropout
        output: single sigmoid unit for binary classification
    """
    
    
        """
        swish activation function: x * sigmoid(x)
        provides non-monotonic "bump" that helps with gradient flow
        """
    def swish(x):
        return x * K.sigmoid(x)
    
    # Register Swish as a custom activation
    tf.keras.utils.get_custom_objects().update({'swish': Activation(swish)})
    
    model = Sequential(name=f"Enhanced_CNN_{activation}")
    
    # Determine activation function to use
    if activation == 'swish':
        act_func = swish
    else:  # Default to ReLU
        act_func = 'relu'
    
  
    """
    block 1: initial feature extraction
    two 3x3 conv layers with 32 filters
    batch norm for stable training
    max pooling for dimensionality reduction
    spatial dropout (more effective than regular dropout for conv layers)
    """
    model.add(Conv2D(32, (3, 3), padding='same', input_shape=(img_size, img_size, 3)))
    model.add(Activation(act_func))
    model.add(BatchNormalization())
    model.add(Conv2D(32, (3, 3), padding='same'))
    model.add(Activation(act_func))
    model.add(BatchNormalization())
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(SpatialDropout2D(0.2))
    
    """
    block 2: intermediate feature learning
    increased to 64 filters
    higher spatial dropout (0.3)
    same structure as block 1
    """
    model.add(Conv2D(64, (3, 3), padding='same'))
    model.add(Activation(act_func))
    model.add(BatchNormalization())
    model.add(Conv2D(64, (3, 3), padding='same'))
    model.add(Activation(act_func))
    model.add(BatchNormalization())
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(SpatialDropout2D(0.3))
    
    """
    block 3: higher-level feature extraction
    128 filters for more complex patterns
    increased spatial dropout (0.4)
    """
    model.add(Conv2D(128, (3, 3), padding='same'))
    model.add(Activation(act_func))
    model.add(BatchNormalization())
    model.add(Conv2D(128, (3, 3), padding='same'))
    model.add(Activation(act_func))
    model.add(BatchNormalization())
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(SpatialDropout2D(0.4))
    
    """
    block 4: final convolutional block
    256 filters for most complex features
    highest spatial dropout (0.5)
    """
    model.add(Conv2D(256, (3, 3), padding='same'))
    model.add(Activation(act_func))
    model.add(BatchNormalization())
    model.add(Conv2D(256, (3, 3), padding='same'))
    model.add(Activation(act_func))
    model.add(BatchNormalization())
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(SpatialDropout2D(0.5))
    
     """
    block 4: final convolutional block
     256 filters for most complex features
     highest spatial dropout (0.5)
    """
    model.add(Flatten())
    model.add(Dense(512))
    model.add(Activation(act_func))
    model.add(BatchNormalization())
    model.add(Dropout(0.5))
    
    model.add(Dense(256))
    model.add(Activation(act_func))
    model.add(BatchNormalization())
    model.add(Dropout(0.5))
    
    model.add(Dense(1, activation='sigmoid'))


     """
    model compilation:
     adam optimizer with low learning rate (1e-4)
     focal loss for class imbalance
     multiple metrics: accuracy, auc, precision, recall
    """
    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
    model.compile(
        optimizer=optimizer,
        loss=focal_loss(gamma=2, alpha=0.25),
        metrics=[
            'accuracy',
            tf.keras.metrics.AUC(name='auc'),
            tf.keras.metrics.Precision(name='precision'),
            tf.keras.metrics.Recall(name='recall')
        ]
    )
    print(f"Enhanced CNN with {activation} activation created.")
    model.summary()
    return model

In [None]:
# 6. Pre-trained VGG16 Model
def create_pretrained_vgg16(img_size=224):
    """
    creates a pneumonia classifier using transfer learning with vgg16.
    leverages imagenet-pretrained weights for feature extraction with a custom classification head.
    
    args:
        img_size (int): input image dimensions (default 224x224 - must match vgg16's expected size)
    
    returns:
        tuple: (full_model, base_model)
            full_model: compiled keras model ready for training
            base_model: original vgg16 model (frozen)
    
    model architecture:
         frozen vgg16 base (convolutional layers only)
         global average pooling
         dense layer (512 units, relu)
         50% dropout
         sigmoid output layer
    
    training approach:
        base model weights remain frozen (not trainable)
        only custom head layers are trained
        uses focal loss to handle class imbalance
        low learning rate (1e-4) for gentle fine-tuning
    """
    
    """
    load pre-trained vgg16 model:
        weights: imagenet (pre-trained on large image dataset)
        include_top: false (we replace the original classification head)
        input_shape: must match our image size (224x224 is vgg16's default)
    """
    base_model = VGG16(
        weights='imagenet',
        include_top=False,
        input_shape=(img_size, img_size, 3)
    )
    base_model.trainable = False
    print(f"Loaded VGG16 with {len(base_model.layers)} layers. Base model frozen.")
    """
    build custom classification head:
         global average pooling reduces spatial dimensions
         dense layer with relu activation
         dropout for regularization
         final sigmoid output for binary classification
    """
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(512, activation='relu')(x)
    x = Dropout(0.5)(x)
    predictions = Dense(1, activation='sigmoid')(x)
    model = Model(inputs=base_model.input, outputs=predictions, name="VGG16_Transfer")

    """
    compile model with:
         adam optimizer (low learning rate 1e-4)
         focal loss (gamma=2, alpha=0.25)
         comprehensive metrics:
             accuracy
             auc (area under roc curve)
             precision
             recall
    """
    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
    model.compile(
        optimizer=optimizer,
        loss=focal_loss(gamma=2, alpha=0.25),
        metrics=[
            'accuracy',
            tf.keras.metrics.AUC(name='auc'),
            tf.keras.metrics.Precision(name='precision'),
            tf.keras.metrics.Recall(name='recall')
        ]
    )
    print("VGG16 model with custom head created.")
    model.summary()
    return model, base_model

In [None]:
# 7. Model Training and Fine-tuning
def train_model(model, train_gen, val_gen, class_weights, epochs=200, model_name='model'):
    """Train model with callbacks"""
    print(f"\n--- Starting Training: {model_name} ---")
    callbacks = [
        EarlyStopping(monitor='val_auc', patience=7, restore_best_weights=True, mode='max', verbose=1),
        ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, min_lr=1e-6, verbose=1),
        ModelCheckpoint(f'best_{model_name}.keras', monitor='val_auc', save_best_only=True, mode='max', verbose=0)
    ]
    history = model.fit(
        train_gen,
        epochs=epochs,
        validation_data=val_gen,
        callbacks=callbacks,
        class_weight=class_weights,
        steps_per_epoch=train_gen.samples // train_gen.batch_size,
        validation_steps=val_gen.samples // val_gen.batch_size
    )
    print(f"--- Finished Training: {model_name} ---")
    return history

In [None]:
def fine_tune_model(model, base_model, train_gen, val_gen, class_weights, epochs=50, model_name='model'):
    
    """
    trains a keras model with comprehensive callbacks and class weighting.
    implements early stopping, learning rate reduction, and model checkpointing.
    
    args:
        model (keras.Model): compiled model to train
        train_gen (ImageDataGenerator): training data generator
        val_gen (ImageDataGenerator): validation data generator
        class_weights (dict): dictionary of class weights for imbalanced data
        epochs (int): maximum number of training epochs (default 200)
        model_name (str): identifier for saving model files (default 'model')
    
    returns:
        keras.History: training history object containing metrics
    
    training strategy:
        early stopping based on validation auc (patience=7)
        learning rate reduction when validation loss plateaus (factor=0.2)
        saves best model based on validation auc
        uses class weights to handle imbalanced data
        calculates proper steps per epoch based on dataset size
    
    callbacks:
        earlystopping: stops training when val_auc doesn't improve for 7 epochs
        reducelronplateau: reduces lr by factor of 0.2 when val_loss plateaus
        modelcheckpoint: saves best model based on val_auc
    """
    print(f"\n--- Starting Fine-tuning: {model_name} ---")
    base_model.trainable = True
    fine_tune_at = 15
    for layer in base_model.layers[:fine_tune_at]:
        layer.trainable = False
    print(f"Unfreezing layers from index {fine_tune_at} onwards for fine-tuning.")
    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5)
    model.compile(
        optimizer=optimizer,
        loss=focal_loss(gamma=2, alpha=0.25),
        metrics=[
            'accuracy',
            tf.keras.metrics.AUC(name='auc'),
            tf.keras.metrics.Precision(name='precision'),
            tf.keras.metrics.Recall(name='recall')
        ]
    )
    model.summary()
    callbacks_fine = [
        EarlyStopping(monitor='val_auc', patience=5, restore_best_weights=True, mode='max', verbose=1),
        ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, min_lr=1e-7, verbose=1),
        ModelCheckpoint(f'best_{model_name}_finetuned.keras', monitor='val_auc', save_best_only=True, mode='max', verbose=0)
    ]
    """
    execute model training:
         uses class weights for imbalanced data
         calculates proper steps per epoch
         validates on validation set
         applies all configured callbacks
    """
    history_fine = model.fit(
        train_gen,
        epochs=epochs,
        validation_data=val_gen,
        callbacks=callbacks_fine,
        class_weight=class_weights,
        steps_per_epoch=train_gen.samples // train_gen.batch_size,
        validation_steps=val_gen.samples // val_gen.batch_size
    )
    print(f"--- Finished Fine-tuning: {model_name} ---")
    return history_fine

In [None]:
# 8. Model Evaluation
def evaluate_model(model, test_gen, model_name='model'):
    """Evaluate model performance with various metrics"""
    print(f"\n--- Evaluating Model: {model_name} ---")
    predictions = model.predict(test_gen, steps=test_gen.samples // test_gen.batch_size + 1, verbose=1)
    true_classes = test_gen.classes
    predictions = predictions[:len(true_classes)]

    print("\nEvaluation with default threshold (0.5):")
    pred_classes_05 = (predictions > 0.5).astype(int).flatten()
    print(classification_report(true_classes, pred_classes_05, target_names=['NORMAL', 'PNEUMONIA'], digits=3))
    cm_05 = confusion_matrix(true_classes, pred_classes_05)
    plt.figure(figsize=(7, 5))
    sns.heatmap(cm_05, annot=True, fmt='d', cmap='Blues', xticklabels=['NORMAL', 'PNEUMONIA'], yticklabels=['NORMAL', 'PNEUMONIA'])
    plt.title(f'Confusion Matrix - {model_name} (Threshold=0.5)')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.savefig(f'{model_name}_confusion_matrix_0.5.png')
    plt.show()

    best_threshold, optimized_preds, best_f1 = optimize_threshold(true_classes, predictions)
    print(f"\nEvaluation with optimized threshold ({best_threshold:.3f} for best F1={best_f1:.3f}):")
    print(classification_report(true_classes, optimized_preds, target_names=['NORMAL', 'PNEUMONIA'], digits=3))
    cm_opt = confusion_matrix(true_classes, optimized_preds)
    plt.figure(figsize=(7, 5))
    sns.heatmap(cm_opt, annot=True, fmt='d', cmap='Oranges', xticklabels=['NORMAL', 'PNEUMONIA'], yticklabels=['NORMAL', 'PNEUMONIA'])
    plt.title(f'Confusion Matrix - {model_name} (Opt Threshold={best_threshold:.3f})')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.savefig(f'{model_name}_confusion_matrix_optimized.png')
    plt.show()

    fpr, tpr, _ = roc_curve(true_classes, predictions)
    roc_auc = auc(fpr, tpr)
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.3f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'ROC Curve - {model_name}')
    plt.legend(loc='lower right')
    plt.grid(alpha=0.3)
    plt.savefig(f'{model_name}_roc_curve.png')
    plt.show()

    precision, recall, _ = precision_recall_curve(true_classes, predictions)
    average_precision = average_precision_score(true_classes, predictions)
    plt.figure(figsize=(8, 6))
    plt.plot(recall, precision, color='blue', lw=2, label=f'PR curve (AP = {average_precision:.3f})')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title(f'Precision-Recall Curve - {model_name}')
    plt.legend(loc='lower left')
    plt.grid(alpha=0.3)
    plt.savefig(f'{model_name}_pr_curve.png')
    plt.show()

    final_accuracy = (optimized_preds == true_classes).mean()
    final_f1 = f1_score(true_classes, optimized_preds)
    print(f"Metrics Summary ({model_name}):")
    print(f"  AUC: {roc_auc:.4f}")
    print(f"  Average Precision (AP): {average_precision:.4f}")
    print(f"  Accuracy (Optimized Threshold): {final_accuracy:.4f}")
    print(f"  F1 Score (Optimized Threshold): {final_f1:.4f}")

    return {
        'predictions': predictions,
        'true_classes': true_classes,
        'pred_classes_0.5': pred_classes_05,
        'optimized_preds': optimized_preds,
        'best_threshold': best_threshold,
        'metrics': {'accuracy': final_accuracy, 'auc': roc_auc, 'average_precision': average_precision, 'f1': final_f1}
    }

In [None]:
def optimize_threshold(true_classes, predictions):
    """Find optimal threshold for F1 score"""
    best_f1 = 0
    best_threshold = 0.5
    thresholds = np.arange(0.1, 0.9, 0.01)
    f1_scores = []
    for threshold in thresholds:
        pred_classes = (predictions >= threshold).astype(int).flatten()
        f1 = f1_score(true_classes, pred_classes)
        f1_scores.append(f1)
        if f1 > best_f1:
            best_f1 = f1
            best_threshold = threshold
    print(f"Optimal threshold found: {best_threshold:.3f} with F1 score: {best_f1:.4f}")
    plt.figure(figsize=(8, 5))
    plt.plot(thresholds, f1_scores)
    plt.title('F1 Score vs. Threshold')
    plt.xlabel('Threshold')
    plt.ylabel('F1 Score')
    plt.vlines(best_threshold, 0, best_f1, colors='r', linestyles='--', label=f'Best Threshold ({best_threshold:.3f})')
    plt.legend()
    plt.grid(alpha=0.3)
    plt.savefig('f1_vs_threshold.png')
    plt.show()
    optimized_preds = (predictions >= best_threshold).astype(int).flatten()
    return best_threshold, optimized_preds, best_f1

In [None]:
# 9. Ensemble Model
def create_ensemble(models_dict, test_gen):
    """Create an ensemble of models by averaging predictions"""
    print("\n--- Creating Ensemble ---")
    all_predictions = []
    model_names = list(models_dict.keys())
    if len(model_names) < 2:
        print("Need at least two models to create an ensemble. Skipping.")
        return None, None
    print(f"Ensembling models: {model_names}")
    for model_name, model in models_dict.items():
        print(f"Getting predictions from {model_name}...")
        preds = model.predict(test_gen, steps=test_gen.samples // test_gen.batch_size + 1, verbose=0)
        true_classes = test_gen.classes
        preds = preds[:len(true_classes)]
        all_predictions.append(preds)
    if not all_predictions:
        print("No predictions generated for ensemble. Skipping.")
        return None, None
    ensemble_preds = np.mean(all_predictions, axis=0)
    print("Ensemble predictions generated.")
    return ensemble_preds, true_classes


In [None]:
# 10. Explainable AI (XAI) Techniques
def get_gradcam(model, img_array, last_conv_layer_name, pred_index=None):
    """Generate Grad-CAM heatmap"""
    try:
        grad_model = tf.keras.models.Model(
            [model.inputs],
            [model.get_layer(last_conv_layer_name).output, model.output]
        )
    except ValueError as e:
        print(f"Error creating Grad-CAM model: {e}")
        print(f"Ensure '{last_conv_layer_name}' is a valid convolutional layer name in the model.")
        print("Available layer names:")
        for layer in model.layers: print(f"  - {layer.name} (Type: {type(layer).__name__})")
        return None
    with tf.GradientTape() as tape:
        last_conv_layer_output, preds = grad_model(img_array)
        if pred_index is None: pred_index = tf.argmax(preds[0])
        class_output = preds[:, 0]
    grads = tape.gradient(class_output, last_conv_layer_output)
    if grads is None:
        print("Grad-CAM: Gradients are None. Check model structure and layer connectivity.")
        return None
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
    last_conv_layer_output = last_conv_layer_output[0]
    heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)
    heatmap = tf.maximum(heatmap, 0) / (tf.math.reduce_max(heatmap) + K.epsilon())
    return heatmap.numpy()

def visualize_gradcam(img, heatmap, alpha=0.5, title="Grad-CAM"):
    """Visualize Grad-CAM heatmap overlaid on original image"""
    if heatmap is None:
        print("Cannot visualize Grad-CAM: Heatmap is None.")
        return
    if img.dtype == np.float32 or img.dtype == np.float64: img = (img * 255).astype(np.uint8)
    if len(img.shape) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
    elif len(img.shape) == 3 and img.shape[2] == 1: img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
    heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    superimposed_img = cv2.addWeighted(img, 1 - alpha, heatmap, alpha, 0)
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1); plt.imshow(img); plt.title('Original Image'); plt.axis('off')
    plt.subplot(1, 3, 2); plt.imshow(heatmap); plt.title('Heatmap'); plt.axis('off')
    plt.subplot(1, 3, 3); plt.imshow(superimposed_img); plt.title(title); plt.axis('off')
    plt.tight_layout()
    plt.savefig(f'{title.replace(" ", "_").lower()}_gradcam.png')
    plt.show()
    return superimposed_img

def explain_with_lime(model, img_array, num_samples=500):
    """Generate LIME explanation"""
    print("Generating LIME explanation...")
    image_for_lime = img_array[0]
    explainer = lime_image.LimeImageExplainer()
    def predict_fn_lime(images):
        processed_images = []
        for img in images:
            if img.dtype == np.uint8: img = img.astype('float32') / 255.0
            processed_images.append(img)
        processed_images = np.array(processed_images)
        if len(processed_images.shape) == 3: processed_images = np.expand_dims(processed_images, axis=0)
        return model.predict(processed_images)
    try:
        explanation = explainer.explain_instance(
            image_for_lime.astype('double'),
            predict_fn_lime,
            top_labels=1,
            hide_color=0,
            num_samples=num_samples
        )
        temp, mask = explanation.get_image_and_mask(
            explanation.top_labels[0], positive_only=False, num_features=10, hide_rest=False
        )
        plt.figure(figsize=(8, 4))
        plt.subplot(1, 2, 1)
        display_img = image_for_lime if image_for_lime.dtype != np.uint8 else (image_for_lime / 255.0)
        plt.imshow(display_img); plt.title('Original Image'); plt.axis('off')
        plt.subplot(1, 2, 2)
        plt.imshow(mark_boundaries(temp / 2 + 0.5, mask)); plt.title('LIME Explanation'); plt.axis('off')
        plt.tight_layout()
        plt.savefig('lime_explanation.png')
        plt.show()
        print("LIME explanation generated.")
        return explanation
    except Exception as e:
        print(f"Error during LIME explanation: {e}")
        return None

In [None]:
# 11. Compare Different Models
def compare_models(results_dict):
    """Compare performance of different models based on evaluation results"""
    print("\n--- Comparing Model Performance ---")
    models = list(results_dict.keys())
    if not models:
        print("No results to compare.")
        return None
    accuracy = [results_dict[model]['metrics']['accuracy'] for model in models]
    auc_scores = [results_dict[model]['metrics']['auc'] for model in models] # Renamed to avoid conflict with auc function
    ap = [results_dict[model]['metrics']['average_precision'] for model in models]
    f1 = [results_dict[model]['metrics']['f1'] for model in models]

    plt.figure(figsize=(14, 10))
    bar_width = 0.6
    plt.subplot(2, 2, 1); plt.bar(models, accuracy, color='skyblue', width=bar_width); plt.title('Accuracy Comparison (Optimized Threshold)'); plt.ylabel('Accuracy'); plt.ylim(0, 1); plt.xticks(rotation=15, ha='right'); plt.grid(axis='y', alpha=0.3)
    plt.subplot(2, 2, 2); plt.bar(models, auc_scores, color='salmon', width=bar_width); plt.title('AUC Comparison'); plt.ylabel('AUC'); plt.ylim(0, 1); plt.xticks(rotation=15, ha='right'); plt.grid(axis='y', alpha=0.3)
    plt.subplot(2, 2, 3); plt.bar(models, ap, color='lightgreen', width=bar_width); plt.title('Average Precision Comparison'); plt.ylabel('Average Precision (AP)'); plt.ylim(0, 1); plt.xticks(rotation=15, ha='right'); plt.grid(axis='y', alpha=0.3)
    plt.subplot(2, 2, 4); plt.bar(models, f1, color='mediumpurple', width=bar_width); plt.title('F1 Score Comparison (Optimized Threshold)'); plt.ylabel('F1 Score'); plt.ylim(0, 1); plt.xticks(rotation=15, ha='right'); plt.grid(axis='y', alpha=0.3)
    plt.tight_layout(pad=2.0)
    plt.savefig('model_comparison.png')
    plt.show()

    summary = pd.DataFrame({'Model': models, 'Accuracy': accuracy, 'AUC': auc_scores, 'Avg Precision': ap, 'F1 Score': f1}).round(4)
    print("\nModel Performance Summary:")
    print(summary.sort_values('F1 Score', ascending=False).to_string(index=False))
    best_model_name = summary.loc[summary['F1 Score'].idxmax()]['Model'] if not summary.empty else None
    if best_model_name: print(f"\nBest model based on F1 Score: {best_model_name}")
    return summary, best_model_name


In [None]:
def plot_training_history(history, model_plot_name):
    """Plots training and validation metrics from Keras history."""
    if not history or not history.history:
        print(f"No history data to plot for {model_plot_name}.")
        return

    history_df = pd.DataFrame(history.history)

    plt.figure(figsize=(18, 12))

    # Plot Loss
    plt.subplot(2, 2, 1)
    if 'loss' in history_df:
        plt.plot(history_df['loss'], label='Training Loss')
    if 'val_loss' in history_df:
        plt.plot(history_df['val_loss'], label='Validation Loss')
    plt.title(f'Loss vs. Epochs - {model_plot_name}')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    if 'loss' in history_df or 'val_loss' in history_df:
        plt.legend()
    plt.grid(True, alpha=0.3)

    # Plot Accuracy
    plt.subplot(2, 2, 2)
    if 'accuracy' in history_df:
        plt.plot(history_df['accuracy'], label='Training Accuracy')
    if 'val_accuracy' in history_df:
        plt.plot(history_df['val_accuracy'], label='Validation Accuracy')
    plt.title(f'Accuracy vs. Epochs - {model_plot_name}')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    if 'accuracy' in history_df or 'val_accuracy' in history_df:
        plt.legend()
    plt.grid(True, alpha=0.3)

    # Plot AUC
    plt.subplot(2, 2, 3)
    if 'auc' in history_df:
        plt.plot(history_df['auc'], label='Training AUC')
    if 'val_auc' in history_df:
        plt.plot(history_df['val_auc'], label='Validation AUC')
    plt.title(f'AUC vs. Epochs - {model_plot_name}')
    plt.xlabel('Epochs')
    plt.ylabel('AUC')
    if 'auc' in history_df or 'val_auc' in history_df:
        plt.legend()
    plt.grid(True, alpha=0.3)

    # Plot Precision and Recall
    plt.subplot(2, 2, 4)
    legend_items = []
    if 'precision' in history_df:
        plt.plot(history_df['precision'], label='Training Precision')
        legend_items.append('Training Precision')
    if 'val_precision' in history_df:
        plt.plot(history_df['val_precision'], label='Validation Precision')
        legend_items.append('Validation Precision')
    if 'recall' in history_df:
        plt.plot(history_df['recall'], label='Training Recall', linestyle='--')
        legend_items.append('Training Recall')
    if 'val_recall' in history_df:
        plt.plot(history_df['val_recall'], label='Validation Recall', linestyle='--')
        legend_items.append('Validation Recall')

    plt.title(f'Precision & Recall vs. Epochs - {model_plot_name}')
    plt.xlabel('Epochs')
    plt.ylabel('Metric Value')
    if legend_items: # Only show legend if there's something to show
        plt.legend()
    plt.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(f'{model_plot_name}_training_history.png')
    plt.show()
    print(f"Training history plots saved for {model_plot_name}.")

In [None]:
def main():
    """Main function to run the complete pipeline"""
    print("--- Starting Pneumonia Prediction Project ---")

    print("\n=== Step 1: Exploring Dataset ===")
    try:
        dataset_info = explore_dataset()
    except Exception as e:
        print(f"Failed during dataset exploration: {e}")
        return

    print("\n=== Step 2: Creating Data Generators ===")
    img_size = 224
    batch_size = 32
    try:
        # train_gen, val_gen, test_gen = create_data_generators(img_size, batch_size)
        train_gen, val_gen, test_gen = create_data_generators(img_size, batch_size, validation_split=0.2)
    except Exception as e:
        print(f"Failed to create data generators: {e}")
        return

    print("\n=== Step 3: Calculating Class Weights ===")
    class_weights = calculate_class_weights(dataset_info)
    if train_gen.class_indices.get('NORMAL', -1) != 0 or train_gen.class_indices.get('PNEUMONIA', -1) != 1:
        print("Warning: Class indices might not match expected {0: Normal, 1: Pneumonia}. Ensure class_weights dict is correct.")
        print(f"Actual class indices: {train_gen.class_indices}")


    print("\n=== Step 4: Creating and Training Models ===")
    all_results = {}
    trained_models = {}

    activation_functions = ['relu', 'swish']
    for activation in activation_functions:
        model_id = f'CNN_{activation}'
        print(f"\n--- Training {model_id} ---")
        try:
            model = create_custom_cnn(img_size, activation)
            history = train_model(model, train_gen, val_gen, class_weights, epochs=200, model_name=f'cnn_{activation}')
            plot_training_history(history, model_id) # ADDED PLOT CALL
            model.load_weights(f'best_cnn_{activation}.keras')
            results = evaluate_model(model, test_gen, model_name=model_id)
            all_results[model_id] = results
            trained_models[model_id] = model
        except Exception as e:
            print(f"!!! Failed to train/evaluate {model_id}: {e}")
            import traceback; traceback.print_exc()

    model_id_vgg = 'Pretrained_VGG16'
    print(f"\n--- Training {model_id_vgg} ---")
    try:
        model_vgg, base_model_vgg = create_pretrained_vgg16(img_size)
        history_vgg = train_model(model_vgg, train_gen, val_gen, class_weights, epochs=100, model_name='vgg16_transfer')
        plot_training_history(history_vgg, f'{model_id_vgg}_initial_transfer') # ADDED PLOT CALL
        model_vgg.load_weights('best_vgg16_transfer.keras')
        print("Loaded best weights from initial VGG16 training.")

        history_vgg_fine = fine_tune_model(model_vgg, base_model_vgg, train_gen, val_gen, class_weights, epochs=50, model_name='vgg16')
        plot_training_history(history_vgg_fine, f'{model_id_vgg}_fine_tuning') # ADDED PLOT CALL

        fine_tuned_path = 'best_vgg16_finetuned.keras'
        transfer_path = 'best_vgg16_transfer.keras'
        if os.path.exists(fine_tuned_path):
             print(f"Loading best fine-tuned weights from: {fine_tuned_path}")
             model_vgg.load_weights(fine_tuned_path)
        elif os.path.exists(transfer_path): # Fallback to transfer weights if fine-tuned didn't improve or save
             print(f"Fine-tuned model checkpoint not found or did not improve. Loading best transfer weights from: {transfer_path}")
             model_vgg.load_weights(transfer_path)
        else:
             print("Warning: No saved weights found for VGG16. Evaluating with current weights.")

        results_vgg = evaluate_model(model_vgg, test_gen, model_name=model_id_vgg)
        all_results[model_id_vgg] = results_vgg
        trained_models[model_id_vgg] = model_vgg
    except Exception as e:
        print(f"!!! Failed to train/evaluate {model_id_vgg}: {e}")
        import traceback; traceback.print_exc()

    print("\n=== Step 5: Creating and Evaluating Ensemble Model ===")
    ensemble_candidates = {}
    if 'CNN_swish' in trained_models: ensemble_candidates['CNN_swish'] = trained_models['CNN_swish']
    if 'Pretrained_VGG16' in trained_models: ensemble_candidates['Pretrained_VGG16'] = trained_models['Pretrained_VGG16']

    if len(ensemble_candidates) >= 2:
        try:
            ensemble_preds, ensemble_true_classes = create_ensemble(ensemble_candidates, test_gen)
            if ensemble_preds is not None:
                print("\n--- Evaluating Ensemble Model ---")
                ens_best_threshold, ens_optimized_preds, ens_best_f1 = optimize_threshold(ensemble_true_classes, ensemble_preds)
                print(f"\nEnsemble Evaluation with optimized threshold ({ens_best_threshold:.3f}):")
                print(classification_report(ensemble_true_classes, ens_optimized_preds, target_names=['NORMAL', 'PNEUMONIA'], digits=3))
                ens_accuracy = (ens_optimized_preds == ensemble_true_classes).mean()
                ens_auc = roc_auc_score(ensemble_true_classes, ensemble_preds)
                ens_ap = average_precision_score(ensemble_true_classes, ensemble_preds)
                ens_f1 = f1_score(ensemble_true_classes, ens_optimized_preds)
                all_results['Ensemble'] = {
                    'predictions': ensemble_preds, 'true_classes': ensemble_true_classes,
                    'pred_classes_0.5': (ensemble_preds > 0.5).astype(int),
                    'optimized_preds': ens_optimized_preds, 'best_threshold': ens_best_threshold,
                    'metrics': {'accuracy': ens_accuracy, 'auc': ens_auc, 'average_precision': ens_ap, 'f1': ens_f1}
                }
                cm_ens = confusion_matrix(ensemble_true_classes, ens_optimized_preds)
                plt.figure(figsize=(7, 5))
                sns.heatmap(cm_ens, annot=True, fmt='d', cmap='Greens', xticklabels=['NORMAL', 'PNEUMONIA'], yticklabels=['NORMAL', 'PNEUMONIA'])
                plt.title(f'Confusion Matrix - Ensemble (Opt Threshold={ens_best_threshold:.3f})')
                plt.ylabel('True Label'); plt.xlabel('Predicted Label')
                plt.tight_layout(); plt.savefig('ensemble_confusion_matrix_optimized.png'); plt.show()
        except Exception as e:
            print(f"!!! Failed to create/evaluate ensemble model: {e}")
            import traceback; traceback.print_exc()
    else:
        print("Skipping ensemble: Not enough base models trained successfully.")

    print("\n=== Step 6: Comparing All Models ===")
    model_comparison_summary, best_model_name = None, None 
    if all_results:
        model_comparison_summary, best_model_name = compare_models(all_results)
    else:
        print("No models were successfully trained and evaluated. Cannot compare.")

    print("\n=== Step 7: Applying XAI Techniques ===")
    if best_model_name and best_model_name != 'Ensemble' and best_model_name in trained_models:
        print(f"Applying XAI to the best individual model: {best_model_name}")
        best_model = trained_models[best_model_name]
        test_gen.reset()
        normal_img, pneumonia_img = None, None
        try:
            for i in range(len(test_gen)): # Iterate through batches
                img_batch, label_batch = next(test_gen)
                if normal_img is None and 0 in label_batch:
                    normal_idx = np.where(label_batch == 0)[0][0]
                    normal_img = img_batch[normal_idx:normal_idx+1]
                    print("Found sample Normal image for XAI.")
                if pneumonia_img is None and 1 in label_batch:
                    pneumonia_idx = np.where(label_batch == 1)[0][0]
                    pneumonia_img = img_batch[pneumonia_idx:pneumonia_idx+1]
                    print("Found sample Pneumonia image for XAI.")
                if normal_img is not None and pneumonia_img is not None: break
            if normal_img is None or pneumonia_img is None: print("Warning: Could not find both Normal and Pneumonia samples in test set for XAI.")
        except Exception as e: print(f"Error getting sample images for XAI: {e}")

        last_conv_layer_name = None
        for layer in reversed(best_model.layers):
            if isinstance(layer, (Conv2D, tf.keras.layers.SeparableConv2D)):
                 is_in_base = any(layer.name == base_layer.name for base_layer in getattr(best_model, 'layers', []) if isinstance(base_layer, Model) and hasattr(base_layer, 'layers'))
                 last_conv_layer_name = layer.name
                 break
        if not last_conv_layer_name:
             print("Warning: Could not automatically find a Conv2D layer for Grad-CAM.")
             if 'VGG16' in best_model_name: last_conv_layer_name = 'block5_conv3' # VGG16 specific fallback

        if last_conv_layer_name and normal_img is not None:
            print(f"\nApplying Grad-CAM (layer: {last_conv_layer_name}) to Normal Image...")
            normal_heatmap = get_gradcam(best_model, normal_img, last_conv_layer_name)
            visualize_gradcam(normal_img[0], normal_heatmap, title=f"Grad-CAM Normal ({best_model_name})")
        if last_conv_layer_name and pneumonia_img is not None:
            print(f"\nApplying Grad-CAM (layer: {last_conv_layer_name}) to Pneumonia Image...")
            pneumonia_heatmap = get_gradcam(best_model, pneumonia_img, last_conv_layer_name)
            visualize_gradcam(pneumonia_img[0], pneumonia_heatmap, title=f"Grad-CAM Pneumonia ({best_model_name})")

        if normal_img is not None: print("\nApplying LIME to Normal Image..."); explain_with_lime(best_model, normal_img)
        if pneumonia_img is not None: print("\nApplying LIME to Pneumonia Image..."); explain_with_lime(best_model, pneumonia_img)

        

    print("\n--- Project Pipeline Completed ---")
    if best_model_name and model_comparison_summary is not None and not model_comparison_summary.empty:
        best_score_row = model_comparison_summary[model_comparison_summary['Model'] == best_model_name]
        if not best_score_row.empty:
             best_score = best_score_row['F1 Score'].iloc[0]
             print(f"Best Model Identified: {best_model_name} (F1 Score: {best_score:.4f})")
        else:
            print(f"Best model name '{best_model_name}' not found in summary. Check model IDs.")
    elif model_comparison_summary is not None and not model_comparison_summary.empty:
        print("Review the performance summary table above for details on trained models.")
    else:
        print("No models were successfully evaluated or no summary available.")

if __name__ == "__main__":
    main()