In [None]:
# COVID-19 Image Classification using VGG16 - 50 Epochs Version
# This notebook is designed for comprehensive training with 50 epochs
# Updated for new project structure with enhanced training configuration

# =============================================================================
# CELL 1: Import Libraries and Setup
# =============================================================================

import numpy as np 
from tqdm import tqdm
import os
import random
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
import shutil
import tensorflow as tf
import cv2

from tensorflow.keras import layers
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.vgg16 import preprocess_input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import load_model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau

# GPU Configuration
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"✅ GPU configuration completed. Found {len(gpus)} GPU(s)")
    except RuntimeError as e:
        print(f"GPU configuration error: {e}")
else:
    print("No GPU found, using CPU")

print("✅ All libraries imported successfully")

# =============================================================================
# CELL 2: Enhanced Configuration for 50 Epochs Training
# =============================================================================

# Path Configuration - Updated for new project structure
PROJECT_ROOT = r'C:\Users\29873\code\Summer-projects\lung-cnn'
TRAIN_PATH = os.path.join(PROJECT_ROOT, 'data', 'train_covid19')
TEST_PATH = os.path.join(PROJECT_ROOT, 'data', 'test_healthcare')

# Output directories (will be created in project root for compatibility)
PROCESSED_TRAIN_DIR = os.path.join(PROJECT_ROOT, 'Train_covid_50')
PROCESSED_VAL_DIR = os.path.join(PROJECT_ROOT, 'Val_covid_50')

# Model save path
MODEL_SAVE_PATH = os.path.join(PROJECT_ROOT, 'models')
RESULTS_PATH = os.path.join(PROJECT_ROOT, 'results')

# Enhanced Training Configuration for 50 Epochs
IMG_SIZE = (224, 224)
BATCH_SIZE = 32
NUM_EPOCHS = 50  # Extended training
LEARNING_RATE = 1e-3
RANDOM_SEED = 100
TRAIN_VAL_SPLIT = 0.6  # 60% training, 40% validation

# Enhanced callbacks configuration
EARLY_STOPPING_PATIENCE = 10  # Increased patience for longer training
REDUCE_LR_PATIENCE = 5
MIN_LR = 1e-7

# Set random seeds for reproducibility
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
tf.random.set_seed(RANDOM_SEED)

# Create necessary directories
os.makedirs(MODEL_SAVE_PATH, exist_ok=True)
os.makedirs(RESULTS_PATH, exist_ok=True)
os.makedirs(os.path.join(RESULTS_PATH, 'plots'), exist_ok=True)

print("✅ Enhanced configuration completed for 50 epochs training")
print(f"Project root: {PROJECT_ROOT}")
print(f"Training path: {TRAIN_PATH}")
print(f"Testing path: {TEST_PATH}")
print(f"Model save path: {MODEL_SAVE_PATH}")
print(f"Results path: {RESULTS_PATH}")
print(f"Image size: {IMG_SIZE}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Number of epochs: {NUM_EPOCHS}")
print(f"Early stopping patience: {EARLY_STOPPING_PATIENCE}")

# =============================================================================
# CELL 3: Enhanced Data Structure Analysis and Validation
# =============================================================================

def check_data_structure():
    """
    Analyze and validate the data structure with enhanced reporting
    Returns True if data structure is valid, False otherwise
    """
    print("=" * 60)
    print("DATA STRUCTURE ANALYSIS - 50 EPOCHS VERSION")
    print("=" * 60)
    
    # Check training data
    if os.path.exists(TRAIN_PATH):
        print(f"✅ Training data path exists: {TRAIN_PATH}")
        subdirs = [d for d in os.listdir(TRAIN_PATH) if os.path.isdir(os.path.join(TRAIN_PATH, d))]
        print(f"Training data subdirectories: {subdirs}")
        
        if len(subdirs) != 2:
            print(f"⚠️ Expected 2 class directories, found {len(subdirs)}")
        
        total_train_images = 0
        class_distribution = {}
        
        for subdir in subdirs:
            subdir_path = os.path.join(TRAIN_PATH, subdir)
            img_files = [f for f in os.listdir(subdir_path) 
                        if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))]
            img_count = len(img_files)
            total_train_images += img_count
            class_distribution[subdir] = img_count
            print(f"  {subdir}: {img_count} images")
        
        print(f"Total training images: {total_train_images}")
        
        # Check class balance
        if len(class_distribution) == 2:
            class_counts = list(class_distribution.values())
            balance_ratio = min(class_counts) / max(class_counts)
            print(f"Class balance ratio: {balance_ratio:.3f}")
            if balance_ratio < 0.5:
                print("⚠️ Significant class imbalance detected. Consider data augmentation.")
            else:
                print("✅ Reasonable class balance detected.")
        
    else:
        print(f"❌ Training data path does not exist: {TRAIN_PATH}")
        return False
    
    # Check testing data with enhanced analysis
    if os.path.exists(TEST_PATH):
        print(f"✅ Testing data path exists: {TEST_PATH}")
        
        # Check for subdirectories and files
        items = os.listdir(TEST_PATH)
        subdirs = [d for d in items if os.path.isdir(os.path.join(TEST_PATH, d))]
        files = [f for f in items if os.path.isfile(os.path.join(TEST_PATH, f)) 
                and f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))]
        
        print(f"Testing data subdirectories: {subdirs}")
        print(f"Testing data root directory images: {len(files)}")
        
        total_test_images = len(files)
        if subdirs:
            for subdir in subdirs:
                subdir_path = os.path.join(TEST_PATH, subdir)
                if os.path.isdir(subdir_path):
                    img_files = [f for f in os.listdir(subdir_path) 
                               if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))]
                    img_count = len(img_files)
                    total_test_images += img_count
                    print(f"  {subdir}: {img_count} images")
        
        print(f"Total testing images: {total_test_images}")
        
        # Recommend train/test ratio
        if total_train_images > 0:
            test_ratio = total_test_images / (total_train_images + total_test_images)
            print(f"Train/Test ratio: {1-test_ratio:.3f}/{test_ratio:.3f}")
        
    else:
        print(f"❌ Testing data path does not exist: {TEST_PATH}")
        return False
    
    print("=" * 60)
    return True

# Run enhanced data structure check
if not check_data_structure():
    print("❌ Please check your data paths and structure!")
    raise SystemExit("Data structure validation failed")

# =============================================================================
# CELL 4: Enhanced Directory Setup and Data Splitting
# =============================================================================

def setup_directories():
    """
    Create and clean up training and validation directories for 50-epoch version
    """
    print("Setting up directories for 50-epoch training...")
    
    # Remove existing directories
    for dir_path in [PROCESSED_TRAIN_DIR, PROCESSED_VAL_DIR]:
        if os.path.exists(dir_path):
            shutil.rmtree(dir_path)
            print(f"Removed existing directory: {dir_path}")
    
    # Create new directories
    for dir_path in [PROCESSED_TRAIN_DIR, PROCESSED_VAL_DIR]:
        os.makedirs(dir_path, exist_ok=True)
        os.makedirs(os.path.join(dir_path, 'yes'), exist_ok=True)
        os.makedirs(os.path.join(dir_path, 'no'), exist_ok=True)
        print(f"Created directory: {dir_path}")
    
    print("✅ Directory setup completed for 50-epoch version")

def split_data():
    """
    Split training data into training and validation sets with enhanced logging
    Returns True if successful, False otherwise
    """
    print("\n📈 Next Steps:")
print("1. Run the 10-epoch version for comparison")
print("2. Use the modular code in src/ for production deployment")
print("3. Implement model comparison and ensemble methods")
print("4. Consider creating a web interface for predictions")
print("5. Set up automated model retraining pipeline")

print("\n💡 Usage Examples:")
print("# Load the trained model for inference:")
print("from tensorflow.keras.models import load_model")
print(f"model = load_model('{model_filename}')")
print("")
print("# Use the enhanced prediction pipeline:")
print("python src/main.py --mode predict --version 50_epochs")
print("")
print("# Compare with 10-epoch version:")
print("python src/main.py --mode compare")

print("\n" + "="*80)
print("🎉 50-EPOCH TRAINING SESSION COMPLETED SUCCESSFULLY!")
print("="*80)
print("Your model is ready for advanced deployment and clinical evaluation.")
print("Remember to validate performance on independent test sets before clinical use.")
print("\n")Splitting data into training and validation sets (50-epoch version)...")
    
    # Get class folders
    class_folders = [d for d in os.listdir(TRAIN_PATH) 
                    if os.path.isdir(os.path.join(TRAIN_PATH, d))]
    
    if len(class_folders) != 2:
        print(f"⚠️ Found {len(class_folders)} class folders, expected 2")
        print(f"Class folders: {class_folders}")
        
        # Create mapping for non-standard folder names
        if len(class_folders) == 2:
            # Assume first folder is 'no' and second is 'yes' or map based on folder names
            if 'no' in class_folders and 'yes' in class_folders:
                class_mapping = {'no': 'no', 'yes': 'yes'}
            else:
                class_mapping = {class_folders[0]: 'no', class_folders[1]: 'yes'}
            print(f"Class mapping: {class_mapping}")
        else:
            print("❌ Please ensure training data has exactly 2 class folders!")
            return False
    else:
        # Standard case with 'yes' and 'no' folders
        class_mapping = {folder: folder for folder in class_folders}
    
    total_train = 0
    total_val = 0
    split_summary = {}
    
    for original_class, target_class in class_mapping.items():
        class_path = os.path.join(TRAIN_PATH, original_class)
        files = [f for f in os.listdir(class_path) 
                if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))]
        
        if len(files) == 0:
            print(f"⚠️ No image files found in {original_class} folder")
            continue
            
        # Shuffle and split data
        random.shuffle(files)
        split_point = int(TRAIN_VAL_SPLIT * len(files))
        
        # Copy training data
        train_files = files[:split_point]
        for file_name in train_files:
            src = os.path.join(class_path, file_name)
            dst = os.path.join(PROCESSED_TRAIN_DIR, target_class, file_name)
            shutil.copy2(src, dst)
        
        # Copy validation data
        val_files = files[split_point:]
        for file_name in val_files:
            src = os.path.join(class_path, file_name)
            dst = os.path.join(PROCESSED_VAL_DIR, target_class, file_name)
            shutil.copy2(src, dst)
        
        total_train += len(train_files)
        total_val += len(val_files)
        
        split_summary[original_class] = {
            'total': len(files),
            'train': len(train_files),
            'val': len(val_files),
            'train_ratio': len(train_files) / len(files),
            'val_ratio': len(val_files) / len(files)
        }
        
        print(f"{original_class} -> {target_class}: {len(train_files)} training, {len(val_files)} validation")
    
    print(f"\nDetailed split summary:")
    for class_name, stats in split_summary.items():
        print(f"  {class_name}: {stats['total']} total | "
              f"{stats['train']} train ({stats['train_ratio']:.1%}) | "
              f"{stats['val']} val ({stats['val_ratio']:.1%})")
    
    print(f"\nOverall: {total_train} training images, {total_val} validation images")
    print("✅ Data splitting completed successfully for 50-epoch version")
    return True

# Execute directory setup and data splitting
setup_directories()
if not split_data():
    raise SystemExit("Data splitting failed")

# =============================================================================
# CELL 5: Enhanced Data Visualization
# =============================================================================

def plot_samples(img_path, n=25, title="Sample Images"):
    """
    Display sample images from the dataset with enhanced layout
    
    Args:
        img_path (str): Path to image directory
        n (int): Number of images to display
        title (str): Title for the plot
    """
    files_list = []
    labels_list = []
    
    # Collect all image files
    for root, dirs, files in os.walk(img_path):
        for file in files:
            if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
                files_list.append(os.path.join(root, file))
                label = os.path.basename(root)
                labels_list.append(label)
    
    if not files_list:
        print(f"No image files found in {img_path}")
        return
    
    # Randomly select images
    combined = list(zip(files_list, labels_list))
    random.shuffle(combined)
    files_list, labels_list = zip(*combined)
    
    # Limit number of images to display
    n = min(n, len(files_list))
    cols = 5
    rows = (n + cols - 1) // cols
    
    plt.figure(figsize=(20, 4 * rows))
    plt.suptitle(title, fontsize=18, fontweight='bold')
    
    for i in range(n):
        file_path, label = files_list[i], labels_list[i]
        
        # Read and display image
        img = cv2.imread(file_path)
        if img is not None:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            plt.subplot(rows, cols, i + 1)
            plt.imshow(img)
            plt.title(f'Class: {label}', fontsize=12, fontweight='bold')
            plt.axis('off')
    
    plt.tight_layout()
    plt.show()

def plot_class_distribution():
    """Plot class distribution for both training and validation sets"""
    train_classes = {'yes': 0, 'no': 0}
    val_classes = {'yes': 0, 'no': 0}
    
    # Count training images
    for class_name in ['yes', 'no']:
        train_path = os.path.join(PROCESSED_TRAIN_DIR, class_name)
        val_path = os.path.join(PROCESSED_VAL_DIR, class_name)
        
        if os.path.exists(train_path):
            train_classes[class_name] = len([f for f in os.listdir(train_path) 
                                           if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))])
        
        if os.path.exists(val_path):
            val_classes[class_name] = len([f for f in os.listdir(val_path) 
                                         if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))])
    
    # Create visualization
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Training distribution
    ax1.bar(train_classes.keys(), train_classes.values(), color=['lightcoral', 'lightblue'])
    ax1.set_title('Training Set Class Distribution', fontweight='bold')
    ax1.set_ylabel('Number of Images')
    for i, (k, v) in enumerate(train_classes.items()):
        ax1.text(i, v + max(train_classes.values()) * 0.01, str(v), ha='center', fontweight='bold')
    
    # Validation distribution
    ax2.bar(val_classes.keys(), val_classes.values(), color=['lightcoral', 'lightblue'])
    ax2.set_title('Validation Set Class Distribution', fontweight='bold')
    ax2.set_ylabel('Number of Images')
    for i, (k, v) in enumerate(val_classes.items()):
        ax2.text(i, v + max(val_classes.values()) * 0.01, str(v), ha='center', fontweight='bold')
    
    plt.tight_layout()
    plt.show()
    
    return train_classes, val_classes

# Display enhanced visualizations
print("\nDisplaying sample training images...")
plot_samples(PROCESSED_TRAIN_DIR, n=25, title="Training Dataset Sample Images - 50 Epochs Version")

print("\nAnalyzing class distribution...")
train_dist, val_dist = plot_class_distribution()

# =============================================================================
# CELL 6: Enhanced Data Generators Setup
# =============================================================================

# Enhanced training data generator with more aggressive augmentation
train_datagen = ImageDataGenerator(
    rotation_range=30,          # Increased from 20
    width_shift_range=0.25,     # Increased from 0.2
    height_shift_range=0.25,    # Increased from 0.2
    shear_range=0.25,          # Increased from 0.2
    zoom_range=0.25,           # Increased from 0.2
    horizontal_flip=True,
    vertical_flip=False,        # Keep False for medical images
    brightness_range=[0.8, 1.2], # New: brightness variation
    fill_mode='nearest',
    preprocessing_function=preprocess_input
)

# Validation data generator (no augmentation)
val_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input
)

print("Enhanced data generators created with the following augmentations:")
print("- Rotation: ±30 degrees (increased)")
print("- Width/Height shift: ±25% (increased)")
print("- Shear transformation: 25% (increased)")
print("- Zoom: ±25% (increased)")
print("- Horizontal flip: Yes")
print("- Brightness variation: 0.8-1.2 (new)")
print("- Vertical flip: No (medical image consideration)")

# Create enhanced data generators
train_generator = train_datagen.flow_from_directory(
    PROCESSED_TRAIN_DIR,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='binary',
    seed=RANDOM_SEED,
    shuffle=True
)

validation_generator = val_datagen.flow_from_directory(
    PROCESSED_VAL_DIR,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='binary',
    seed=RANDOM_SEED,
    shuffle=False
)

print(f"\n✅ Enhanced data generators created successfully")
print(f"Training samples: {train_generator.samples}")
print(f"Validation samples: {validation_generator.samples}")
print(f"Number of classes: {len(train_generator.class_indices)}")
print(f"Class indices: {train_generator.class_indices}")

# Calculate expected training time
steps_per_epoch = train_generator.samples // BATCH_SIZE
validation_steps = validation_generator.samples // BATCH_SIZE
total_steps = steps_per_epoch * NUM_EPOCHS

print(f"\nTraining details:")
print(f"Steps per epoch: {steps_per_epoch}")
print(f"Validation steps: {validation_steps}")
print(f"Total training steps: {total_steps}")
print(f"Estimated time per epoch: ~{steps_per_epoch * 2}s (assuming 2s/step)")

# =============================================================================
# CELL 7: Enhanced Model Architecture
# =============================================================================

def create_enhanced_model():
    """
    Create and compile an enhanced CNN model based on VGG16 for 50-epoch training
    
    Returns:
        tensorflow.keras.Model: Compiled model with enhanced architecture
    """
    # Load pre-trained VGG16 model
    base_model = VGG16(
        weights='imagenet',
        include_top=False,
        input_shape=IMG_SIZE + (3,)
    )
    
    # Create enhanced custom classifier on top
    model = Sequential([
        base_model,
        GlobalAveragePooling2D(),
        Dropout(0.6),  # Increased dropout for longer training
        Dense(256, activation='relu'),  # Increased units
        BatchNormalization(),
        Dropout(0.4),  # Additional dropout layer
        Dense(128, activation='relu'),
        BatchNormalization(),
        Dropout(0.3),
        Dense(64, activation='relu'),  # Additional layer
        Dropout(0.2),
        Dense(1, activation='sigmoid')  # Binary classification
    ])
    
    # Freeze pre-trained layers initially
    base_model.trainable = False
    
    # Compile model with enhanced configuration
    model.compile(
        optimizer=Adam(learning_rate=LEARNING_RATE),
        loss='binary_crossentropy',
        metrics=['accuracy', 'precision', 'recall']  # Additional metrics
    )
    
    return model

# Create enhanced model
print("Creating enhanced model architecture for 50-epoch training...")
model = create_enhanced_model()

print("\n" + "="*70)
print("ENHANCED MODEL ARCHITECTURE - 50 EPOCHS VERSION")
print("="*70)
model.summary()

# Calculate total parameters
trainable_params = sum([np.prod(v.get_shape().as_list()) for v in model.trainable_variables])
non_trainable_params = sum([np.prod(v.get_shape().as_list()) for v in model.non_trainable_variables])

print(f"\nModel parameters:")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Non-trainable parameters: {non_trainable_params:,}")
print(f"Total parameters: {trainable_params + non_trainable_params:,}")

print(f"\nEnhanced features for 50-epoch training:")
print("- Deeper classifier network (4 dense layers)")
print("- Increased dropout rates for better regularization")
print("- Additional metrics: precision and recall")
print("- Enhanced data augmentation")

# =============================================================================
# CELL 8: Enhanced Training Configuration and Callbacks
# =============================================================================

# Define enhanced callbacks for 50-epoch training
callbacks = [
    # Early stopping with increased patience
    EarlyStopping(
        monitor='val_accuracy',
        patience=EARLY_STOPPING_PATIENCE,
        restore_best_weights=True,
        verbose=1,
        mode='max'
    ),
    
    # Model checkpoint to save best model
    ModelCheckpoint(
        filepath=os.path.join(MODEL_SAVE_PATH, 'best_model_50epochs.h5'),
        monitor='val_accuracy',
        save_best_only=True,
        mode='max',
        verbose=1,
        save_weights_only=False
    ),
    
    # Learning rate reduction
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.2,
        patience=REDUCE_LR_PATIENCE,
        min_lr=MIN_LR,
        verbose=1,
        cooldown=2
    )
]

print("Enhanced training configuration for 50 epochs:")
print(f"- Optimizer: Adam (learning_rate={LEARNING_RATE})")
print(f"- Loss function: Binary crossentropy")
print(f"- Metrics: Accuracy, Precision, Recall")
print(f"- Early stopping: val_accuracy (patience={EARLY_STOPPING_PATIENCE})")
print(f"- Model checkpoint: best_model_50epochs.h5")
print(f"- Learning rate reduction: factor=0.2, patience={REDUCE_LR_PATIENCE}")
print(f"- Number of epochs: {NUM_EPOCHS}")

# Calculate steps per epoch
steps_per_epoch = train_generator.samples // BATCH_SIZE
validation_steps = validation_generator.samples // BATCH_SIZE

print(f"- Steps per epoch: {steps_per_epoch}")
print(f"- Validation steps: {validation_steps}")

# =============================================================================
# CELL 9: Enhanced Model Training (50 epochs with detailed monitoring)
# =============================================================================

print("\n" + "="*70)
print("STARTING ENHANCED MODEL TRAINING - 50 EPOCHS VERSION")
print("="*70)

# Enhanced callback for detailed epoch monitoring
class EnhancedVerboseCallback(tf.keras.callbacks.Callback):
    def __init__(self):
        super().__init__()
        self.epoch_start_time = None
    
    def on_epoch_begin(self, epoch, logs=None):
        import time
        self.epoch_start_time = time.time()
        print(f"\n--- Epoch {epoch + 1}/{NUM_EPOCHS} ---")
    
    def on_epoch_end(self, epoch, logs=None):
        import time
        epoch_time = time.time() - self.epoch_start_time
        
        print(f"Epoch {epoch + 1}/{NUM_EPOCHS} completed in {epoch_time:.1f}s")
        print(f"  Training   - Loss: {logs['loss']:.4f}, Accuracy: {logs['accuracy']:.4f}")
        print(f"  Validation - Loss: {logs['val_loss']:.4f}, Accuracy: {logs['val_accuracy']:.4f}")
        print(f"  Precision: {logs.get('precision', 'N/A'):.4f}, Recall: {logs.get('recall', 'N/A'):.4f}")
        
        # Progress indicator
        progress = (epoch + 1) / NUM_EPOCHS * 100
        print(f"  Progress: {progress:.1f}% complete")

# Add enhanced callback
callbacks.append(EnhancedVerboseCallback())

# Train the enhanced model
print("Starting 50-epoch training session...")
print("This may take several hours depending on your hardware.")
print("Monitor GPU/CPU usage and ensure adequate cooling.")

history = model.fit(
    train_generator,
    steps_per_epoch=steps_per_epoch,
    epochs=NUM_EPOCHS,
    validation_data=validation_generator,
    validation_steps=validation_steps,
    callbacks=callbacks,
    verbose=1  # Show progress bar for each epoch
)

print("\n✅ 50-epoch training completed successfully!")

# =============================================================================
# CELL 10: Enhanced Training Results Visualization
# =============================================================================

def plot_enhanced_training_history(history):
    """
    Plot comprehensive training history for 50-epoch training
    
    Args:
        history: Training history object
    """
    # Create a comprehensive figure with multiple subplots
    fig = plt.figure(figsize=(20, 15))
    
    epochs_range = range(1, len(history.history['accuracy']) + 1)
    
    # 1. Accuracy plot
    plt.subplot(3, 3, 1)
    plt.plot(epochs_range, history.history['accuracy'], 'b-', label='Training Accuracy', linewidth=2)
    plt.plot(epochs_range, history.history['val_accuracy'], 'r-', label='Validation Accuracy', linewidth=2)
    plt.title('Model Accuracy - 50 Epochs', fontsize=14, fontweight='bold')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # 2. Loss plot
    plt.subplot(3, 3, 2)
    plt.plot(epochs_range, history.history['loss'], 'b-', label='Training Loss', linewidth=2)
    plt.plot(epochs_range, history.history['val_loss'], 'r-', label='Validation Loss', linewidth=2)
    plt.title('Model Loss - 50 Epochs', fontsize=14, fontweight='bold')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # 3. Precision plot (if available)
    if 'precision' in history.history:
        plt.subplot(3, 3, 3)
        plt.plot(epochs_range, history.history['precision'], 'g-', label='Training Precision', linewidth=2)
        if 'val_precision' in history.history:
            plt.plot(epochs_range, history.history['val_precision'], 'orange', label='Validation Precision', linewidth=2)
        plt.title('Model Precision - 50 Epochs', fontsize=14, fontweight='bold')
        plt.xlabel('Epoch')
        plt.ylabel('Precision')
        plt.legend()
        plt.grid(True, alpha=0.3)
    
    # 4. Recall plot (if available)
    if 'recall' in history.history:
        plt.subplot(3, 3, 4)
        plt.plot(epochs_range, history.history['recall'], 'purple', label='Training Recall', linewidth=2)
        if 'val_recall' in history.history:
            plt.plot(epochs_range, history.history['val_recall'], 'brown', label='Validation Recall', linewidth=2)
        plt.title('Model Recall - 50 Epochs', fontsize=14, fontweight='bold')
        plt.xlabel('Epoch')
        plt.ylabel('Recall')
        plt.legend()
        plt.grid(True, alpha=0.3)
    
    # 5. Learning rate plot (if available in callbacks)
    plt.subplot(3, 3, 5)
    if hasattr(model.optimizer, 'learning_rate'):
        # This is a simplified representation - actual LR changes would need callback logging
        plt.plot(epochs_range, [LEARNING_RATE] * len(epochs_range), 'k--', label='Learning Rate')
        plt.title('Learning Rate Schedule', fontsize=14, fontweight='bold')
        plt.xlabel('Epoch')
        plt.ylabel('Learning Rate')
        plt.legend()
        plt.grid(True, alpha=0.3)
    
    # 6. Accuracy comparison bar chart
    plt.subplot(3, 3, 6)
    final_train_acc = history.history['accuracy'][-1]
    final_val_acc = history.history['val_accuracy'][-1]
    best_val_acc = max(history.history['val_accuracy'])
    
    metrics = ['Final Train', 'Final Val', 'Best Val']
    values = [final_train_acc, final_val_acc, best_val_acc]
    colors = ['blue', 'red', 'green']
    
    bars = plt.bar(metrics, values, color=colors, alpha=0.7)
    plt.title('Accuracy Summary', fontsize=14, fontweight='bold')
    plt.ylabel('Accuracy')
    plt.ylim(0, 1)
    
    for bar, value in zip(bars, values):
        plt.text(bar.get_x() + bar.get_width()/2, value + 0.01, 
                f'{value:.4f}', ha='center', va='bottom', fontweight='bold')
    
    # 7. Loss comparison bar chart
    plt.subplot(3, 3, 7)
    final_train_loss = history.history['loss'][-1]
    final_val_loss = history.history['val_loss'][-1]
    min_val_loss = min(history.history['val_loss'])
    
    loss_metrics = ['Final Train', 'Final Val', 'Min Val']
    loss_values = [final_train_loss, final_val_loss, min_val_loss]
    
    bars = plt.bar(loss_metrics, loss_values, color=colors, alpha=0.7)
    plt.title('Loss Summary', fontsize=14, fontweight='bold')
    plt.ylabel('Loss')
    
    for bar, value in zip(bars, loss_values):
        plt.text(bar.get_x() + bar.get_width()/2, value + max(loss_values) * 0.02, 
                f'{value:.4f}', ha='center', va='bottom', fontweight='bold')
    
    # 8. Training progress over time
    plt.subplot(3, 3, 8)
    # Calculate moving averages for smoother visualization
    window_size = max(1, len(epochs_range) // 10)
    
    def moving_average(data, window_size):
        return [sum(data[max(0, i-window_size):i+1]) / min(i+1, window_size) for i in range(len(data))]
    
    smooth_train_acc = moving_average(history.history['accuracy'], window_size)
    smooth_val_acc = moving_average(history.history['val_accuracy'], window_size)
    
    plt.plot(epochs_range, smooth_train_acc, 'b-', label='Smooth Train Acc', linewidth=2)
    plt.plot(epochs_range, smooth_val_acc, 'r-', label='Smooth Val Acc', linewidth=2)
    plt.title('Smoothed Accuracy Trends', fontsize=14, fontweight='bold')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # 9. Overfitting analysis
    plt.subplot(3, 3, 9)
    acc_diff = [train - val for train, val in zip(history.history['accuracy'], history.history['val_accuracy'])]
    plt.plot(epochs_range, acc_diff, 'purple', linewidth=2, label='Train - Val Accuracy')
    plt.axhline(y=0, color='black', linestyle='--', alpha=0.5)
    plt.title('Overfitting Analysis', fontsize=14, fontweight='bold')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy Difference')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Save the plot
    plot_path = os.path.join(RESULTS_PATH, 'plots', 'training_history_50epochs.png')
    fig.savefig(plot_path, dpi=300, bbox_inches='tight')
    print(f"✅ Training plots saved to: {plot_path}")

# Plot enhanced training results
print("Generating comprehensive training results visualization...")
plot_enhanced_training_history(history)

# Enhanced training summary
print("\n" + "="*70)
print("COMPREHENSIVE TRAINING SUMMARY - 50 EPOCHS VERSION")
print("="*70)

epochs_completed = len(history.history['accuracy'])
final_train_acc = history.history['accuracy'][-1]
final_val_acc = history.history['val_accuracy'][-1]
final_train_loss = history.history['loss'][-1]
final_val_loss = history.history['val_loss'][-1]

best_val_acc_epoch = np.argmax(history.history['val_accuracy']) + 1
best_val_acc = max(history.history['val_accuracy'])
min_val_loss_epoch = np.argmin(history.history['val_loss']) + 1
min_val_loss = min(history.history['val_loss'])

print(f"Training Progress:")
print(f"  Epochs completed: {epochs_completed}/{NUM_EPOCHS}")
print(f"  Early stopping: {'Yes' if epochs_completed < NUM_EPOCHS else 'No'}")

print(f"\nFinal Performance:")
print(f"  Training accuracy: {final_train_acc:.4f}")
print(f"  Validation accuracy: {final_val_acc:.4f}")
print(f"  Training loss: {final_train_loss:.4f}")
print(f"  Validation loss: {final_val_loss:.4f}")

print(f"\nBest Performance:")
print(f"  Best validation accuracy: {best_val_acc:.4f} (epoch {best_val_acc_epoch})")
print(f"  Minimum validation loss: {min_val_loss:.4f} (epoch {min_val_loss_epoch})")

# Overfitting analysis
accuracy_gap = final_train_acc - final_val_acc
if accuracy_gap > 0.1:
    print(f"\n⚠️ Overfitting detected! Training accuracy exceeds validation by {accuracy_gap:.3f}")
    print("   Consider: more regularization, data augmentation, or early stopping")
elif accuracy_gap < 0:
    print(f"\n📈 Underfitting possible. Validation accuracy exceeds training by {abs(accuracy_gap):.3f}")
    print("   Consider: reducing regularization or training longer")
else:
    print(f"\n✅ Good fit! Accuracy gap is reasonable: {accuracy_gap:.3f}")

# Additional metrics if available
if 'precision' in history.history:
    final_precision = history.history['precision'][-1]
    print(f"  Final precision: {final_precision:.4f}")

if 'recall' in history.history:
    final_recall = history.history['recall'][-1]
    print(f"  Final recall: {final_recall:.4f}")

# =============================================================================
# CELL 11: Enhanced Model Saving and Checkpointing
# =============================================================================

# Save the final trained model
model_filename = os.path.join(MODEL_SAVE_PATH, 'covid_classifier_vgg16_50epochs_final.h5')
model.save(model_filename)
print(f"✅ Final model saved as: {model_filename}")

# Save training history with enhanced data
import pickle
history_data = {
    'history': history.history,
    'config': {
        'epochs': NUM_EPOCHS,
        'batch_size': BATCH_SIZE,
        'learning_rate': LEARNING_RATE,
        'train_samples': train_generator.samples,
        'val_samples': validation_generator.samples,
        'train_val_split': TRAIN_VAL_SPLIT,
        'augmentation': 'enhanced',
        'architecture': 'vgg16_enhanced'
    },
    'performance': {
        'best_val_acc': best_val_acc,
        'best_val_acc_epoch': best_val_acc_epoch,
        'final_val_acc': final_val_acc,
        'min_val_loss': min_val_loss,
        'min_val_loss_epoch': min_val_loss_epoch,
        'final_val_loss': final_val_loss
    }
}

history_filename = os.path.join(MODEL_SAVE_PATH, 'training_history_50epochs_enhanced.pkl')
with open(history_filename, 'wb') as f:
    pickle.dump(history_data, f)
print(f"✅ Enhanced training history saved as: {history_filename}")

# Create model summary file
summary_filename = os.path.join(RESULTS_PATH, 'model_summary_50epochs.txt')
with open(summary_filename, 'w') as f:
    f.write("COVID-19 Classification Model - 50 Epochs Version\n")
    f.write("="*50 + "\n\n")
    f.write(f"Training completed: {epochs_completed}/{NUM_EPOCHS} epochs\n")
    f.write(f"Best validation accuracy: {best_val_acc:.4f} (epoch {best_val_acc_epoch})\n")
    f.write(f"Final validation accuracy: {final_val_acc:.4f}\n")
    f.write(f"Minimum validation loss: {min_val_loss:.4f} (epoch {min_val_loss_epoch})\n")
    f.write(f"Model architecture: Enhanced VGG16\n")
    f.write(f"Total parameters: {trainable_params + non_trainable_params:,}\n")
    f.write(f"Training samples: {train_generator.samples}\n")
    f.write(f"Validation samples: {validation_generator.samples}\n")

print(f"✅ Model summary saved as: {summary_filename}")

# =============================================================================
# CELL 12: Enhanced Test Data Prediction
# =============================================================================

def predict_test_images_enhanced():
    """
    Enhanced prediction on test images with detailed analysis
    
    Returns:
        list: List of enhanced prediction results
    """
    print("\n" + "="*70)
    print("ENHANCED TEST DATA PREDICTION - 50 EPOCHS VERSION")
    print("="*70)
    
    if not os.path.exists(TEST_PATH):
        print(f"❌ Test data path does not exist: {TEST_PATH}")
        return []
    
    # Collect all test images with enhanced organization
    test_image_paths = []
    test_subdirs = []
    
    # Search in subdirectories
    subdirs = [d for d in os.listdir(TEST_PATH) if os.path.isdir(os.path.join(TEST_PATH, d))]
    if subdirs:
        print(f"Searching for images in subdirectories: {subdirs}")
        for subdir in subdirs:
            subdir_path = os.path.join(TEST_PATH, subdir)
            subdir_files = []
            for file in os.listdir(subdir_path):
                if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
                    full_path = os.path.join(subdir_path, file)
                    test_image_paths.append(full_path)
                    test_subdirs.append(subdir)
                    subdir_files.append(file)
            print(f"  {subdir}: {len(subdir_files)} images")
    
    # Search in root directory
    root_files = [f for f in os.listdir(TEST_PATH) 
                  if os.path.isfile(os.path.join(TEST_PATH, f)) 
                  and f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))]
    
    if root_files:
        print(f"Found {len(root_files)} images in root directory")
        for file in root_files:
            test_image_paths.append(os.path.join(TEST_PATH, file))
            test_subdirs.append('root')
    
    if not test_image_paths:
        print("❌ No test images found!")
        return []
    
    print(f"✅ Found {len(test_image_paths)} test images total")
    
    # Enhanced preprocessing with batch processing
    test_images = []
    valid_paths = []
    valid_subdirs = []
    failed_images = []
    
    print("Preprocessing test images with enhanced validation...")
    for i, (img_path, subdir) in enumerate(tqdm(zip(test_image_paths, test_subdirs), 
                                                 desc="Loading images", 
                                                 total=len(test_image_paths))):
        try:
            img = cv2.imread(img_path)
            if img is not None:
                # Enhanced preprocessing with validation
                original_shape = img.shape
                img_resized = cv2.resize(img, IMG_SIZE)
                img_preprocessed = preprocess_input(img_resized)
                
                # Basic image quality checks
                if np.mean(img_preprocessed) != 0:  # Check if image is not completely black
                    test_images.append(img_preprocessed)
                    valid_paths.append(img_path)
                    valid_subdirs.append(subdir)
                else:
                    failed_images.append((img_path, "Image appears to be corrupted"))
            else:
                failed_images.append((img_path, "Could not read image"))
        except Exception as e:
            failed_images.append((img_path, f"Error: {str(e)}"))
    
    if failed_images:
        print(f"⚠️ Failed to process {len(failed_images)} images:")
        for path, reason in failed_images[:5]:  # Show first 5 failures
            print(f"  {os.path.basename(path)}: {reason}")
        if len(failed_images) > 5:
            print(f"  ... and {len(failed_images) - 5} more")
    
    if not test_images:
        print("❌ No valid test images loaded!")
        return []
    
    test_images = np.array(test_images)
    print(f"✅ Successfully loaded {len(test_images)} test images")
    print(f"Image array shape: {test_images.shape}")
    
    # Enhanced prediction with uncertainty estimation
    print("Making enhanced predictions...")
    print("Using model with 50-epoch training for optimal performance...")
    
    # Get predictions with additional statistics
    predictions = model.predict(test_images, verbose=1, batch_size=BATCH_SIZE)
    predicted_classes = (predictions > 0.5).astype(int).flatten()
    
    # Calculate prediction confidence and uncertainty
    prediction_probs = predictions.flatten()
    uncertainties = 1 - np.abs(prediction_probs - 0.5) * 2  # Distance from decision boundary
    
    # Get class name mapping
    class_names = {v: k for k, v in train_generator.class_indices.items()}
    
    # Prepare enhanced results
    results = []
    for i, (img_path, subdir, pred_prob, pred_class, uncertainty) in enumerate(
        zip(valid_paths, valid_subdirs, prediction_probs, predicted_classes, uncertainties)):
        
        filename = os.path.basename(img_path)
        class_name = class_names[pred_class]
        confidence = pred_prob if pred_class == 1 else 1 - pred_prob
        
        results.append({
            'filename': filename,
            'path': img_path,
            'subdirectory': subdir,
            'prediction': class_name,
            'confidence': confidence,
            'uncertainty': uncertainty,
            'raw_probability': pred_prob,
            'predicted_class': pred_class
        })
    
    return results

# Execute enhanced test prediction
test_results = predict_test_images_enhanced()

# =============================================================================
# CELL 13: Comprehensive Results Analysis and Visualization
# =============================================================================

if test_results:
    print("\n" + "="*70)
    print("COMPREHENSIVE PREDICTION RESULTS - 50 EPOCHS VERSION")
    print("="*70)
    
    # Enhanced sorting and analysis
    test_results.sort(key=lambda x: x['confidence'], reverse=True)
    
    # Detailed statistics
    total_predictions = len(test_results)
    class_counts = {}
    confidence_stats = {'yes': [], 'no': []}
    uncertainty_stats = {'yes': [], 'no': []}
    subdir_stats = {}
    
    for result in test_results:
        pred = result['prediction']
        subdir = result['subdirectory']
        
        # Class statistics
        class_counts[pred] = class_counts.get(pred, 0) + 1
        confidence_stats[pred].append(result['confidence'])
        uncertainty_stats[pred].append(result['uncertainty'])
        
        # Subdirectory statistics
        if subdir not in subdir_stats:
            subdir_stats[subdir] = {'yes': 0, 'no': 0, 'total': 0}
        subdir_stats[subdir][pred] += 1
        subdir_stats[subdir]['total'] += 1
    
    # Display top predictions with enhanced information
    print("\n📊 Top predictions (by confidence):")
    print("-" * 100)
    print(f"{'Filename':<35} {'Prediction':<10} {'Confidence':<12} {'Uncertainty':<12} {'Source':<15}")
    print("-" * 100)
    
    for i, result in enumerate(test_results[:25]):  # Show top 25
        print(f"{result['filename']:<35} {result['prediction']:<10} "
              f"{result['confidence']:.4f}      {result['uncertainty']:.4f}      "
              f"{result['subdirectory']:<15}")
    
    if len(test_results) > 25:
        print(f"... and {len(test_results) - 25} more results")
    
    # Enhanced statistical analysis
    print("\n" + "="*70)
    print("DETAILED PREDICTION STATISTICS")
    print("="*70)
    
    # Class distribution analysis
    print("\n🎯 Class Distribution:")
    for class_name, count in class_counts.items():
        percentage = (count / total_predictions) * 100
        avg_confidence = np.mean(confidence_stats[class_name])
        avg_uncertainty = np.mean(uncertainty_stats[class_name])
        std_confidence = np.std(confidence_stats[class_name])
        
        print(f"  {class_name.upper()}: {count:3d} images ({percentage:5.1f}%)")
        print(f"    Average confidence: {avg_confidence:.4f} (±{std_confidence:.4f})")
        print(f"    Average uncertainty: {avg_uncertainty:.4f}")
    
    # Subdirectory analysis
    if len(subdir_stats) > 1:
        print("\n📁 Subdirectory Analysis:")
        for subdir, stats in subdir_stats.items():
            if stats['total'] > 0:
                yes_pct = (stats['yes'] / stats['total']) * 100
                no_pct = (stats['no'] / stats['total']) * 100
                print(f"  {subdir}: {stats['total']} images")
                print(f"    YES: {stats['yes']} ({yes_pct:.1f}%), NO: {stats['no']} ({no_pct:.1f}%)")
    
    # Overall confidence statistics
    all_confidences = [r['confidence'] for r in test_results]
    all_uncertainties = [r['uncertainty'] for r in test_results]
    
    print(f"\n📈 Overall Confidence Statistics:")
    print(f"  Mean confidence: {np.mean(all_confidences):.4f}")
    print(f"  Median confidence: {np.median(all_confidences):.4f}")
    print(f"  Standard deviation: {np.std(all_confidences):.4f}")
    print(f"  Min confidence: {np.min(all_confidences):.4f}")
    print(f"  Max confidence: {np.max(all_confidences):.4f}")
    
    print(f"\n🎲 Uncertainty Analysis:")
    print(f"  Mean uncertainty: {np.mean(all_uncertainties):.4f}")
    print(f"  Median uncertainty: {np.median(all_uncertainties):.4f}")
    print(f"  High uncertainty (>0.8): {sum(1 for u in all_uncertainties if u > 0.8)} images")
    
    # Enhanced confidence distribution
    very_high_conf = sum(1 for c in all_confidences if c > 0.95)
    high_conf = sum(1 for c in all_confidences if 0.9 <= c <= 0.95)
    medium_conf = sum(1 for c in all_confidences if 0.7 <= c < 0.9)
    low_conf = sum(1 for c in all_confidences if c < 0.7)
    
    print(f"\n🎯 Enhanced Confidence Distribution:")
    print(f"  Very high confidence (>0.95): {very_high_conf} images ({very_high_conf/total_predictions*100:.1f}%)")
    print(f"  High confidence (0.9-0.95): {high_conf} images ({high_conf/total_predictions*100:.1f}%)")
    print(f"  Medium confidence (0.7-0.9): {medium_conf} images ({medium_conf/total_predictions*100:.1f}%)")
    print(f"  Low confidence (<0.7): {low_conf} images ({low_conf/total_predictions*100:.1f}%)")
    
    # Create enhanced visualizations
    print("\n📊 Creating enhanced result visualizations...")
    
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
    
    # Confidence distribution histogram
    ax1.hist(all_confidences, bins=20, alpha=0.7, color='skyblue', edgecolor='black')
    ax1.axvline(np.mean(all_confidences), color='red', linestyle='--', 
                label=f'Mean: {np.mean(all_confidences):.3f}')
    ax1.set_title('Prediction Confidence Distribution', fontweight='bold')
    ax1.set_xlabel('Confidence')
    ax1.set_ylabel('Frequency')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Class distribution pie chart
    ax2.pie(class_counts.values(), labels=[f'{k.upper()}\n({v} images)' for k, v in class_counts.items()], 
            autopct='%1.1f%%', startangle=90, colors=['lightcoral', 'lightblue'])
    ax2.set_title('Class Distribution', fontweight='bold')
    
    # Confidence vs Uncertainty scatter plot
    confidences = [r['confidence'] for r in test_results]
    uncertainties = [r['uncertainty'] for r in test_results]
    colors = ['red' if r['prediction'] == 'yes' else 'blue' for r in test_results]
    
    ax3.scatter(confidences, uncertainties, c=colors, alpha=0.6, s=30)
    ax3.set_xlabel('Confidence')
    ax3.set_ylabel('Uncertainty')
    ax3.set_title('Confidence vs Uncertainty', fontweight='bold')
    ax3.grid(True, alpha=0.3)
    
    # Add legend for colors
    from matplotlib.patches import Patch
    legend_elements = [Patch(facecolor='red', label='COVID-19 Positive'),
                      Patch(facecolor='blue', label='COVID-19 Negative')]
    ax3.legend(handles=legend_elements)
    
    # Confidence levels bar chart
    conf_levels = ['Very High\n(>0.95)', 'High\n(0.9-0.95)', 'Medium\n(0.7-0.9)', 'Low\n(<0.7)']
    conf_counts = [very_high_conf, high_conf, medium_conf, low_conf]
    colors_bar = ['darkgreen', 'green', 'orange', 'red']
    
    bars = ax4.bar(conf_levels, conf_counts, color=colors_bar, alpha=0.7)
    ax4.set_title('Confidence Level Distribution', fontweight='bold')
    ax4.set_ylabel('Number of Images')
    
    for bar, count in zip(bars, conf_counts):
        ax4.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(conf_counts) * 0.01,
                str(count), ha='center', va='bottom', fontweight='bold')
    
    plt.tight_layout()
    plt.show()
    
    # Save results
    results_plot_path = os.path.join(RESULTS_PATH, 'plots', 'prediction_results_50epochs.png')
    fig.savefig(results_plot_path, dpi=300, bbox_inches='tight')
    print(f"✅ Results visualization saved to: {results_plot_path}")
    
    # Save detailed results to JSON
    import json
    results_json_path = os.path.join(RESULTS_PATH, 'predictions', 'detailed_predictions_50epochs.json')
    os.makedirs(os.path.dirname(results_json_path), exist_ok=True)
    
    # Convert numpy types to native Python types for JSON serialization
    json_results = []
    for result in test_results:
        json_result = {k: float(v) if isinstance(v, (np.floating, np.integer)) else v 
                      for k, v in result.items()}
        json_results.append(json_result)
    
    with open(results_json_path, 'w') as f:
        json.dump({
            'metadata': {
                'model_version': '50_epochs_enhanced',
                'total_predictions': total_predictions,
                'prediction_date': str(pd.Timestamp.now() if 'pd' in globals() else 'Unknown'),
                'model_performance': {
                    'final_val_accuracy': final_val_acc,
                    'best_val_accuracy': best_val_acc
                }
            },
            'predictions': json_results,
            'summary_statistics': {
                'class_distribution': class_counts,
                'confidence_stats': {
                    'mean': float(np.mean(all_confidences)),
                    'median': float(np.median(all_confidences)),
                    'std': float(np.std(all_confidences)),
                    'min': float(np.min(all_confidences)),
                    'max': float(np.max(all_confidences))
                },
                'confidence_levels': {
                    'very_high': very_high_conf,
                    'high': high_conf,
                    'medium': medium_conf,
                    'low': low_conf
                }
            }
        }, f, indent=2)
    
    print(f"✅ Detailed results saved to: {results_json_path}")

else:
    print("❌ No test results to display")

# =============================================================================
# CELL 14: Final Summary and Advanced Recommendations
# =============================================================================

print("\n" + "="*80)
print("TRAINING AND PREDICTION COMPLETED SUCCESSFULLY - 50 EPOCHS VERSION!")
print("="*80)

print("\nFiles created during this session:")
print(f"✅ Final trained model: {model_filename}")
print(f"✅ Best model checkpoint: {os.path.join(MODEL_SAVE_PATH, 'best_model_50epochs.h5')}")
print(f"✅ Enhanced training history: {history_filename}")
print(f"✅ Model summary: {summary_filename}")
print(f"✅ Training/validation directories: {PROCESSED_TRAIN_DIR}, {PROCESSED_VAL_DIR}")

if test_results:
    print(f"✅ Detailed predictions: {results_json_path}")
    print(f"✅ Results visualization: {results_plot_path}")

print(f"\nModel Performance Summary:")
print(f"✅ Final training accuracy: {final_train_acc:.4f}")
print(f"✅ Final validation accuracy: {final_val_acc:.4f}")
print(f"✅ Best validation accuracy: {best_val_acc:.4f} (epoch {best_val_acc_epoch})")
print(f"✅ Final validation loss: {final_val_loss:.4f}")

if test_results:
    avg_confidence = np.mean([r['confidence'] for r in test_results])
    high_conf_pct = (very_high_conf + high_conf) / total_predictions * 100
    print(f"✅ Test images processed: {len(test_results)}")
    print(f"✅ Average prediction confidence: {avg_confidence:.4f}")
    print(f"✅ High confidence predictions: {high_conf_pct:.1f}%")

print("\n" + "="*80)
print("ADVANCED RECOMMENDATIONS AND NEXT STEPS")
print("="*80)

print("\n🎯 Model Performance Analysis:")
if final_val_acc > 0.9:
    print("✅ Excellent model performance achieved!")
elif final_val_acc > 0.8:
    print("✅ Good model performance. Consider fine-tuning for improvement.")
else:
    print("⚠️ Model performance could be improved. See recommendations below.")

print("\n🔧 For Further Improvement:")
print("1. Fine-tuning approach:")
print("   - Unfreeze top layers of VGG16 for domain-specific learning")
print("   - Use lower learning rate (1e-5) for fine-tuning")
print("   - Implement gradual unfreezing strategy")

print("\n2. Data enhancement:")
print("   - Collect more diverse training data")
print("   - Implement advanced augmentation (mixup, cutmix)")
print("   - Consider synthetic data generation")

print("\n3. Architecture experiments:")
print("   - Try EfficientNet, ResNet, or DenseNet backbones")
print("   - Implement ensemble methods")
print("   - Experiment with attention mechanisms")

print("\n4. Advanced techniques:")
print("   - Implement class activation maps (CAM) for explainability")
print("   - Use test-time augmentation for robust predictions")
print("   - Apply uncertainty quantification methods")

print("\n📊 Production Deployment:")
print("1. Model validation:")
print("   - Test on external datasets")
print("   - Conduct clinical validation studies")
print("   - Implement A/B testing framework")

print("\n2. Implementation considerations:")
print("   - Set appropriate confidence thresholds")
print("   - Implement human-in-the-loop workflows")
print("   - Create model monitoring dashboards")

print("\n3. Regulatory and ethical:")
print("   - Ensure HIPAA compliance for medical data")
print("   - Implement bias detection and mitigation")
print("   - Document model limitations and failure cases")

print("\n🚀 Comparison with 10-epoch version:")
print("- Expected better generalization with 50-epoch training")
print("- More stable predictions due to extended training")
print("- Better feature learning from enhanced augmentation")
print("- Reduced overfitting with improved regularization")

print("\n")