# **AI DRIVEN PRECISION AGRICULTURE FOR EARLY DISEASE DETECTION AND SUSTAINABLE CROP PROTECTION**

# **STAGE-2-MULTI-CLASS CLASSIFIER FOR 59 CLASSES (ONLY-DISEASED)**

# **DATASET: PlantWild (Benchmarking In-the-Wild Multimodal Plant Disease Recognition and A Versatile Baseline)**
# **LINK TO PAPER: https://tqwei05.github.io/PlantWild**
# **LINK TO DATASET: https://huggingface.co/datasets/uqtwei2/PlantWild/tree/main**

# **IMPORTS**

In [None]:
# Import essential libraries for deep learning and data processing
import tensorflow as tf
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout, BatchNormalization
from tensorflow.keras.optimizers import Adam, SGD, RMSprop
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical

# Import data manipulation and analysis libraries
import pandas as pd
import numpy as np
import os
import zipfile
import gdown
import json
import pickle
from collections import Counter

# Import machine learning evaluation libraries
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
from sklearn.metrics import classification_report, roc_auc_score, average_precision_score
from sklearn.metrics import cohen_kappa_score, matthews_corrcoef
from sklearn.utils.class_weight import compute_class_weight
from sklearn.model_selection import StratifiedKFold

# Import visualization libraries
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Rectangle
import cv2

# Import statistical analysis libraries
from scipy import stats
from scipy.stats import bootstrap
import warnings
warnings.filterwarnings('ignore')

# Import Google Drive integration
from google.colab import drive

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

# Configure matplotlib for better plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("All libraries imported successfully!")
print("Ready for enhanced Stage 2 implementation...")

# **GOOGLE DRIVE MOUNT AND DATASET DOWNLOAD**

In [None]:
# Mount Google Drive to access files and save models
from google.colab import drive  # Import Google Drive mounting functionality
drive.mount('/content/drive')  # Mount Google Drive to access files
print("Google Drive mounted successfully!")  # Confirm successful mounting

In [None]:
# Create working directory and download dataset
import os  # Import operating system interface
import zipfile  # Import zip file handling functionality
import gdown  # Import Google Drive download utility

# Create working directory for dataset storage
os.makedirs("/content/plantwild", exist_ok=True)  # Create directory for dataset storage

# Google Drive File ID from your shared PlantWild link
file_id = "1TVvXiJIWvpOYUba78gm6ALuy52Ks6IwW"  # Unique identifier for dataset file
zip_path = "/content/plantwild/plantwild.zip"  # Local path for downloaded zip file

# Download the file using gdown utility
print("Downloading PlantWild dataset...")  # Inform user of download start
gdown.download(f"https://drive.google.com/uc?id={file_id}", zip_path, quiet=False)  # Download dataset from Google Drive

# Unzip the dataset to extract contents
print("Extracting dataset...")  # Inform user of extraction start
with zipfile.ZipFile(zip_path, 'r') as zip_ref:  # Open zip file for reading
    zip_ref.extractall("/content/plantwild")  # Extract all contents to plantwild directory

print("PlantWild dataset downloaded and extracted to /content/plantwild")  # Confirm successful extraction

# Verify dataset structure and confirm successful setup
DATASET_ROOT = "/content/plantwild/plantwild"  # Define path to extracted dataset
if os.path.exists(DATASET_ROOT):  # Check if dataset directory exists
    print(f"Dataset found at: {DATASET_ROOT}")  # Confirm dataset location
    print(f"Contents: {os.listdir(DATASET_ROOT)}")  # Display dataset contents
else:
    print("Dataset not found!")  # Error message if dataset not found

print("Dataset download and verification completed!")  # Confirm dataset setup completion

# **CONFIGURATION AND SET-UP**

In [None]:
# Enhanced Stage 2 configuration for precision agriculture optimization
IMG_HEIGHT = 224  # Image height for MobileNetV2 input
IMG_WIDTH = 224   # Image width for MobileNetV2 input
TOTAL_CLASSES = 89  # Total classes in PlantWild dataset
DISEASED_CLASSES = 59  # Number of diseased classes for Stage 2
HEALTHY_CLASSES = 30  # Number of healthy classes (filtered out in Stage 2)

# Enhanced training parameters for better convergence
EPOCHS_PHASE1 = 15  # Epochs for head training (increased from 10)
EPOCHS_PHASE2 = 20  # Epochs for fine-tuning (increased from 15)
BATCH_SIZE = 16  # Batch size for training (optimized for A-100 GPU)

# Advanced learning rates for different optimizers
LEARNING_RATES = {
    'AdamW': {'phase1': 1e-3, 'phase2': 1e-5},  # AdamW with weight decay
    'Lion': {'phase1': 1e-3, 'phase2': 1e-5},   # Lion optimizer for speed
    'AdaBelief': {'phase1': 1e-3, 'phase2': 1e-5}  # AdaBelief for imbalanced data
}

# Enhanced regularization parameters
DROPOUT_RATE = 0.5  # Increased dropout for better generalization
L2_REGULARIZATION = 1e-4  # L2 regularization for weight decay
WEIGHT_DECAY = 1e-4  # Weight decay for AdamW optimizer

# Advanced data augmentation parameters for minority classes (FIXED - removed unsupported parameters)
AUGMENTATION_PARAMS = {
    'rotation_range': 25,  # Increased rotation range
    'width_shift_range': 0.2,  # Horizontal shift range
    'height_shift_range': 0.2,  # Vertical shift range
    'zoom_range': 0.2,  # Zoom range for scale variation
    'horizontal_flip': True,  # Enable horizontal flipping
    'vertical_flip': False,  # Disable vertical flipping (unrealistic for plants)
    'brightness_range': [0.7, 1.3],  # Brightness variation range
    'fill_mode': 'nearest',  # Fill mode for augmented images
    'cval': 0.0  # Constant value for fill mode
}

# Directory configuration for model storage and analysis
LOCAL_MODEL_DIR = "/content/models_stage2_enhanced"  # Local model storage directory
DRIVE_BACKUP_DIR = "/content/drive/MyDrive/plantwild_stage2_enhanced"  # Google Drive backup directory
DRIVE_MODELS_DIR = "/content/drive/MyDrive/Stage2_Enhanced_Models"  # Enhanced models directory
DRIVE_ANALYSIS_DIR = "/content/drive/MyDrive/Stage2_Enhanced_Analysis"  # Analysis results directory
DRIVE_VISUALIZATIONS_DIR = "/content/drive/MyDrive/Stage2_Enhanced_Visualizations"  # Visualization directory
DRIVE_GRADCAM_DIR = "/content/drive/MyDrive/Stage2_Enhanced_GradCAM"  # GradCAM analysis directory

# Create necessary directories for organization
os.makedirs(LOCAL_MODEL_DIR, exist_ok=True)  # Create local model directory
os.makedirs(DRIVE_BACKUP_DIR, exist_ok=True)  # Create Google Drive backup directory
os.makedirs(DRIVE_MODELS_DIR, exist_ok=True)  # Create enhanced models directory
os.makedirs(DRIVE_ANALYSIS_DIR, exist_ok=True)  # Create analysis directory
os.makedirs(DRIVE_VISUALIZATIONS_DIR, exist_ok=True)  # Create visualizations directory
os.makedirs(DRIVE_GRADCAM_DIR, exist_ok=True)  # Create GradCAM directory

# Enhanced ensemble configuration with powerful optimizers
ENSEMBLE_CONFIGS_STAGE2 = [
    {'name': 'AdamW', 'optimizer_class': 'AdamW', 'weight_decay': 1e-4},  # AdamW with weight decay
    {'name': 'Lion', 'optimizer_class': 'Lion', 'weight_decay': 1e-4},    # Lion optimizer for speed
    {'name': 'AdaBelief', 'optimizer_class': 'AdaBelief', 'eps': 1e-16}   # AdaBelief for imbalanced data
]

print("Enhanced Stage 2 configuration completed!")
print(f"Training on {DISEASED_CLASSES} diseased classes with advanced optimizers")
print(f"Models will be saved to: {DRIVE_MODELS_DIR}")

# **DATA LOADING FUNCTIONS**

In [None]:
# Enhanced data loading functions with precision agriculture focus
def load_class_mapping(classes_file):
    """
    Load class mapping from classes.txt with enhanced categorization
    for precision agriculture applications
    """
    class_mapping = {}  # Dictionary to store class ID to name mapping
    healthy_classes = []  # List to store healthy class names
    diseased_classes = []  # List to store diseased class names

    # Read classes.txt file and parse class information
    with open(classes_file, 'r') as f:
        for line in f:
            parts = line.strip().split(' ', 1)  # Split on first space to separate ID and name
            if len(parts) == 2:  # Ensure line has both ID and name
                class_id = int(parts[0])  # Convert class ID to integer
                class_name = parts[1]  # Extract class name
                class_mapping[class_id] = class_name  # Store in mapping dictionary

                # Categorize class as healthy or diseased for binary classification
                if is_healthy(class_name):
                    healthy_classes.append(class_name)  # Add to healthy classes list
                else:
                    diseased_classes.append(class_name)  # Add to diseased classes list

    return class_mapping, healthy_classes, diseased_classes  # Return all mappings

def is_healthy(class_name):
    """
    Enhanced function to determine if a class represents a healthy plant
    based on disease-related keywords and precision agriculture criteria
    """
    # Comprehensive list of disease indicators for agricultural applications
    disease_keywords = [
        'rot', 'blight', 'virus', 'mosaic', 'rust', 'scab',
        'spot', 'canker', 'wilt', 'anthracnose', 'mildew',
        'curl', 'greening', 'smut', 'cavity', 'pocket', 'scorch',
        'disease', 'infection', 'lesion', 'necrosis'
    ]

    # A class is considered healthy if it ends with 'leaf' AND doesn't contain disease keywords
    # This is crucial for early disease detection in precision agriculture
    return (class_name.endswith(' leaf') and
            not any(keyword in class_name.lower() for keyword in disease_keywords))

def load_stage2_dataframe(dataset_root):
    """
    Load and prepare Stage 2 dataset with enhanced analysis for precision agriculture
    """
    # Define file paths for dataset loading
    classes_file = os.path.join(dataset_root, 'classes.txt')  # Path to classes file
    trainval_file = os.path.join(dataset_root, 'trainval.txt')  # Path to trainval file

    # Load class mapping and categorization
    class_mapping, healthy_classes, diseased_classes = load_class_mapping(classes_file)

    # Read trainval.txt file with proper delimiter
    df = pd.read_csv(trainval_file, sep='=', names=['image_path', 'class_id', 'split'])

    # Add class names and binary labels for two-stage classification
    df['class_name'] = df['class_id'].map(class_mapping)  # Map class IDs to names
    df['binary_label'] = df['class_name'].apply(is_healthy)  # Apply healthy/diseased classification
    df['binary_label'] = df['binary_label'].map({True: 'healthy', False: 'diseased'})  # Convert to string labels

    # Add split information for training, validation, and testing
    # FIXED: The split values are 0=test, 1=train, 2=val, but we need to handle them correctly
    df['split_name'] = df['split'].map({0: 'test', 1: 'train', 2: 'val'})  # Map split codes to names

    # Verify image file existence for data integrity
    df['image_exists'] = df['image_path'].apply(lambda x: os.path.exists(os.path.join(dataset_root, 'images', x)))

    # Filter out non-existent images for reliable training
    df = df[df['image_exists'] == True].copy()  # Keep only existing images
    df = df.drop('image_exists', axis=1)  # Remove verification column

    return df, class_mapping, healthy_classes, diseased_classes  # Return processed dataframe and mappings

# Load Stage 2 dataset with enhanced analysis
print("Loading enhanced Stage 2 dataset for precision agriculture analysis...")
df_stage2, class_mapping, healthy_classes, diseased_classes = load_stage2_dataframe(DATASET_ROOT)

# Display comprehensive dataset statistics for precision agriculture insights
print("\n" + "="*60)
print("ENHANCED STAGE 2 DATASET ANALYSIS FOR PRECISION AGRICULTURE")
print("="*60)

# Basic dataset statistics
print(f"Total samples: {len(df_stage2):,}")
print(f"Total classes: {len(class_mapping)}")
print(f"Healthy classes: {len(healthy_classes)}")
print(f"Diseased classes: {len(diseased_classes)}")

# Split distribution analysis
print("\nDataset Split Distribution:")
split_counts = df_stage2['split_name'].value_counts()
for split, count in split_counts.items():
    print(f"  {split.capitalize()}: {count:,} samples ({count/len(df_stage2)*100:.1f}%)")

# Binary label distribution for Stage 1 analysis
print("\nBinary Classification Distribution:")
binary_counts = df_stage2['binary_label'].value_counts()
for label, count in binary_counts.items():
    print(f"  {label.capitalize()}: {count:,} samples ({count/len(df_stage2)*100:.1f}%)")

# Class distribution analysis for precision agriculture insights
print("\nTop 10 Most Frequent Classes:")
class_counts = df_stage2['class_name'].value_counts().head(10)
for i, (class_name, count) in enumerate(class_counts.items(), 1):
    print(f"  {i:2d}. {class_name}: {count:,} samples")

# Diseased class analysis for early detection focus
print("\nDiseased Classes Sample Counts:")
diseased_df = df_stage2[df_stage2['binary_label'] == 'diseased']
diseased_class_counts = diseased_df['class_name'].value_counts()
print(f"Total diseased samples: {len(diseased_df):,}")
print(f"Average samples per diseased class: {len(diseased_df)/len(diseased_classes):.1f}")

# Debug: Show actual split values to understand the data structure
print("\nDEBUG: Actual split values in dataset:")
print(df_stage2['split'].value_counts().sort_index())

print("\nDataset loading and analysis completed successfully!")

# **DATA PREPARATION AND PREPROCESSING**

In [None]:
# **ENHANCED STAGE 2 DATA PREPARATION AND PREPROCESSING FOR PRECISION AGRICULTURE**

# Import required libraries for enhanced data preparation
import pandas as pd # Import pandas for data manipulation
import numpy as np # Import numpy for numerical operations
from sklearn.model_selection import train_test_split # Import train-test split functionality
from sklearn.preprocessing import LabelEncoder # Import label encoder for class encoding
import os # Import operating system interface
import json # Import JSON handling for saving class mappings
from tensorflow.keras.preprocessing.image import ImageDataGenerator # Import image data generator
import tensorflow as tf # Import TensorFlow

print("Preparing enhanced Stage 2 data for precision agriculture disease classification...")

# Define dataset paths and parameters
DATASET_ROOT = "/content/plantwild/plantwild" # Root directory of dataset
IMAGES_DIR = os.path.join(DATASET_ROOT, "images") # Path to images directory
IMG_HEIGHT, IMG_WIDTH = 224, 224 # Image dimensions for model input
BATCH_SIZE = 32 # Batch size for training

# Verify dataset structure
if not os.path.exists(DATASET_ROOT):
    print(f"Dataset not found at: {DATASET_ROOT}")
    print("Please ensure you have downloaded and extracted the PlantWild dataset")
    exit()

print(f"Dataset found at: {DATASET_ROOT}")
print(f"Contents: {os.listdir(DATASET_ROOT)}")

# Check if images directory exists
if not os.path.exists(IMAGES_DIR):
    print(f"Images directory not found at: {IMAGES_DIR}")
    exit()

print(f"Images directory found at: {IMAGES_DIR}")

# Get all class folders (subdirectories in images folder)
class_folders = [d for d in os.listdir(IMAGES_DIR) if os.path.isdir(os.path.join(IMAGES_DIR, d))]
print(f"Found {len(class_folders)} class folders")

# Create dataset dataframe from folder structure
dataset_data = [] # List to store dataset information

for class_name in class_folders:
    class_path = os.path.join(IMAGES_DIR, class_name) # Path to class folder
    image_files = [f for f in os.listdir(class_path) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))] # Get image files

    for image_file in image_files:
        # Create relative path from images directory
        relative_path = os.path.join(class_name, image_file) # Relative path for image
        dataset_data.append({
            'image_path': relative_path, # Store relative path
            'class_name': class_name, # Store class name
            'full_path': os.path.join(class_path, image_file) # Store full path for verification
        })

# Create dataframe from collected data
df_stage2 = pd.DataFrame(dataset_data) # Convert to pandas dataframe

print(f"Total images found: {len(df_stage2)}")
print(f"Classes found: {len(df_stage2['class_name'].unique())}")

# Display first few classes
print("\nFirst 10 classes found:")
for i, class_name in enumerate(df_stage2['class_name'].unique()[:10]):
    class_count = len(df_stage2[df_stage2['class_name'] == class_name])
    print(f"  {i+1}. {class_name}: {class_count} images")

# Create label encoder for class encoding
label_encoder = LabelEncoder() # Initialize label encoder
df_stage2['encoded_class_id'] = label_encoder.fit_transform(df_stage2['class_name']) # Encode class names to integers

# Filter for diseased classes only (excluding healthy classes)
def is_healthy(class_name):
    """Determine if a class is healthy based on disease keywords and leaf suffix"""
    healthy_keywords = ['healthy', 'normal', 'good']
    disease_keywords = ['rot', 'blight', 'mildew', 'mosaic', 'rust', 'scab', 'spot', 'virus', 'disease', 'infected', 'anthracnose', 'canker', 'wilt', 'curl', 'yellows', 'mottle', 'streak', 'necrosis', 'chlorosis', 'lesion']

    # Check if class name contains any disease keywords
    has_disease = any(keyword in class_name.lower() for keyword in disease_keywords)

    # Check if class name contains healthy keywords
    has_healthy = any(keyword in class_name.lower() for keyword in healthy_keywords)

    # Class is healthy if it has healthy keywords and no disease keywords
    # OR if it ends with 'leaf' and has no disease keywords
    return (has_healthy and not has_disease) or (class_name.lower().endswith('leaf') and not has_disease)

# Apply binary classification
df_stage2['binary_label'] = df_stage2['class_name'].apply(lambda x: 'healthy' if is_healthy(x) else 'diseased') # Classify as healthy or diseased

# Display binary classification results
binary_counts = df_stage2['binary_label'].value_counts()
print(f"\nBinary classification results:")
print(f"  Healthy classes: {binary_counts.get('healthy', 0)}")
print(f"  Diseased classes: {binary_counts.get('diseased', 0)}")

# Filter for diseased classes only
df_diseased = df_stage2[df_stage2['binary_label'] == 'diseased'].copy() # Keep only diseased samples

# Reset index and re-encode class IDs for diseased classes only
df_diseased = df_diseased.reset_index(drop=True) # Reset index for clean numbering
diseased_label_encoder = LabelEncoder() # Create new label encoder for diseased classes
df_diseased['encoded_class_id'] = diseased_label_encoder.fit_transform(df_diseased['class_name']) # Re-encode diseased classes

# Create class mapping for diseased classes
diseased_class_mapping = {i: name for i, name in enumerate(diseased_label_encoder.classes_)} # Map encoded IDs to class names

# Display dataset information
print(f"\nStage 2 data preparation completed:")
print(f" Original diseased samples: {len(df_diseased):,}")
print(f" Diseased classes: {len(diseased_class_mapping)}")
print(f" Class ID range: 0 to {len(diseased_class_mapping) - 1}")

# Create train/validation/test split
train_df, temp_df = train_test_split(df_diseased, test_size=0.3, random_state=42, stratify=df_diseased['encoded_class_id']) # Split 70% train, 30% temp
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42, stratify=temp_df['encoded_class_id']) # Split temp into 15% val, 15% test

# Add split identifiers
train_df['split'] = 0 # Training split identifier
val_df['split'] = 1 # Validation split identifier
test_df['split'] = 2 # Test split identifier

# Combine all splits
df_diseased = pd.concat([train_df, val_df, test_df], ignore_index=True) # Combine all splits into single dataframe

# Create split name mapping
split_mapping = {0: 'train', 1: 'val', 2: 'test'} # Map split IDs to names
df_diseased['split_name'] = df_diseased['split'].map(split_mapping) # Add split names for clarity

# Debug: Check split values
print("DEBUG: Checking split values in diseased dataframe:")
print(f" Total diseased samples: {len(df_diseased)}")
print(f" Split value counts: {df_diseased['split'].value_counts()}")
print(f" Split name counts: {df_diseased['split_name'].value_counts()}")

# Create absolute paths for images
def create_absolute_path(relative_path):
    """Convert relative path to absolute path for image loading"""
    if pd.isna(relative_path):
        return None
    # Remove any leading/trailing whitespace and quotes
    clean_path = str(relative_path).strip().strip('"\'')
    # Create absolute path
    absolute_path = os.path.join(IMAGES_DIR, clean_path)
    return absolute_path

# Apply absolute path conversion
df_diseased['image_path_absolute'] = df_diseased['image_path'].apply(create_absolute_path) # Convert to absolute paths

# Filter out rows with invalid image paths
df_diseased = df_diseased.dropna(subset=['image_path_absolute']) # Remove rows with missing image paths

# Verify file existence
df_diseased['file_exists'] = df_diseased['image_path_absolute'].apply(lambda x: os.path.exists(x) if x else False) # Check if files exist
df_diseased = df_diseased[df_diseased['file_exists'] == True] # Keep only existing files

print(f"Valid image files found: {len(df_diseased)}")

# Create data generators for Stage 2 training
def create_stage2_generators(df_diseased, image_size, batch_size):
    """Create data generators for Stage 2 training with proper categorical encoding"""

    # Define augmentation parameters for precision agriculture
    AUGMENTATION_PARAMS = {
        'rotation_range': 20, # Rotation range in degrees
        'width_shift_range': 0.2, # Width shift range as fraction of total width
        'height_shift_range': 0.2, # Height shift range as fraction of total height
        'shear_range': 0.2, # Shear intensity
        'zoom_range': 0.2, # Zoom range
        'horizontal_flip': True, # Enable horizontal flipping
        'fill_mode': 'nearest', # Fill mode for augmented pixels
        'validation_split': 0.0 # No validation split in generators
    }

    # Create training data generator with comprehensive augmentation
    train_generator = ImageDataGenerator(
        preprocessing_function=tf.keras.applications.mobilenet_v2.preprocess_input, # MobileNetV2 preprocessing
        rotation_range=AUGMENTATION_PARAMS['rotation_range'], # Rotation augmentation
        width_shift_range=AUGMENTATION_PARAMS['width_shift_range'], # Width shift augmentation
        height_shift_range=AUGMENTATION_PARAMS['height_shift_range'], # Height shift augmentation
        shear_range=AUGMENTATION_PARAMS['shear_range'], # Shear augmentation
        zoom_range=AUGMENTATION_PARAMS['zoom_range'], # Zoom augmentation
        horizontal_flip=AUGMENTATION_PARAMS['horizontal_flip'], # Horizontal flip augmentation
        fill_mode=AUGMENTATION_PARAMS['fill_mode'], # Fill mode for augmented pixels
        validation_split=AUGMENTATION_PARAMS['validation_split'] # No validation split
    )

    # Create validation data generator (minimal augmentation)
    val_generator = ImageDataGenerator(
        preprocessing_function=tf.keras.applications.mobilenet_v2.preprocess_input, # MobileNetV2 preprocessing
        validation_split=0.0 # No validation split
    )

    # Create test data generator (no augmentation)
    test_generator = ImageDataGenerator(
        preprocessing_function=tf.keras.applications.mobilenet_v2.preprocess_input, # MobileNetV2 preprocessing
        validation_split=0.0 # No validation split
    )

    # Get training, validation, and test dataframes
    train_df = df_diseased[df_diseased['split'] == 0] # Training samples
    val_df = df_diseased[df_diseased['split'] == 1] # Validation samples
    test_df = df_diseased[df_diseased['split'] == 2] # Test samples

    # Convert encoded_class_id to string for categorical mode
    train_df['encoded_class_id_str'] = train_df['encoded_class_id'].astype(str) # Convert to string
    val_df['encoded_class_id_str'] = val_df['encoded_class_id'].astype(str) # Convert to string
    test_df['encoded_class_id_str'] = test_df['encoded_class_id'].astype(str) # Convert to string

    # Create generators with categorical mode and string labels
    train_gen = train_generator.flow_from_dataframe(
        dataframe=train_df, # Training dataframe
        x_col='image_path_absolute', # Column containing image paths
        y_col='encoded_class_id_str', # Column containing class labels as strings
        target_size=image_size, # Target image size
        batch_size=batch_size, # Batch size
        class_mode='categorical', # Categorical mode for one-hot encoding
        shuffle=True, # Shuffle training data
        seed=42 # Random seed for reproducibility
    )

    val_gen = val_generator.flow_from_dataframe(
        dataframe=val_df, # Validation dataframe
        x_col='image_path_absolute', # Column containing image paths
        y_col='encoded_class_id_str', # Column containing class labels as strings
        target_size=image_size, # Target image size
        batch_size=batch_size, # Batch size
        class_mode='categorical', # Categorical mode for one-hot encoding
        shuffle=False, # No shuffling for validation
        seed=42 # Random seed for reproducibility
    )

    test_gen = test_generator.flow_from_dataframe(
        dataframe=test_df, # Test dataframe
        x_col='image_path_absolute', # Column containing image paths
        y_col='encoded_class_id_str', # Column containing class labels as strings
        target_size=image_size, # Target image size
        batch_size=batch_size, # Batch size
        class_mode='categorical', # Categorical mode for one-hot encoding
        shuffle=False, # No shuffling for testing
        seed=42 # Random seed for reproducibility
    )

    return train_gen, val_gen, test_gen

# Create enhanced data generators for Stage 2 training
train_generator, val_generator, test_generator = create_stage2_generators(
    df_diseased, (IMG_HEIGHT, IMG_WIDTH), BATCH_SIZE
)

# Display generator information
print(f"Stage 2 data generators created successfully:")
print(f" Training samples: {len(train_generator) * BATCH_SIZE}")
print(f" Validation samples: {len(val_generator) * BATCH_SIZE}")
print(f" Test samples: {len(test_generator) * BATCH_SIZE}")
print(f" Number of classes: {len(diseased_class_mapping)}")

# Create Google Drive directory if it doesn't exist
DRIVE_MODELS_DIR = '/content/drive/MyDrive/Stage2_Enhanced_Models' # Google Drive models directory
os.makedirs(DRIVE_MODELS_DIR, exist_ok=True) # Create directory if it doesn't exist

# Save class mapping to Google Drive
class_mapping_path = os.path.join(DRIVE_MODELS_DIR, 'stage2_class_mapping.json') # Path for class mapping
with open(class_mapping_path, 'w') as f:
    json.dump(diseased_class_mapping, f, indent=2) # Save class mapping as JSON

print(f"Class mapping saved to: {class_mapping_path}")

# Save final dataset to Google Drive
final_dataset_path = os.path.join(DRIVE_MODELS_DIR, 'stage2_final_dataset.csv') # Save final dataset
df_diseased.to_csv(final_dataset_path, index=False) # Save final dataset

print(f"Final dataset saved to: {final_dataset_path}")

print("Stage 2 data preparation completed successfully!")

In [None]:
# **FINAL FIX FOR CLASS COUNT DISPLAY**

# Fix the num_classes to show the correct count for diseased classes only
print(f"\nFinal verification:")  # Print final verification header
print(f"  Total classes in dataset: {len(label_encoder.classes_)}")  # Display total classes
print(f"  Diseased classes after filtering: {df_diseased['encoded_class_id'].nunique()}")  # Display diseased class count

# Update num_classes for the actual diseased classes
num_classes = df_diseased['encoded_class_id'].nunique()  # Get actual diseased class count
print(f"  Corrected number of classes: {num_classes}")  # Display corrected count

# Get the unique class names in the order they appear
unique_class_names = df_diseased['class_name'].unique()  # Get unique class names
print(f"  Unique diseased class names: {len(unique_class_names)}")  # Display unique class count

# Create a simple sequential mapping
class_mapping = {  # Create updated class mapping dictionary
    'num_classes': num_classes,  # Correct number of diseased classes
    'class_indices': {str(i): i for i in range(num_classes)},  # Sequential class index mapping
    'class_names': {str(i): unique_class_names[i] for i in range(num_classes)},  # Class name mapping by position
    'label_encoder_classes': unique_class_names.tolist()  # Diseased class names only
}

# Save updated class mapping
with open(class_mapping_path, 'w') as f:  # Open file for writing
    json.dump(class_mapping, f, indent=4)  # Save updated class mapping

print(f"Updated class mapping saved with {num_classes} diseased classes")  # Print update confirmation

# Display first few class names for verification
print(f"\nFirst 10 diseased classes:")  # Print verification header
for i in range(min(10, num_classes)):  # Show first 10 classes
    print(f"  Class {i}: {unique_class_names[i]}")  # Display class mapping

## **VISUALIZE STAGE-2 DATASET: TRAIN/VAL/TEST DISTRIBUTION, IMAGES PER CLASS AND CLASS BALANCE**

In [None]:
# **VISUALIZE STAGE 2 DATASET: DISEASED CLASSES ANALYSIS AND SAMPLE IMAGES**

# Import required libraries
import matplotlib.pyplot as plt # Import matplotlib for plotting
import numpy as np # Import numpy for numerical operations
import os # Import operating system interface
import json # Import JSON handling for saving analysis

# Set style for professional visualizations
plt.style.use('default') # Use default matplotlib style

def visualize_stage2_dataset_overview(df_diseased, num_images_per_class=3, max_classes_to_show=20):
    """Visualize Stage 2 dataset with sample images per diseased class and class distribution for precision agriculture"""

    print("="*60)  # Print separator line
    print(" STAGE 2 DATASET VISUALIZATION FOR PRECISION AGRICULTURE")  # Print section header
    print("="*60)  # Print separator line

    # 1. Diseased Class Distribution Analysis
    print("\n DISEASED CLASS DISTRIBUTION ANALYSIS")  # Print subsection header
    print("-" * 50)  # Print subsection separator

    # Get class counts for each split - FIXED: Use split_name instead of split
    train_counts = df_diseased[df_diseased['split_name'] == 'train']['class_name'].value_counts()  # Count classes in training set
    val_counts = df_diseased[df_diseased['split_name'] == 'val']['class_name'].value_counts()  # Count classes in validation set
    test_counts = df_diseased[df_diseased['split_name'] == 'test']['class_name'].value_counts()  # Count classes in test set

    print(f"Total diseased classes: {len(df_diseased['class_name'].unique())}")  # Display total number of unique diseased classes
    print(f"Total diseased images: {len(df_diseased)}")  # Display total number of diseased images
    print(f"Train images: {len(df_diseased[df_diseased['split_name'] == 'train'])}")  # Display number of training images
    print(f"Validation images: {len(df_diseased[df_diseased['split_name'] == 'val'])}")  # Display number of validation images
    print(f"Test images: {len(df_diseased[df_diseased['split_name'] == 'test'])}")  # Display number of test images

    # 2. Diseased Class Balance Visualization
    fig, axes = plt.subplots(2, 1, figsize=(16, 12))  # Create figure with 2 subplots

    # Top diseased classes by total count
    all_counts = df_diseased['class_name'].value_counts()  # Get counts for all diseased classes
    top_classes = all_counts.head(max_classes_to_show)  # Select top classes to display

    # Create stacked bar chart for diseased classes
    x_pos = np.arange(len(top_classes))  # Create x-axis positions
    width = 0.25  # Set bar width

    train_vals = [train_counts.get(cls, 0) for cls in top_classes.index]  # Get training counts for top classes
    val_vals = [val_counts.get(cls, 0) for cls in top_classes.index]  # Get validation counts for top classes
    test_vals = [test_counts.get(cls, 0) for cls in top_classes.index]  # Get test counts for top classes

    axes[0].bar(x_pos - width, train_vals, width, label='Train', color='#2E86AB', alpha=0.8)  # Plot training bars
    axes[0].bar(x_pos, val_vals, width, label='Validation', color='#A23B72', alpha=0.8)  # Plot validation bars
    axes[0].bar(x_pos + width, test_vals, width, label='Test', color='#F18F01', alpha=0.8)  # Plot test bars

    axes[0].set_xlabel('Diseased Classes', fontsize=12, fontweight='bold')  # Set x-axis label
    axes[0].set_ylabel('Number of Images', fontsize=12, fontweight='bold')  # Set y-axis label
    axes[0].set_title(f'Diseased Class Distribution - Top {max_classes_to_show} Classes', fontsize=14, fontweight='bold')  # Set title
    axes[0].set_xticks(x_pos)  # Set x-axis tick positions
    axes[0].set_xticklabels([cls[:20] + '...' if len(cls) > 20 else cls for cls in top_classes.index],  # Set x-axis labels
                           rotation=45, ha='right')  # Rotate labels for readability
    axes[0].legend()  # Add legend
    axes[0].grid(True, alpha=0.3)  # Add grid

    # Add value labels on bars
    for i, (train, val, test) in enumerate(zip(train_vals, val_vals, test_vals)):  # Iterate through bar values
        if train > 0:  # If training count is greater than 0
            axes[0].text(i - width, train + 5, str(train), ha='center', va='bottom', fontsize=8)  # Add training count label
        if val > 0:  # If validation count is greater than 0
            axes[0].text(i, val + 5, str(val), ha='center', va='bottom', fontsize=8)  # Add validation count label
        if test > 0:  # If test count is greater than 0
            axes[0].text(i + width, test + 5, str(test), ha='center', va='bottom', fontsize=8)  # Add test count label

    # 3. Class Balance Analysis for Precision Agriculture
    # Calculate balance threshold (classes above mean are considered balanced)
    mean_samples = all_counts.mean()  # Calculate mean samples per class
    balanced_classes = int(sum(all_counts >= mean_samples))  # Count balanced classes
    imbalanced_classes = int(len(all_counts) - balanced_classes)  # Count imbalanced classes

    balance_data = [balanced_classes, imbalanced_classes]  # Prepare data for pie chart
    balance_labels = ['Balanced Classes', 'Imbalanced Classes']  # Labels for pie chart
    balance_colors = ['#4CAF50', '#F44336']  # Green for balanced, Red for imbalanced

    axes[1].pie(balance_data, labels=balance_labels, colors=balance_colors, autopct='%1.1f%%', startangle=90)  # Create pie chart
    axes[1].set_title('Class Balance Analysis for Early Disease Detection', fontsize=14, fontweight='bold')  # Set title

    plt.tight_layout()  # Adjust layout to prevent overlap
    plt.savefig(os.path.join(DRIVE_MODELS_DIR, 'stage2_class_distribution.png'), dpi=300, bbox_inches='tight')  # Save high-quality image
    plt.show()  # Display the plot

    # 4. Sample Images Per Diseased Class - FIXED: Show 5 samples per row
    print(f"\n SAMPLE IMAGES PER DISEASED CLASS (showing first {max_classes_to_show} classes)")  # Print subsection header
    print("-" * 70)  # Print subsection separator

    # Get sample images for each diseased class
    sample_images = {}  # Initialize dictionary for sample images
    for class_name in top_classes.index[:max_classes_to_show]:  # Iterate through top diseased classes
        class_df = df_diseased[df_diseased['class_name'] == class_name]  # Filter data for current class
        if len(class_df) >= num_images_per_class:  # Check if enough images available
            samples = class_df.sample(n=num_images_per_class)  # Randomly sample images
            sample_images[class_name] = samples  # Store samples

    # Create visualization grid for diseased class samples - FIXED: 5 samples per row
    num_classes_to_show = min(len(sample_images), max_classes_to_show)  # Determine number of classes to show
    num_cols = 5  # Set number of columns to 5 as requested
    num_rows = int(np.ceil(num_classes_to_show * num_images_per_class / num_cols))  # Calculate required rows

    fig, axes = plt.subplots(num_rows, num_cols, figsize=(20, 4*num_rows))  # Create subplot grid with 5 columns

    if num_rows == 1:  # Handle single row case
        axes = axes.reshape(1, -1)  # Reshape axes for single row
    elif num_cols == 1:  # Handle single column case
        axes = axes.reshape(-1, 1)  # Reshape axes for single column

    # Flatten axes for easier iteration
    axes_flat = axes.flatten() if hasattr(axes, 'flatten') else [axes] if num_rows == 1 and num_cols == 1 else axes.flatten()

    # Counter for placing images
    img_counter = 0  # Initialize image counter

    for class_name, samples in list(sample_images.items())[:num_classes_to_show]:  # Iterate through classes
        for _, sample in samples.iterrows():  # Iterate through samples for each class
            if img_counter < len(axes_flat):  # Check if we have enough subplot space
                ax = axes_flat[img_counter]  # Get current subplot

                try:
                    # FIXED: Use 'image_path_absolute' instead of 'full_image_path'
                    img_path = sample['image_path_absolute']  # Get absolute image path
                    if os.path.exists(img_path):  # Check if file exists
                        # Try PIL first (most reliable)
                        try:
                            from PIL import Image  # Import PIL for image loading
                            img = Image.open(img_path)  # Open image with PIL
                            img = img.convert('RGB')  # Convert to RGB format
                            img_array = np.array(img)  # Convert to numpy array
                        except Exception as e:
                            # Fallback to matplotlib if PIL fails
                            img_array = plt.imread(img_path)  # Load image from path

                        ax.imshow(img_array)  # Display image

                        # Set title with class name and split
                        split_name = sample['split_name']  # Get split name
                        title_text = f"{class_name[:15]}...\n{split_name}"  # Create title text
                        ax.set_title(title_text, fontsize=8, fontweight='bold')  # Set subplot title
                        ax.axis('off')  # Hide axes

                        # Add border color based on split (different colors for train/val/test)
                        if split_name == 'train':  # If image is from training set
                            border_color = '#2E86AB'  # Blue for training
                        elif split_name == 'val':  # If image is from validation set
                            border_color = '#A23B72'  # Purple for validation
                        else:  # If image is from test set
                            border_color = '#F18F01'  # Orange for test

                        # Apply border color to all spines
                        for spine in ax.spines.values():  # Iterate through all spines
                            spine.set_color(border_color)  # Set spine color
                            spine.set_linewidth(3)  # Set spine width

                    else:  # If file doesn't exist
                        ax.text(0.5, 0.5, 'File\nNot Found',  # Display file not found message
                               ha='center', va='center', fontsize=10, fontweight='bold', color='red')
                        ax.axis('off')  # Hide axes

                except Exception as e:  # Handle image loading errors
                    ax.text(0.5, 0.5, 'Error\nLoading Image',  # Display error message
                           ha='center', va='center', fontsize=10, fontweight='bold', color='red')
                    ax.axis('off')  # Hide axes

                img_counter += 1  # Increment image counter

    # Hide unused subplots
    for i in range(img_counter, len(axes_flat)):  # Iterate through unused subplots
        axes_flat[i].axis('off')  # Hide unused subplots

    plt.tight_layout()  # Adjust layout to prevent overlap
    plt.savefig(os.path.join(DRIVE_MODELS_DIR, 'stage2_sample_images_per_class.png'), dpi=300, bbox_inches='tight')  # Save high-quality image
    plt.show()  # Display the plot

    # 5. Summary Statistics for Precision Agriculture
    print(f"\n SUMMARY STATISTICS FOR PRECISION AGRICULTURE")  # Print subsection header
    print("-" * 50)  # Print subsection separator
    print(f"Most common diseased class: {all_counts.index[0]} ({all_counts.iloc[0]} images)")  # Display most common diseased class
    print(f"Least common diseased class: {all_counts.index[-1]} ({all_counts.iloc[-1]} images)")  # Display least common diseased class
    print(f"Average images per diseased class: {all_counts.mean():.1f}")  # Display average images per diseased class
    print(f"Standard deviation: {all_counts.std():.1f}")  # Display standard deviation
    print(f"Diseased classes with < 10 images: {(all_counts < 10).sum()}")  # Count diseased classes with few images
    print(f"Diseased classes with > 100 images: {(all_counts > 100).sum()}")  # Count diseased classes with many images

    # Class balance statistics for precision agriculture
    print(f"\nClass balance analysis:")  # Print balance analysis header
    print(f"  Balanced classes (≥{mean_samples:.1f} samples): {balanced_classes} ({balanced_classes/len(all_counts)*100:.1f}%)")  # Display balanced class statistics
    print(f"  Imbalanced classes (<{mean_samples:.1f} samples): {imbalanced_classes} ({imbalanced_classes/len(all_counts)*100:.1f}%)")  # Display imbalanced class statistics

    # Split distribution statistics
    print(f"\nSplit distribution:")  # Print split distribution header
    split_counts = df_diseased['split_name'].value_counts()  # Count images in each split
    for split_name, count in split_counts.items():  # Iterate through splits
        percentage = count / len(df_diseased) * 100  # Calculate percentage
        print(f"  {split_name.capitalize()}: {count} images ({percentage:.1f}%)")  # Display split statistics

    print(f"\n Visualizations saved to Google Drive:")  # Print save confirmation
    print(f"   Class distribution: {os.path.join(DRIVE_MODELS_DIR, 'stage2_class_distribution.png')}")  # Display class distribution file path
    print(f"   Sample images: {os.path.join(DRIVE_MODELS_DIR, 'stage2_sample_images_per_class.png')}")  # Display sample images file path

    # Return analysis results for further processing
    return {
        'total_classes': int(len(all_counts)),  # Convert to Python int
        'total_samples': int(len(df_diseased)),  # Convert to Python int
        'mean_samples_per_class': float(all_counts.mean()),  # Convert to Python float
        'median_samples_per_class': float(all_counts.median()),  # Convert to Python float
        'balanced_classes': int(balanced_classes),  # Convert to Python int
        'imbalanced_classes': int(imbalanced_classes),  # Convert to Python float
        'class_balance_ratio': float(balanced_classes / len(all_counts))  # Convert to Python float
    }

# **RUN THE STAGE 2 VISUALIZATION**
print("Creating enhanced Stage 2 dataset visualizations for precision agriculture analysis...")
dataset_analysis = visualize_stage2_dataset_overview(df_diseased, num_images_per_class=3, max_classes_to_show=20)

# Display analysis summary for precision agriculture planning
print("\n" + "="*60)
print("STAGE 2 DATASET ANALYSIS SUMMARY FOR PRECISION AGRICULTURE")
print("="*60)
print(f"Total diseased classes: {dataset_analysis['total_classes']}")
print(f"Total diseased samples: {dataset_analysis['total_samples']:,}")
print(f"Average samples per class: {dataset_analysis['mean_samples_per_class']:.1f}")
print(f"Class balance ratio: {dataset_analysis['class_balance_ratio']:.2%}")
print(f"Classes requiring augmentation: {dataset_analysis['imbalanced_classes']}")

# Save analysis results to Google Drive for future reference
analysis_path = os.path.join(DRIVE_MODELS_DIR, 'stage2_dataset_analysis.json')
with open(analysis_path, 'w') as f:
    json.dump(dataset_analysis, f, indent=4)  # Save analysis results as JSON

print(f"Dataset analysis saved to: {analysis_path}")
print("Stage 2 dataset visualization completed successfully!")

# **MODEL TRAINING**

In [None]:
# ===== Stage 2: Single-model MobileNetV2 training cell (with class order) =====
import os, json, numpy as np, pandas as pd, tensorflow as tf
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout, BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, CSVLogger
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.regularizers import l2

# -----------------------------
# sanity checks / defaults
# -----------------------------
assert 'df_diseased' in globals(), "df_diseased not found. Run dataset-prep cell first."
need_cols = {'image_path_absolute','encoded_class_id','split_name'}
missing = need_cols - set(df_diseased.columns)
assert not missing, f"df_diseased is missing columns: {missing}"

IMG_HEIGHT, IMG_WIDTH = 224, 224
BATCH_SIZE = 32
EPOCHS_PHASE1 = 20
EPOCHS_PHASE2 = 15
PATIENCE = 10
FACTOR = 0.5
MIN_LR = 1e-7
DROPOUT_RATE = 0.5
L2_REG = 1e-4
MODEL_TAG = "stage2_mnv2"
DRIVE_MODELS_DIR = '/content/drive/MyDrive/Stage2_Enhanced_Models'
os.makedirs(DRIVE_MODELS_DIR, exist_ok=True)

# -----------------------------
# Generators with IDENTICAL class order
# -----------------------------
df_diseased = df_diseased.copy()
df_diseased['encoded_class_id_str'] = df_diseased['encoded_class_id'].astype(str)

num_classes = int(df_diseased['encoded_class_id'].nunique())
classes_fixed = [str(i) for i in range(num_classes)]  # identical order everywhere

train_df = df_diseased[df_diseased['split_name'] == 'train'].copy()
val_df   = df_diseased[df_diseased['split_name'] == 'val'].copy()
test_df  = df_diseased[df_diseased['split_name'] == 'test'].copy()

train_idg = ImageDataGenerator(
    preprocessing_function=tf.keras.applications.mobilenet_v2.preprocess_input,
    rotation_range=20, width_shift_range=0.2, height_shift_range=0.2,
    shear_range=0.2, zoom_range=0.2, horizontal_flip=True, fill_mode='nearest'
)
plain_idg = ImageDataGenerator(
    preprocessing_function=tf.keras.applications.mobilenet_v2.preprocess_input
)

train_gen = train_idg.flow_from_dataframe(
    train_df, x_col='image_path_absolute', y_col='encoded_class_id_str',
    target_size=(IMG_HEIGHT, IMG_WIDTH), batch_size=BATCH_SIZE,
    class_mode='categorical', shuffle=True, seed=42, classes=classes_fixed
)
val_gen = plain_idg.flow_from_dataframe(
    val_df, x_col='image_path_absolute', y_col='encoded_class_id_str',
    target_size=(IMG_HEIGHT, IMG_WIDTH), batch_size=BATCH_SIZE,
    class_mode='categorical', shuffle=False, seed=42, classes=classes_fixed
)
test_gen = plain_idg.flow_from_dataframe(
    test_df, x_col='image_path_absolute', y_col='encoded_class_id_str',
    target_size=(IMG_HEIGHT, IMG_WIDTH), batch_size=BATCH_SIZE,
    class_mode='categorical', shuffle=False, seed=42, classes=classes_fixed
)

print(f"Train samples: {train_gen.n}")
print(f"Val samples:   {val_gen.n}")
print(f"Test samples:  {test_gen.n}")
print(f"Num classes:   {num_classes}")
print("class_indices consistent across splits:",
      train_gen.class_indices == val_gen.class_indices == test_gen.class_indices)

# -----------------------------
# Class weights (inverse frequency on encoded_class_id)
# -----------------------------
counts = train_df['encoded_class_id'].value_counts().sort_index()
total = counts.sum()
class_weights = {i: float(total / (num_classes * counts.get(i, 1))) for i in range(num_classes)}
print(f"class_weight range: {min(class_weights.values()):.4f} -> {max(class_weights.values()):.4f}")

# -----------------------------
# Model (do NOT preprocess again inside the model; generators already do it)
# -----------------------------
def build_model(input_shape=(IMG_HEIGHT, IMG_WIDTH, 3), num_classes=num_classes,
                dropout=DROPOUT_RATE, l2_reg=L2_REG):
    base = MobileNetV2(weights='imagenet', include_top=False, input_shape=input_shape)
    base.trainable = False  # Phase 1

    inputs = tf.keras.Input(shape=input_shape)
    x = base(inputs, training=False)              # features from frozen base
    x = GlobalAveragePooling2D()(x)
    x = Dense(1024, activation='relu', kernel_regularizer=l2(l2_reg))(x)
    x = BatchNormalization()(x); x = Dropout(dropout)(x)
    x = Dense(512, activation='relu', kernel_regularizer=l2(l2_reg))(x)
    x = BatchNormalization()(x); x = Dropout(dropout)(x)
    x = Dense(256, activation='relu', kernel_regularizer=l2(l2_reg))(x)
    x = BatchNormalization()(x); x = Dropout(dropout/2)(x)
    outputs = Dense(num_classes, activation='softmax', kernel_regularizer=l2(l2_reg))(x)
    model = Model(inputs, outputs)
    return model, base

model, base_model = build_model()

# -----------------------------
# Callbacks
# -----------------------------
phase1_best = os.path.join(DRIVE_MODELS_DIR, f"{MODEL_TAG}_phase1_best.keras")
phase2_best = os.path.join(DRIVE_MODELS_DIR, f"{MODEL_TAG}_phase2_best.keras")
history_csv1 = os.path.join(DRIVE_MODELS_DIR, f"{MODEL_TAG}_phase1_history.csv")
history_csv2 = os.path.join(DRIVE_MODELS_DIR, f"{MODEL_TAG}_phase2_history.csv")

def make_callbacks(best_path, csv_path):
    return [
        EarlyStopping(monitor='val_loss', patience=PATIENCE, restore_best_weights=True, verbose=1),
        ModelCheckpoint(best_path, monitor='val_loss', save_best_only=True, verbose=1),
        ReduceLROnPlateau(monitor='val_loss', factor=FACTOR, patience=max(1, PATIENCE//2),
                          min_lr=MIN_LR, verbose=1),
        CSVLogger(csv_path, append=False)
    ]

# -----------------------------
# Phase 1: train head
# -----------------------------
model.compile(
    optimizer=tf.keras.optimizers.AdamW(learning_rate=1e-3, weight_decay=1e-4),
    loss='categorical_crossentropy',
    metrics=['accuracy',
             tf.keras.metrics.Precision(name='precision'),
             tf.keras.metrics.Recall(name='recall'),
             tf.keras.metrics.AUC(name='auc')]
)
print("\n=== Phase 1: training with frozen base ===")
hist1 = model.fit(
    train_gen, validation_data=val_gen, epochs=EPOCHS_PHASE1,
    callbacks=make_callbacks(phase1_best, history_csv1),
    class_weight=class_weights, verbose=1
)
phase1_final = os.path.join(DRIVE_MODELS_DIR, f"{MODEL_TAG}_phase1_final.keras")
model.save(phase1_final)
print("Saved:", phase1_final)

# -----------------------------
# Phase 2: fine-tune whole base (lower LR)
# -----------------------------
base_model.trainable = True
# optionally: fine-tune last N layers only by setting .trainable for subsets

model.compile(
    optimizer=tf.keras.optimizers.AdamW(learning_rate=1e-4, weight_decay=1e-4),
    loss='categorical_crossentropy',
    metrics=['accuracy',
             tf.keras.metrics.Precision(name='precision'),
             tf.keras.metrics.Recall(name='recall'),
             tf.keras.metrics.AUC(name='auc')]
)
print("\n=== Phase 2: fine-tuning entire model ===")
hist2 = model.fit(
    train_gen, validation_data=val_gen, epochs=EPOCHS_PHASE2,
    callbacks=make_callbacks(phase2_best, history_csv2),
    class_weight=class_weights, verbose=1
)
phase2_final = os.path.join(DRIVE_MODELS_DIR, f"{MODEL_TAG}_phase2_final.keras")
model.save(phase2_final)
print("Saved:", phase2_final)

# -----------------------------
# Save combined history + class_indices for reproducibility
# -----------------------------
combined_history = {'phase1': hist1.history, 'phase2': hist2.history}
with open(os.path.join(DRIVE_MODELS_DIR, f"{MODEL_TAG}_history.json"), "w") as f:
    json.dump(combined_history, f, indent=2)
with open(os.path.join(DRIVE_MODELS_DIR, f"{MODEL_TAG}_class_indices.json"), "w") as f:
    json.dump(train_gen.class_indices, f, indent=2)

print("\nDone. Best checkpoints:")
print(" -", phase1_best)
print(" -", phase2_best)


In [None]:
# ===== Stage 2: SAFE RESUME (hardened single cell) =====
import os, json, math, numpy as np, pandas as pd, tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, CSVLogger

# -----------------------------
# paths & knobs
# -----------------------------
DRIVE_MODELS_DIR   = "/content/drive/MyDrive/Stage2_Enhanced_Models"
FINAL_DATASET_CSV  = os.path.join(DRIVE_MODELS_DIR, "stage2_final_dataset.csv")  # written by your prep cell
CLASS_INDICES_JSON = os.path.join(DRIVE_MODELS_DIR, "stage2_mnv2_class_indices.json")  # saved by your training cell
MODEL_TAG          = "stage2_mnv2"
RESUME_CANDIDATES  = [
    f"{MODEL_TAG}_phase2_best.keras",
    f"{MODEL_TAG}_phase2_final.keras",
    f"{MODEL_TAG}_phase1_best.keras",
    f"{MODEL_TAG}_phase1_final.keras",
]
IMG_SIZE     = (224, 224)
BATCH_SIZE   = 32
EPOCHS_RESUME= 10         # adjust if you want longer
BASE_LR      = 5e-5       # warm restart LR (small)
WEIGHT_DECAY = 1e-4
PATIENCE     = 6
FACTOR       = 0.5
MIN_LR       = 1e-7

# -----------------------------
# 0) bring df_diseased into memory (load if needed)
# -----------------------------
if "df_diseased" not in globals():
    assert os.path.exists(FINAL_DATASET_CSV), "Can't find df_diseased or stage2_final_dataset.csv."
    df_diseased = pd.read_csv(FINAL_DATASET_CSV)

# -----------------------------
# 1) recreate helper columns if missing
# -----------------------------
if "encoded_class_id_str" not in df_diseased.columns:
    assert "encoded_class_id" in df_diseased.columns, "encoded_class_id missing in df_diseased."
    df_diseased["encoded_class_id_str"] = df_diseased["encoded_class_id"].astype(str)

if "split_name" not in df_diseased.columns:
    assert "split" in df_diseased.columns, "split/split_name missing in df_diseased."
    df_diseased["split_name"] = df_diseased["split"].map({0:"train",1:"val",2:"test"})

if "image_path_absolute" not in df_diseased.columns:
    # derive from dataset root used earlier
    IMAGES_DIR = "/content/plantwild/plantwild/images"
    df_diseased["image_path_absolute"] = df_diseased["image_path"].apply(lambda p: os.path.join(IMAGES_DIR, str(p)))

# basic sanity
need_cols = {"image_path_absolute","encoded_class_id","encoded_class_id_str","split_name"}
missing = need_cols - set(df_diseased.columns)
assert not missing, f"df_diseased missing required cols: {missing}"

# -----------------------------
# 2) load class order exactly as used during training
# -----------------------------
assert os.path.exists(CLASS_INDICES_JSON), f"Missing {CLASS_INDICES_JSON}. Run the training cell once to create it."
with open(CLASS_INDICES_JSON, "r") as f:
    saved_ci = json.load(f)  # {"0":0,"1":1,...}
# rebuild the list of classes in index order
classes_in_order = [None]*len(saved_ci)
for lbl, idx in saved_ci.items():
    classes_in_order[int(idx)] = str(lbl)

# extra guard: ensure list is contiguous and complete
assert all(x is not None for x in classes_in_order), "Class indices JSON is incomplete."

# -----------------------------
# 3) build generators with the *saved* class order
# -----------------------------
train_df = df_diseased[df_diseased["split_name"]=="train"].copy()
val_df   = df_diseased[df_diseased["split_name"]=="val"].copy()
test_df  = df_diseased[df_diseased["split_name"]=="test"].copy()

train_idg = ImageDataGenerator(
    preprocessing_function=tf.keras.applications.mobilenet_v2.preprocess_input,
    rotation_range=20, width_shift_range=0.2, height_shift_range=0.2,
    shear_range=0.2, zoom_range=0.2, horizontal_flip=True, fill_mode="nearest"
)
plain_idg = ImageDataGenerator(
    preprocessing_function=tf.keras.applications.mobilenet_v2.preprocess_input
)

train_gen = train_idg.flow_from_dataframe(
    train_df, x_col="image_path_absolute", y_col="encoded_class_id_str",
    target_size=IMG_SIZE, batch_size=BATCH_SIZE, class_mode="categorical",
    shuffle=True, seed=42, classes=classes_in_order
)
val_gen = plain_idg.flow_from_dataframe(
    val_df, x_col="image_path_absolute", y_col="encoded_class_id_str",
    target_size=IMG_SIZE, batch_size=BATCH_SIZE, class_mode="categorical",
    shuffle=False, seed=42, classes=classes_in_order
)
test_gen = plain_idg.flow_from_dataframe(
    test_df, x_col="image_path_absolute", y_col="encoded_class_id_str",
    target_size=IMG_SIZE, batch_size=BATCH_SIZE, class_mode="categorical",
    shuffle=False, seed=42, classes=classes_in_order
)

print("Class order locked to saved mapping:",
      train_gen.class_indices == val_gen.class_indices == test_gen.class_indices)

num_classes = len(classes_in_order)

# -----------------------------
# 4) compute class weights (inverse freq on encoded_class_id)
# -----------------------------
counts = train_df["encoded_class_id"].value_counts().sort_index()
total  = counts.sum()
class_weights = {i: float(total / (num_classes * counts.get(i, 1))) for i in range(num_classes)}
print(f"class_weight range: {min(class_weights.values()):.4f} -> {max(class_weights.values()):.4f}")

# -----------------------------
# 5) locate a checkpoint to resume from
# -----------------------------
resume_path = None
for name in RESUME_CANDIDATES:
    p = os.path.join(DRIVE_MODELS_DIR, name)
    if os.path.exists(p):
        resume_path = p
        break
assert resume_path is not None, f"No resume checkpoint found in {DRIVE_MODELS_DIR}."
print("Resuming from:", resume_path)

# -----------------------------
# 6) load model & set trainable policy (warm restart)
# -----------------------------
model = tf.keras.models.load_model(resume_path, compile=False)

# unfreeze entire network for fine-tuning (safe even if already unfrozen)
for layer in model.layers:
    layer.trainable = True

# recompile with small LR (warm restart)
model.compile(
    optimizer=tf.keras.optimizers.AdamW(learning_rate=BASE_LR, weight_decay=WEIGHT_DECAY),
    loss="categorical_crossentropy",
    metrics=["accuracy",
             tf.keras.metrics.Precision(name="precision"),
             tf.keras.metrics.Recall(name="recall"),
             tf.keras.metrics.AUC(name="auc")]
)

# -----------------------------
# 7) callbacks
# -----------------------------
resume_tag   = f"{MODEL_TAG}_resume"
best_ckpt    = os.path.join(DRIVE_MODELS_DIR, f"{resume_tag}_best.keras")
history_csv  = os.path.join(DRIVE_MODELS_DIR, f"{resume_tag}_history.csv")

callbacks = [
    EarlyStopping(monitor="val_loss", patience=PATIENCE, restore_best_weights=True, verbose=1),
    ModelCheckpoint(best_ckpt, monitor="val_loss", save_best_only=True, verbose=1),
    ReduceLROnPlateau(monitor="val_loss", factor=FACTOR, patience=max(1, PATIENCE//2),
                      min_lr=MIN_LR, verbose=1),
    CSVLogger(history_csv, append=True)
]

# -----------------------------
# 8) resume training
# -----------------------------
print("\n=== Resuming fine-tuning (warm restart) ===")
history = model.fit(
    train_gen,
    validation_data=val_gen,
    epochs=EPOCHS_RESUME,
    callbacks=callbacks,
    class_weight=class_weights,
    verbose=1
)

# save a final snapshot too
final_path = os.path.join(DRIVE_MODELS_DIR, f"{resume_tag}_final.keras")
model.save(final_path)
print("Saved final snapshot:", final_path)
print("Best checkpoint:", best_ckpt)


# **COMPREHENSIVE MODEL EVALUATION AND ANALYSIS**

In [None]:
# ==== # **STAGE 2 COMPREHENSIVE EVALUATION FRAMEWORK**  ====
import os, json, math, itertools, numpy as np, pandas as pd, tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.metrics import (confusion_matrix, classification_report, roc_auc_score,
                             average_precision_score, precision_recall_fscore_support,
                             log_loss, precision_recall_curve)
from sklearn.preprocessing import label_binarize
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Paths (feel free to tweak)
DRIVE_MODELS_DIR   = '/content/drive/MyDrive/Stage2_Enhanced_Models'
DRIVE_ANALYSIS_DIR = '/content/drive/MyDrive/Stage2_Enhanced_Analysis'
DRIVE_GRADCAM_DIR  = '/content/drive/MyDrive/Stage2_Enhanced_GradCAM'

os.makedirs(DRIVE_ANALYSIS_DIR, exist_ok=True)
os.makedirs(DRIVE_GRADCAM_DIR, exist_ok=True)

# Best model to evaluate (from your logs)
BEST_MODEL_PATH  = os.path.join(DRIVE_MODELS_DIR, 'stage2_mnv2_resume_best.keras')
SAVEDMODEL_DIR   = os.path.join(DRIVE_MODELS_DIR, 'stage2_mnv2_savedmodel')
CLASS_INDEX_JSON = os.path.join(DRIVE_MODELS_DIR, 'stage2_mnv2_class_indices.json')

# Outputs
PREDICTIONS_CSV  = os.path.join(DRIVE_ANALYSIS_DIR, 'stage2_test_predictions.csv')
CLASS_REPORT_CSV = os.path.join(DRIVE_ANALYSIS_DIR, 'stage2_classification_report.csv')
METRICS_JSON     = os.path.join(DRIVE_ANALYSIS_DIR, 'stage2_metrics.json')

IMG_SIZE   = (224, 224)
BATCH_SIZE = 32

print("Ready. Edit BEST_MODEL_PATH above if needed.")


In [None]:
# **Multi-Class Disease Classification Analysis for Precision Agriculture**

import os
import json
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from sklearn.metrics import (confusion_matrix, classification_report, roc_auc_score,
                             average_precision_score, precision_recall_fscore_support,
                             log_loss, precision_recall_curve, cohen_kappa_score)
from sklearn.preprocessing import label_binarize
from sklearn.model_selection import StratifiedKFold
import tensorflow as tf

print("="*60)
print(" STAGE 2 COMPREHENSIVE EVALUATION FRAMEWORK")
print(" Multi-Class Disease Classification Analysis")
print("="*60)

# Configuration
DRIVE_MODELS_DIR = '/content/drive/MyDrive/Stage2_Enhanced_Models'
DRIVE_ANALYSIS_DIR = '/content/drive/MyDrive/Stage2_Enhanced_Analysis'
DRIVE_VISUALIZATIONS_DIR = '/content/drive/MyDrive/Stage2_Enhanced_Visualizations'

# Create directories
os.makedirs(DRIVE_ANALYSIS_DIR, exist_ok=True)
os.makedirs(DRIVE_VISUALIZATIONS_DIR, exist_ok=True)

# Load your trained model
BEST_MODEL_PATH = os.path.join(DRIVE_MODELS_DIR, 'stage2_mnv2_resume_best.keras')
if os.path.exists(BEST_MODEL_PATH):
    print(f" Loading model: {BEST_MODEL_PATH}")
    model = tf.keras.models.load_model(BEST_MODEL_PATH)
else:
    print(f" Model not found: {BEST_MODEL_PATH}")
    print("Please ensure you have a trained Stage 2 model")
    exit()

print("Stage 2 evaluation framework ready!")

In [None]:
# **LOAD DATASET AND PREPARE FOR EVALUATION**

print("Loading Stage 2 dataset for evaluation...")

# Load the saved dataset
if 'df_diseased' not in globals():
    dataset_path = os.path.join(DRIVE_MODELS_DIR, 'stage2_final_dataset.csv')
    if os.path.exists(dataset_path):
        df_diseased = pd.read_csv(dataset_path)
        print(f" Dataset loaded: {len(df_diseased)} samples")
    else:
        print(f" Dataset not found: {dataset_path}")
        exit()

# Prepare data for evaluation
df_eval = df_diseased.copy()
if 'encoded_class_id_str' not in df_eval.columns:
    df_eval['encoded_class_id_str'] = df_eval['encoded_class_id'].astype(str)

num_classes = int(df_eval['encoded_class_id'].nunique())
classes_fixed = [str(i) for i in range(num_classes)]

# Create class name mapping
idx2name = (df_eval[['encoded_class_id','class_name']]
            .drop_duplicates()
            .sort_values('encoded_class_id')
            .set_index('encoded_class_id')['class_name'].to_dict())

# Split data
val_df = df_eval[df_eval['split_name']=='val'].copy()
test_df = df_eval[df_eval['split_name']=='test'].copy()

print(f"Dataset prepared:")
print(f"  Total classes: {num_classes}")
print(f"  Validation samples: {len(val_df)}")
print(f"  Test samples: {len(test_df)}")
print(f"  Class range: 0 to {num_classes-1}")

# Create data generators
plain_idg = ImageDataGenerator(
    preprocessing_function=tf.keras.applications.mobilenet_v2.preprocess_input
)

val_gen = plain_idg.flow_from_dataframe(
    val_df, x_col='image_path_absolute', y_col='encoded_class_id_str',
    target_size=(224, 224), batch_size=32, class_mode='categorical',
    shuffle=False, seed=42, classes=classes_fixed
)

test_gen = plain_idg.flow_from_dataframe(
    test_df, x_col='image_path_absolute', y_col='encoded_class_id_str',
    target_size=(224, 224), batch_size=32, class_mode='categorical',
    shuffle=False, seed=42, classes=classes_fixed
)

print(" Data generators created successfully!")

In [None]:
# **GENERATE PREDICTIONS AND BASIC METRICS**

print("Generating predictions for evaluation...")

def predict_generator(gen):
    """Generate predictions from generator"""
    probs = model.predict(gen, verbose=1)
    y_true_attr = getattr(gen, 'classes', None)
    if y_true_attr is None:
        y_true_attr = getattr(gen, 'labels', None)
    y_true = np.asarray(y_true_attr, dtype=int)
    y_pred = probs.argmax(axis=1)
    assert probs.shape[0] == len(y_true), f"Mismatch: probs {probs.shape[0]} vs y_true {len(y_true)}"
    return y_true, y_pred, probs

# Generate predictions
print("Validation predictions...")
y_val, yhat_val, p_val = predict_generator(val_gen)

print("Test predictions...")
y_test, yhat_test, p_test = predict_generator(test_gen)

# Basic accuracy metrics
acc_val = float((yhat_val == y_val).mean())
acc_test = float((yhat_test == y_test).mean())

print(f"\nBasic Performance:")
print(f"  Validation Accuracy: {acc_val:.4f}")
print(f"  Test Accuracy: {acc_test:.4f}")

# Top-K accuracy for multi-class
def topk_acc(probs, y_true, k=5):
    """Calculate top-K accuracy"""
    topk = np.argpartition(-probs, kth=range(k), axis=1)[:, :k]
    return float(np.mean([y_true[i] in topk[i] for i in range(len(y_true))]))

acc_top3 = topk_acc(p_test, y_test, k=3)
acc_top5 = topk_acc(p_test, y_test, k=5)

print(f"  Top-3 Accuracy: {acc_top3:.4f}")
print(f"  Top-5 Accuracy: {acc_top5:.4f}")

print(" Predictions generated successfully!")

In [None]:
# **TRAINING HISTORY  AND LOSS FUNCTION ANALYSIS + PLOTS**

def analyze_training_loss_and_history():
    """Analyze training loss, validation loss, and training history with comprehensive visualizations"""

    print("="*60)
    print(" TRAINING LOSS AND HISTORY ANALYSIS")
    print(" Categorical Cross-Entropy Loss Analysis with Visualizations")
    print("="*60)

    # Check if we have training history
    history_files = [
        os.path.join(DRIVE_MODELS_DIR, 'stage2_mnv2_history.json'),
        os.path.join(DRIVE_MODELS_DIR, 'stage2_mnv2_phase1_history.csv'),
        os.path.join(DRIVE_MODELS_DIR, 'stage2_mnv2_phase2_history.csv'),
        os.path.join(DRIVE_MODELS_DIR, 'stage2_mnv2_resume_history.csv')
    ]

    training_history = None
    history_source = None

    # Try to load training history
    for history_file in history_files:
        if os.path.exists(history_file):
            try:
                if history_file.endswith('.json'):
                    with open(history_file, 'r') as f:
                        training_history = json.load(f)
                    history_source = history_file
                    print(f"✓ Loaded training history from: {history_file}")
                    break
                elif history_file.endswith('.csv'):
                    training_history = pd.read_csv(history_file)
                    history_source = history_file
                    print(f"✓ Loaded training history from: {history_file}")
                    break
            except Exception as e:
                print(f"✗ Error loading {history_file}: {e}")
                continue

    if training_history is None:
        print(" No training history found. Proceeding with current model evaluation only.")
        return None

    # Analyze training history
    print(f"\nTRAINING HISTORY ANALYSIS:")
    print("-" * 40)

    # Prepare data for plotting
    plot_data = {}

    if isinstance(training_history, dict):
        # JSON format (combined history)
        print("  Format: Combined JSON history")

        if 'phase1' in training_history and 'phase2' in training_history:
            print("  Phases: Phase 1 (Head training) + Phase 2 (Fine-tuning)")

            # Phase 1 analysis
            phase1 = training_history['phase1']
            if 'loss' in phase1 and 'val_loss' in phase1:
                phase1_epochs = len(phase1['loss'])
                phase1_final_loss = phase1['loss'][-1]
                phase1_final_val_loss = phase1['val_loss'][-1]

                print(f"  Phase 1:")
                print(f"    Epochs: {phase1_epochs}")
                print(f"    Final Training Loss: {phase1_final_loss:.4f}")
                print(f"    Final Validation Loss: {phase1_final_val_loss:.4f}")
                print(f"    Overfitting Check: {'Yes' if phase1_final_loss < phase1_final_val_loss else 'No'}")

                # Store for plotting
                plot_data['phase1'] = {
                    'epochs': list(range(1, phase1_epochs + 1)),
                    'loss': phase1['loss'],
                    'val_loss': phase1['val_loss'],
                    'accuracy': phase1.get('accuracy', []),
                    'val_accuracy': phase1.get('val_accuracy', [])
                }

            # Phase 2 analysis
            phase2 = training_history['phase2']
            if 'loss' in phase2 and 'val_loss' in phase2:
                phase2_epochs = len(phase2['loss'])
                phase2_final_loss = phase2['loss'][-1]
                phase2_final_val_loss = phase2['val_loss'][-1]

                print(f"  Phase 2:")
                print(f"    Epochs: {phase2_epochs}")
                print(f"    Final Training Loss: {phase2_final_loss:.4f}")
                print(f"    Final Validation Loss: {phase2_final_val_loss:.4f}")
                print(f"    Overfitting Check: {'Yes' if phase2_final_loss < phase2_final_val_loss else 'No'}")

                # Store for plotting
                plot_data['phase2'] = {
                    'epochs': list(range(1, phase2_epochs + 1)),
                    'loss': phase2['loss'],
                    'val_loss': phase2['val_loss'],
                    'accuracy': phase2.get('accuracy', []),
                    'val_accuracy': phase2.get('val_accuracy', [])
                }

    elif isinstance(training_history, pd.DataFrame):
        # CSV format
        print("  Format: CSV history")
        print(f"  Total rows: {len(training_history)}")

        if 'loss' in training_history.columns and 'val_loss' in training_history.columns:
            final_loss = training_history['loss'].iloc[-1]
            final_val_loss = training_history['val_loss'].iloc[-1]

            print(f"  Final Training Loss: {final_loss:.4f}")
            print(f"  Final Validation Loss: {final_val_loss:.4f}")
            print(f"  Overfitting Check: {'Yes' if final_loss < final_val_loss else 'No'}")

            # Store for plotting
            plot_data['single_phase'] = {
                'epochs': list(range(1, len(training_history) + 1)),
                'loss': training_history['loss'].tolist(),
                'val_loss': training_history['val_loss'].tolist(),
                'accuracy': training_history.get('accuracy', []).tolist() if 'accuracy' in training_history.columns else [],
                'val_accuracy': training_history.get('val_accuracy', []).tolist() if 'val_accuracy' in training_history.columns else []
            }

    # Current model loss evaluation
    print(f"\nCURRENT MODEL LOSS EVALUATION:")
    print("-" * 40)

    # Calculate categorical cross-entropy loss on test set
    from sklearn.metrics import log_loss

    # Convert predictions to proper format for loss calculation
    y_test_onehot = tf.keras.utils.to_categorical(y_test, num_classes=num_classes)

    # Calculate loss
    test_loss = log_loss(y_test_onehot, p_test)

    print(f"  Test Set Categorical Cross-Entropy Loss: {test_loss:.4f}")

    # Loss interpretation
    if test_loss < 0.5:
        loss_quality = "Excellent"
    elif test_loss < 1.0:
        loss_quality = "Good"
    elif test_loss < 2.0:
        loss_quality = "Fair"
    else:
        loss_quality = "Poor"

    print(f"  Loss Quality: {loss_quality}")

    # Compare with random baseline
    random_loss = -np.log(1.0 / num_classes)
    print(f"  Random Baseline Loss: {random_loss:.4f}")
    print(f"  Improvement over Random: {((random_loss - test_loss) / random_loss * 100):.1f}%")

    # ===== CREATE COMPREHENSIVE VISUALIZATIONS =====
    print(f"\nCreating comprehensive training analysis visualizations...")

    if plot_data:
        # Determine the number of subplots needed
        if 'phase1' in plot_data and 'phase2' in plot_data:
            # Two-phase training
            fig, axes = plt.subplots(2, 3, figsize=(20, 12))

            # Phase 1: Loss Curves
            axes[0, 0].plot(plot_data['phase1']['epochs'], plot_data['phase1']['loss'],
                           'b-', linewidth=2, label='Training Loss', alpha=0.8)
            axes[0, 0].plot(plot_data['phase1']['epochs'], plot_data['phase1']['val_loss'],
                           'r-', linewidth=2, label='Validation Loss', alpha=0.8)
            axes[0, 0].set_title('Phase 1: Loss Curves (Head Training)', fontsize=14, fontweight='bold')
            axes[0, 0].set_xlabel('Epochs', fontsize=12)
            axes[0, 0].set_ylabel('Loss', fontsize=12)
            axes[0, 0].legend()
            axes[0, 0].grid(True, alpha=0.3)
            axes[0, 0].set_yscale('log')  # Log scale for better visualization

            # Phase 1: Accuracy Curves
            if plot_data['phase1']['accuracy'] and plot_data['phase1']['val_accuracy']:
                axes[0, 1].plot(plot_data['phase1']['epochs'], plot_data['phase1']['accuracy'],
                               'b-', linewidth=2, label='Training Accuracy', alpha=0.8)
                axes[0, 1].plot(plot_data['phase1']['epochs'], plot_data['phase1']['val_accuracy'],
                               'r-', linewidth=2, label='Validation Accuracy', alpha=0.8)
                axes[0, 1].set_title('Phase 1: Accuracy Curves (Head Training)', fontsize=14, fontweight='bold')
                axes[0, 1].set_xlabel('Epochs', fontsize=12)
                axes[0, 1].set_ylabel('Accuracy', fontsize=12)
                axes[0, 1].legend()
                axes[0, 1].grid(True, alpha=0.3)

            # Phase 1: Overfitting Analysis
            if plot_data['phase1']['loss'] and plot_data['phase1']['val_loss']:
                loss_diff = [t - v for t, v in zip(plot_data['phase1']['loss'], plot_data['phase1']['val_loss'])]
                axes[0, 2].plot(plot_data['phase1']['epochs'], loss_diff,
                               'g-', linewidth=2, alpha=0.8)
                axes[0, 2].axhline(y=0, color='red', linestyle='--', alpha=0.5, label='No Overfitting')
                axes[0, 2].set_title('Phase 1: Overfitting Analysis\n(Training - Validation Loss)', fontsize=14, fontweight='bold')
                axes[0, 2].set_xlabel('Epochs', fontsize=12)
                axes[0, 2].set_ylabel('Loss Difference', fontsize=12)
                axes[0, 2].legend()
                axes[0, 2].grid(True, alpha=0.3)

            # Phase 2: Loss Curves
            axes[1, 0].plot(plot_data['phase2']['epochs'], plot_data['phase2']['loss'],
                           'b-', linewidth=2, label='Training Loss', alpha=0.8)
            axes[1, 0].plot(plot_data['phase2']['epochs'], plot_data['phase2']['val_loss'],
                           'r-', linewidth=2, label='Validation Loss', alpha=0.8)
            axes[1, 0].set_title('Phase 2: Loss Curves (Fine-tuning)', fontsize=14, fontweight='bold')
            axes[1, 0].set_xlabel('Epochs', fontsize=12)
            axes[1, 0].set_ylabel('Loss', fontsize=12)
            axes[1, 0].legend()
            axes[1, 0].grid(True, alpha=0.3)
            axes[1, 0].set_yscale('log')  # Log scale for better visualization

            # Phase 2: Accuracy Curves
            if plot_data['phase2']['accuracy'] and plot_data['phase2']['val_accuracy']:
                axes[1, 1].plot(plot_data['phase2']['epochs'], plot_data['phase2']['accuracy'],
                               'b-', linewidth=2, label='Training Accuracy', alpha=0.8)
                axes[1, 1].plot(plot_data['phase2']['epochs'], plot_data['phase2']['val_accuracy'],
                               'r-', linewidth=2, label='Validation Accuracy', alpha=0.8)
                axes[1, 1].set_title('Phase 2: Accuracy Curves (Fine-tuning)', fontsize=14, fontweight='bold')
                axes[1, 0].set_xlabel('Epochs', fontsize=12)
                axes[1, 1].set_ylabel('Accuracy', fontsize=12)
                axes[1, 1].legend()
                axes[1, 1].grid(True, alpha=0.3)

            # Phase 2: Overfitting Analysis
            if plot_data['phase2']['loss'] and plot_data['phase2']['val_loss']:
                loss_diff = [t - v for t, v in zip(plot_data['phase2']['loss'], plot_data['phase2']['val_loss'])]
                axes[1, 2].plot(plot_data['phase2']['epochs'], loss_diff,
                               'g-', linewidth=2, alpha=0.8)
                axes[1, 2].axhline(y=0, color='red', linestyle='--', alpha=0.5, label='No Overfitting')
                axes[1, 2].set_title('Phase 2: Overfitting Analysis\n(Training - Validation Loss)', fontsize=14, fontweight='bold')
                axes[1, 2].set_xlabel('Epochs', fontsize=12)
                axes[1, 2].set_ylabel('Loss Difference', fontsize=12)
                axes[1, 2].legend()
                axes[1, 2].grid(True, alpha=0.3)

        elif 'single_phase' in plot_data:
            # Single phase training
            fig, axes = plt.subplots(2, 2, figsize=(16, 12))

            # Loss Curves
            axes[0, 0].plot(plot_data['single_phase']['epochs'], plot_data['single_phase']['loss'],
                           'b-', linewidth=2, label='Training Loss', alpha=0.8)
            axes[0, 0].plot(plot_data['single_phase']['epochs'], plot_data['single_phase']['val_loss'],
                           'r-', linewidth=2, label='Validation Loss', alpha=0.8)
            axes[0, 0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
            axes[0, 0].set_xlabel('Epochs', fontsize=12)
            axes[0, 0].set_ylabel('Loss', fontsize=12)
            axes[0, 0].legend()
            axes[0, 0].grid(True, alpha=0.3)
            axes[0, 0].set_yscale('log')  # Log scale for better visualization

            # Accuracy Curves
            if plot_data['single_phase']['accuracy'] and plot_data['single_phase']['val_accuracy']:
                axes[0, 1].plot(plot_data['single_phase']['epochs'], plot_data['single_phase']['accuracy'],
                               'b-', linewidth=2, label='Training Accuracy', alpha=0.8)
                axes[0, 1].plot(plot_data['single_phase']['epochs'], plot_data['single_phase']['val_accuracy'],
                               'r-', linewidth=2, label='Validation Accuracy', alpha=0.8)
                axes[0, 1].set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
                axes[0, 1].set_xlabel('Epochs', fontsize=12)
                axes[0, 1].set_ylabel('Accuracy', fontsize=12)
                axes[0, 1].legend()
                axes[0, 1].grid(True, alpha=0.3)

            # Overfitting Analysis
            if plot_data['single_phase']['loss'] and plot_data['single_phase']['val_loss']:
                loss_diff = [t - v for t, v in zip(plot_data['single_phase']['loss'], plot_data['single_phase']['val_loss'])]
                axes[1, 0].plot(plot_data['single_phase']['epochs'], loss_diff,
                               'g-', linewidth=2, alpha=0.8)
                axes[1, 0].axhline(y=0, color='red', linestyle='--', alpha=0.5, label='No Overfitting')
                axes[1, 0].set_title('Overfitting Analysis\n(Training - Validation Loss)', fontsize=14, fontweight='bold')
                axes[1, 0].set_xlabel('Epochs', fontsize=12)
                axes[1, 0].set_ylabel('Loss Difference', fontsize=12)
                axes[1, 0].legend()
                axes[1, 0].grid(True, alpha=0.3)

            # Loss vs Accuracy Correlation
            if plot_data['single_phase']['loss'] and plot_data['single_phase']['accuracy']:
                axes[1, 1].scatter(plot_data['single_phase']['loss'], plot_data['single_phase']['accuracy'],
                                  alpha=0.7, color='purple', s=50)
                axes[1, 1].set_title('Training Loss vs Training Accuracy', fontsize=14, fontweight='bold')
                axes[1, 1].set_xlabel('Training Loss', fontsize=12)
                axes[1, 1].set_ylabel('Training Accuracy', fontsize=12)
                axes[1, 1].grid(True, alpha=0.3)

        plt.tight_layout()

        # Save training analysis visualization
        training_viz_path = os.path.join(DRIVE_VISUALIZATIONS_DIR, 'stage2_training_analysis.png')
        plt.savefig(training_viz_path, dpi=300, bbox_inches='tight')
        plt.show()

        print(f"Training analysis visualizations saved to: {training_viz_path}")

    # Create additional loss analysis plots
    print(f"\nCreating additional loss analysis plots...")

    # 1. Loss Quality Assessment
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))

    # Loss Quality Bar Chart
    loss_categories = ['Random\nBaseline', 'Current\nModel', 'Perfect\nModel']
    loss_values = [random_loss, test_loss, 0.0]  # Perfect model has 0 loss
    colors = ['red', 'orange', 'green']

    bars = axes[0].bar(loss_categories, loss_values, color=colors, alpha=0.7)
    axes[0].set_title('Loss Quality Assessment', fontsize=14, fontweight='bold')
    axes[0].set_xlabel('Model Type', fontsize=12)
    axes[0].set_ylabel('Categorical Cross-Entropy Loss', fontsize=12)
    axes[0].grid(True, alpha=0.3)

    # Add value labels on bars
    for bar, value in zip(bars, loss_values):
        height = bar.get_height()
        axes[0].text(bar.get_x() + bar.get_width()/2., height + 0.1,
                    f'{value:.3f}', ha='center', va='bottom', fontweight='bold')

    # Improvement over Random
    improvement = ((random_loss - test_loss) / random_loss * 100)
    axes[1].pie([improvement, 100-improvement], labels=[f'Improvement\n{improvement:.1f}%', 'Remaining\nGap'],
                colors=['lightgreen', 'lightcoral'], autopct='%1.1f%%', startangle=90)
    axes[1].set_title('Improvement over Random Baseline', fontsize=14, fontweight='bold')

    plt.tight_layout()

    # Save loss quality visualization
    loss_quality_path = os.path.join(DRIVE_VISUALIZATIONS_DIR, 'stage2_loss_quality_analysis.png')
    plt.savefig(loss_quality_path, dpi=300, bbox_inches='tight')
    plt.show()

    # 2. Loss Distribution Analysis
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))

    # Test Loss Distribution
    axes[0].hist(p_test.max(axis=1), bins=30, color='skyblue', alpha=0.7, edgecolor='black')
    axes[0].axvline(p_test.max(axis=1).mean(), color='red', linestyle='--',
                    label=f'Mean Confidence: {p_test.max(axis=1).mean():.3f}')
    axes[0].set_title('Model Confidence Distribution', fontsize=14, fontweight='bold')
    axes[0].set_xlabel('Prediction Confidence', fontsize=12)
    axes[0].set_ylabel('Frequency', fontsize=12)
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

    # Loss vs Confidence Correlation
    # Calculate per-sample loss
    per_sample_loss = []
    for i in range(len(y_test)):
        true_probs = y_test_onehot[i]
        pred_probs = p_test[i]
        sample_loss = -np.sum(true_probs * np.log(pred_probs + 1e-15))
        per_sample_loss.append(sample_loss)

    per_sample_loss = np.array(per_sample_loss)
    confidence = p_test.max(axis=1)

    axes[1].scatter(confidence, per_sample_loss, alpha=0.6, color='purple', s=30)
    axes[1].set_title('Loss vs Confidence Correlation', fontsize=14, fontweight='bold')
    axes[1].set_xlabel('Prediction Confidence', fontsize=12)
    axes[1].set_ylabel('Per-Sample Loss', fontsize=12)
    axes[1].grid(True, alpha=0.3)

    # Add trend line
    z = np.polyfit(confidence, per_sample_loss, 1)
    p = np.poly1d(z)
    correlation = np.corrcoef(confidence, per_sample_loss)[0, 1]
    axes[1].plot(confidence, p(confidence), "r--", alpha=0.8,
                 label=f"Trend (r={correlation:.3f})")
    axes[1].legend()

    plt.tight_layout()

    # Save loss distribution visualization
    loss_dist_path = os.path.join(DRIVE_VISUALIZATIONS_DIR, 'stage2_loss_distribution_analysis.png')
    plt.savefig(loss_dist_path, dpi=300, bbox_inches='tight')
    plt.show()

    print(f"Additional loss analysis plots saved to:")
    print(f"  - Loss quality analysis: {loss_quality_path}")
    print(f"  - Loss distribution analysis: {loss_dist_path}")

    return {
        'training_history': training_history,
        'history_source': history_source,
        'test_loss': test_loss,
        'loss_quality': loss_quality,
        'random_baseline': random_loss,
        'plot_data': plot_data
    }

# Run loss analysis with comprehensive plots
loss_analysis = analyze_training_loss_and_history()

In [None]:
# **BOOTSTRAP CONFIDENCE INTERVALS FOR MULTI-CLASS + PLOTS**

def bootstrap_confidence_intervals_multi_class(n_bootstrap=1000, confidence=0.95):
    """Calculate bootstrap confidence intervals for multi-class metrics with comprehensive visualizations"""

    print("="*60)
    print(" BOOTSTRAP CONFIDENCE INTERVALS - MULTI-CLASS")
    print(" Disease Classification Analysis for 59 Classes")
    print("="*60)

    # Bootstrap accuracy
    n_samples = len(y_test)
    bootstrap_accuracies = []
    bootstrap_f1s = []
    bootstrap_precisions = []
    bootstrap_recalls = []
    bootstrap_top3_accuracies = []
    bootstrap_top5_accuracies = []

    print(f"Performing bootstrap analysis on {n_samples} test samples...")
    print(f"Number of disease classes: {num_classes}")

    for i in range(n_bootstrap):
        if (i + 1) % 200 == 0:
            print(f"  Progress: {i + 1}/{n_bootstrap} iterations")

        indices = np.random.choice(n_samples, n_samples, replace=True)
        bootstrap_acc = (yhat_test[indices] == y_test[indices]).mean()
        bootstrap_accuracies.append(bootstrap_acc)

        # Calculate precision, recall, F1 for this bootstrap sample
        bootstrap_precision, bootstrap_recall, bootstrap_f1, _ = precision_recall_fscore_support(
            y_test[indices], yhat_test[indices], average='macro', zero_division=0
        )
        bootstrap_precisions.append(bootstrap_precision)
        bootstrap_recalls.append(bootstrap_recall)
        bootstrap_f1s.append(bootstrap_f1)

        # Top-K accuracy bootstrap
        bootstrap_top3_acc = topk_acc(p_test[indices], y_test[indices], k=3)
        bootstrap_top5_acc = topk_acc(p_test[indices], y_test[indices], k=5)
        bootstrap_top3_accuracies.append(bootstrap_top3_acc)
        bootstrap_top5_accuracies.append(bootstrap_top5_acc)

    # Calculate confidence intervals for all metrics
    alpha = 1 - confidence

    # Accuracy
    acc_ci_lower = np.percentile(bootstrap_accuracies, alpha/2 * 100)
    acc_ci_upper = np.percentile(bootstrap_accuracies, (1-alpha/2) * 100)
    acc_mean = np.mean(bootstrap_accuracies)
    acc_std = np.std(bootstrap_accuracies)

    # F1-Score
    f1_ci_lower = np.percentile(bootstrap_f1s, alpha/2 * 100)
    f1_ci_upper = np.percentile(bootstrap_f1s, (1-alpha/2) * 100)
    f1_mean = np.mean(bootstrap_f1s)
    f1_std = np.std(bootstrap_f1s)

    # Precision
    prec_ci_lower = np.percentile(bootstrap_precisions, alpha/2 * 100)
    prec_ci_upper = np.percentile(bootstrap_precisions, (1-alpha/2) * 100)
    prec_mean = np.mean(bootstrap_precisions)
    prec_std = np.std(bootstrap_precisions)

    # Recall
    rec_ci_lower = np.percentile(bootstrap_recalls, alpha/2 * 100)
    rec_ci_upper = np.percentile(bootstrap_recalls, (1-alpha/2) * 100)
    rec_mean = np.mean(bootstrap_recalls)
    rec_std = np.std(bootstrap_recalls)

    # Top-3 Accuracy
    top3_ci_lower = np.percentile(bootstrap_top3_accuracies, alpha/2 * 100)
    top3_ci_upper = np.percentile(bootstrap_top3_accuracies, (1-alpha/2) * 100)
    top3_mean = np.mean(bootstrap_top3_accuracies)
    top3_std = np.std(bootstrap_top3_accuracies)

    # Top-5 Accuracy
    top5_ci_lower = np.percentile(bootstrap_top5_accuracies, alpha/2 * 100)
    top5_ci_upper = np.percentile(bootstrap_top5_accuracies, (1-alpha/2) * 100)
    top5_mean = np.mean(bootstrap_top5_accuracies)
    top5_std = np.std(bootstrap_top5_accuracies)

    # Print results
    print(f"\nACCURACY ANALYSIS:")
    print(f"  Mean: {acc_mean:.4f}")
    print(f"  {confidence*100}% CI: [{acc_ci_lower:.4f}, {acc_ci_upper:.4f}]")
    print(f"  Width: {acc_ci_upper - acc_ci_lower:.4f}")

    print(f"\nF1-SCORE ANALYSIS (Macro):")
    print(f"  Mean: {f1_mean:.4f}")
    print(f"  {confidence*100}% CI: [{f1_ci_lower:.4f}, {f1_ci_upper:.4f}]")
    print(f"  Width: {f1_ci_upper - f1_ci_lower:.4f}")

    print(f"\nPRECISION ANALYSIS (Macro):")
    print(f"  Mean: {prec_mean:.4f}")
    print(f"  {confidence*100}% CI: [{prec_ci_lower:.4f}, {prec_ci_upper:.4f}]")
    print(f"  Width: {prec_ci_upper - prec_ci_lower:.4f}")

    print(f"\nRECALL ANALYSIS (Macro):")
    print(f"  Mean: {rec_mean:.4f}")
    print(f"  {confidence*100}% CI: [{rec_ci_lower:.4f}, {rec_ci_upper:.4f}]")
    print(f"  Width: {rec_ci_upper - rec_ci_lower:.4f}")

    print(f"\nTOP-3 ACCURACY ANALYSIS:")
    print(f"  Mean: {top3_mean:.4f}")
    print(f"  {confidence*100}% CI: [{top3_ci_lower:.4f}, {top3_ci_upper:.4f}]")
    print(f"  Width: {top3_ci_upper - top3_ci_lower:.4f}")

    print(f"\nTOP-5 ACCURACY ANALYSIS:")
    print(f"  Mean: {top5_mean:.4f}")
    print(f"  {confidence*100}% CI: [{top5_ci_lower:.4f}, {top5_ci_upper:.4f}]")
    print(f"  Width: {top5_ci_upper - top5_ci_lower:.4f}")

    print(f"\nBootstrap analysis completed with {n_bootstrap} iterations")
    print(f"Confidence level: {confidence*100}%")

    # Store results for visualization
    bootstrap_results = {
        'accuracy': {'mean': acc_mean, 'ci': [acc_ci_lower, acc_ci_upper], 'std': acc_std, 'values': bootstrap_accuracies},
        'f1_score': {'mean': f1_mean, 'ci': [f1_ci_lower, f1_ci_upper], 'std': f1_std, 'values': bootstrap_f1s},
        'precision': {'mean': prec_mean, 'ci': [prec_ci_lower, prec_ci_upper], 'std': prec_std, 'values': bootstrap_precisions},
        'recall': {'mean': rec_mean, 'ci': [rec_ci_lower, rec_ci_upper], 'std': rec_std, 'values': bootstrap_recalls},
        'top3_accuracy': {'mean': top3_mean, 'ci': [top3_ci_lower, top3_ci_upper], 'std': top3_std, 'values': bootstrap_top3_accuracies},
        'top5_accuracy': {'mean': top5_mean, 'ci': [top5_ci_lower, top5_ci_upper], 'std': top5_std, 'values': bootstrap_top5_accuracies}
    }

    # Create comprehensive visualizations
    create_stage2_bootstrap_visualizations(bootstrap_results, confidence, n_bootstrap, num_classes)

    return bootstrap_results

def create_stage2_bootstrap_visualizations(bootstrap_results, confidence, n_bootstrap, num_classes):
    """Create comprehensive bootstrap analysis visualizations for Stage 2 multi-class classification"""

    print("\n" + "="*60)
    print(" CREATING STAGE 2 BOOTSTRAP ANALYSIS VISUALIZATIONS")
    print(" Multi-Class Disease Classification Analysis")
    print("="*60)

    # Set up the plotting style
    plt.style.use('seaborn-v0_8')

    # Create a comprehensive figure
    fig = plt.figure(figsize=(20, 16))
    fig.suptitle(f'Stage 2: Multi-Class Disease Classification Bootstrap Analysis\n'
                 f'{n_bootstrap} iterations, {confidence*100}% confidence level, {num_classes} disease classes',
                 fontsize=18, fontweight='bold')

    # Define colors for different metrics
    colors = ['skyblue', 'lightcoral', 'lightgreen', 'gold', 'plum', 'lightsteelblue']

    # Plot 1: Performance Metrics with Confidence Intervals
    ax1 = plt.subplot(3, 3, 1)
    metrics = ['accuracy', 'f1_score', 'precision', 'recall', 'top3_accuracy', 'top5_accuracy']
    metric_labels = ['Accuracy', 'F1-Score', 'Precision', 'Recall', 'Top-3 Acc', 'Top-5 Acc']

    means = [bootstrap_results[metric]['mean'] for metric in metrics]
    ci_lowers = [bootstrap_results[metric]['ci'][0] for metric in metrics]
    ci_uppers = [bootstrap_results[metric]['ci'][1] for metric in metrics]

    x = np.arange(len(metrics))
    bars = ax1.bar(x, means, color=colors[:len(metrics)], alpha=0.8)

    # Add error bars for confidence intervals
    ax1.errorbar(x, means, yerr=[np.array(means) - np.array(ci_lowers),
                                 np.array(ci_uppers) - np.array(means)],
                fmt='none', color='black', capsize=5, capthick=1)

    ax1.set_xlabel('Metrics')
    ax1.set_ylabel('Score')
    ax1.set_title('Multi-Class Performance Metrics with Confidence Intervals')
    ax1.set_xticks(x)
    ax1.set_xticklabels(metric_labels, rotation=45)
    ax1.grid(True, alpha=0.3)
    ax1.set_ylim(0, 1)

    # Add value labels on bars
    for bar, mean_val in zip(bars, means):
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height + 0.01, f'{mean_val:.3f}',
                ha='center', va='bottom', fontweight='bold')

    # Plot 2: Bootstrap Distribution of Accuracy
    ax2 = plt.subplot(3, 3, 2)
    acc_values = bootstrap_results['accuracy']['values']
    ax2.hist(acc_values, bins=30, alpha=0.7, color='skyblue', density=True, edgecolor='black')
    ax2.axvline(bootstrap_results['accuracy']['mean'], color='red', linestyle='--', linewidth=2,
                label=f'Mean: {bootstrap_results["accuracy"]["mean"]:.3f}')
    ax2.axvline(bootstrap_results['accuracy']['ci'][0], color='orange', linestyle=':', linewidth=2,
                label=f'CI Lower: {bootstrap_results["accuracy"]["ci"][0]:.3f}')
    ax2.axvline(bootstrap_results['accuracy']['ci'][1], color='orange', linestyle=':', linewidth=2,
                label=f'CI Upper: {bootstrap_results["accuracy"]["ci"][1]:.3f}')

    ax2.set_xlabel('Accuracy')
    ax2.set_ylabel('Density')
    ax2.set_title('Bootstrap Distribution of Accuracy')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    # Plot 3: Bootstrap Distribution of F1-Score
    ax3 = plt.subplot(3, 3, 3)
    f1_values = bootstrap_results['f1_score']['values']
    ax3.hist(f1_values, bins=30, alpha=0.7, color='lightcoral', density=True, edgecolor='black')
    ax3.axvline(bootstrap_results['f1_score']['mean'], color='red', linestyle='--', linewidth=2,
                label=f'Mean: {bootstrap_results["f1_score"]["mean"]:.3f}')
    ax3.axvline(bootstrap_results['f1_score']['ci'][0], color='orange', linestyle=':', linewidth=2,
                label=f'CI Lower: {bootstrap_results["f1_score"]["ci"][0]:.3f}')
    ax3.axvline(bootstrap_results['f1_score']['ci'][1], color='orange', linestyle=':', linewidth=2,
                label=f'CI Upper: {bootstrap_results["f1_score"]["ci"][1]:.3f}')

    ax3.set_xlabel('F1-Score')
    ax3.set_ylabel('Density')
    ax3.set_title('Bootstrap Distribution of F1-Score')
    ax3.legend()
    ax3.grid(True, alpha=0.3)

    # Plot 4: Top-K Accuracy Comparison
    ax4 = plt.subplot(3, 3, 4)
    topk_metrics = ['accuracy', 'top3_accuracy', 'top5_accuracy']
    topk_labels = ['Top-1', 'Top-3', 'Top-5']
    topk_means = [bootstrap_results[metric]['mean'] for metric in topk_metrics]
    topk_stds = [bootstrap_results[metric]['std'] for metric in topk_metrics]

    bars = ax4.bar(topk_labels, topk_means, color=['skyblue', 'lightgreen', 'plum'], alpha=0.8)
    ax4.errorbar(topk_labels, topk_means, yerr=topk_stds, fmt='none', color='black', capsize=5, capthick=1)

    ax4.set_ylabel('Accuracy')
    ax4.set_title('Top-K Accuracy Comparison')
    ax4.grid(True, alpha=0.3)
    ax4.set_ylim(0, 1)

    # Add value labels
    for bar, mean_val in zip(bars, topk_means):
        height = bar.get_height()
        ax4.text(bar.get_x() + bar.get_width()/2., height + 0.01, f'{mean_val:.3f}',
                ha='center', va='bottom', fontweight='bold')

    # Plot 5: Precision vs Recall Analysis
    ax5 = plt.subplot(3, 3, 5)
    precision_values = bootstrap_results['precision']['values']
    recall_values = bootstrap_results['recall']['values']

    ax5.scatter(precision_values, recall_values, alpha=0.6, color='lightgreen', s=20)
    ax5.axhline(bootstrap_results['recall']['mean'], color='red', linestyle='--', alpha=0.7,
                label=f'Recall Mean: {bootstrap_results["recall"]["mean"]:.3f}')
    ax5.axvline(bootstrap_results['precision']['mean'], color='blue', linestyle='--', alpha=0.7,
                label=f'Precision Mean: {bootstrap_results["precision"]["mean"]:.3f}')

    ax5.set_xlabel('Precision')
    ax5.set_ylabel('Recall')
    ax5.set_title('Precision vs Recall Distribution')
    ax5.legend()
    ax5.grid(True, alpha=0.3)

    # Plot 6: Confidence Interval Widths
    ax6 = plt.subplot(3, 3, 6)
    ci_widths = [bootstrap_results[metric]['ci'][1] - bootstrap_results[metric]['ci'][0] for metric in metrics]

    bars = ax6.bar(metric_labels, ci_widths, color=colors[:len(metrics)], alpha=0.8)
    ax6.set_ylabel('CI Width (Lower is Better)')
    ax6.set_title('Confidence Interval Widths')
    ax6.set_xticklabels(metric_labels, rotation=45)
    ax6.grid(True, alpha=0.3)

    # Add value labels
    for bar, width in zip(bars, ci_widths):
        height = bar.get_height()
        ax6.text(bar.get_x() + bar.get_width()/2., height + 0.001, f'{width:.4f}',
                ha='center', va='bottom', fontsize=8, fontweight='bold')

    # Plot 7: Bootstrap Stability Analysis
    ax7 = plt.subplot(3, 3, 7)
    stds = [bootstrap_results[metric]['std'] for metric in metrics]

    bars = ax7.bar(metric_labels, stds, color=colors[:len(metrics)], alpha=0.8)
    ax7.set_ylabel('Standard Deviation (Lower is Better)')
    ax7.set_title('Bootstrap Stability Analysis')
    ax7.set_xticklabels(metric_labels, rotation=45)
    ax7.grid(True, alpha=0.3)

    # Add value labels
    for bar, std_val in zip(bars, stds):
        height = bar.get_height()
        ax7.text(bar.get_x() + bar.get_width()/2., height + 0.001, f'{std_val:.4f}',
                ha='center', va='bottom', fontsize=8, fontweight='bold')

    # Plot 8: Performance Heatmap
    ax8 = plt.subplot(3, 3, 8)
    performance_data = np.array([
        [bootstrap_results['accuracy']['mean'], bootstrap_results['f1_score']['mean']],
        [bootstrap_results['precision']['mean'], bootstrap_results['recall']['mean']],
        [bootstrap_results['top3_accuracy']['mean'], bootstrap_results['top5_accuracy']['mean']]
    ])

    im = ax8.imshow(performance_data, cmap='RdYlGn', aspect='auto')
    ax8.set_xticks([0, 1])
    ax8.set_yticks([0, 1, 2])
    ax8.set_xticklabels(['Metric 1', 'Metric 2'])
    ax8.set_yticklabels(['Accuracy/F1', 'Precision/Recall', 'Top-3/Top-5'])
    ax8.set_title('Performance Heatmap')

    # Add text annotations
    for i in range(3):
        for j in range(2):
            text = ax8.text(j, i, f'{performance_data[i, j]:.3f}',
                           ha="center", va="center", color="black", fontweight='bold')

    plt.colorbar(im, ax=ax8, shrink=0.8)

    # Plot 9: Statistical Summary
    ax9 = plt.subplot(3, 3, 9)
    ax9.axis('off')

    summary_text = "STAGE 2 BOOTSTRAP SUMMARY:\n\n"
    summary_text += f"Iterations: {n_bootstrap}\n"
    summary_text += f"Confidence Level: {confidence*100}%\n"
    summary_text += f"Disease Classes: {num_classes}\n\n"

    summary_text += "KEY METRICS:\n"
    summary_text += f"Accuracy: {bootstrap_results['accuracy']['mean']:.4f}\n"
    summary_text += f"F1-Score: {bootstrap_results['f1_score']['mean']:.4f}\n"
    summary_text += f"Top-3 Acc: {bootstrap_results['top3_accuracy']['mean']:.4f}\n"
    summary_text += f"Top-5 Acc: {bootstrap_results['top5_accuracy']['mean']:.4f}\n\n"

    summary_text += "PERFORMANCE ASSESSMENT:\n"
    if bootstrap_results['accuracy']['mean'] > 0.8:
        summary_text += "✓ High accuracy achieved\n"
    elif bootstrap_results['accuracy']['mean'] > 0.6:
        summary_text += "~ Moderate accuracy\n"
    else:
        summary_text += "⚠ Low accuracy - needs improvement\n"

    if bootstrap_results['top3_accuracy']['mean'] > bootstrap_results['accuracy']['mean'] + 0.1:
        summary_text += "✓ Top-K accuracy shows improvement\n"

    summary_text += f"\nAnalysis completed successfully!"

    ax9.text(0.05, 0.95, summary_text, transform=ax9.transAxes, fontsize=9,
             verticalalignment='top', fontfamily='monospace',
             bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray"))

    plt.tight_layout()

    # Save the comprehensive visualization
    save_path = '/content/drive/MyDrive/plantwild_stage1_models/stage2_bootstrap_analysis_comprehensive.png'
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"✓ Comprehensive Stage 2 bootstrap visualization saved to: {save_path}")

    plt.show()

# Run enhanced Stage 2 bootstrap analysis with plots
bootstrap_results = bootstrap_confidence_intervals_multi_class()

In [None]:
# **PRECISION, RECALL, F1-SCORE COMPREHENSIVE ANALYSIS + PLOTS (FIXED)**

def comprehensive_precision_recall_f1_analysis():
    """Comprehensive analysis of precision, recall, and F1-score metrics with visualizations"""

    print("="*60)
    print(" PRECISION, RECALL, F1-SCORE COMPREHENSIVE ANALYSIS")
    print(" Multi-Class Disease Classification Performance Analysis")
    print("="*60)

    # Calculate all metrics with different averaging methods
    per_class_precision, per_class_recall, per_class_f1, per_class_support = precision_recall_fscore_support(
        y_test, yhat_test, average=None, zero_division=0
    )

    # Macro-averaged metrics (treats all classes equally)
    macro_precision, macro_recall, macro_f1, _ = precision_recall_fscore_support(
        y_test, yhat_test, average='macro', zero_division=0
    )

    # Micro-averaged metrics (aggregates all classes)
    micro_precision, micro_recall, micro_f1, _ = precision_recall_fscore_support(
        y_test, yhat_test, average='micro', zero_division=0
    )

    # Weighted-averaged metrics (weighted by class support)
    weighted_precision, weighted_recall, weighted_f1, _ = precision_recall_fscore_support(
        y_test, yhat_test, average='weighted', zero_division=0
    )

    print(f"OVERALL PERFORMANCE METRICS:")
    print(f"  Accuracy: {acc_test:.4f}")
    print(f"  Top-3 Accuracy: {acc_top3:.4f}")
    print(f"  Top-5 Accuracy: {acc_top5:.4f}")

    print(f"\nMACRO-AVERAGED METRICS (All classes treated equally):")
    print(f"  Precision: {macro_precision:.4f}")
    print(f"  Recall: {macro_recall:.4f}")
    print(f"  F1-Score: {macro_f1:.4f}")

    print(f"\nMICRO-AVERAGED METRICS (Aggregated across all classes):")
    print(f"  Precision: {micro_precision:.4f}")
    print(f"  Recall: {micro_recall:.4f}")
    print(f"  F1-Score: {micro_f1:.4f}")

    print(f"\nWEIGHTED-AVERAGED METRICS (Weighted by class frequency):")
    print(f"  Precision: {weighted_precision:.4f}")
    print(f"  Recall: {weighted_recall:.4f}")
    print(f"  F1-Score: {weighted_f1:.4f}")

    # Per-class detailed analysis
    print(f"\nPER-CLASS PERFORMANCE ANALYSIS:")
    print("-" * 80)
    print(f"{'Class ID':<8} {'Class Name':<35} {'Precision':<10} {'Recall':<10} {'F1-Score':<10} {'Support':<8}")
    print("-" * 80)

    # Sort classes by F1-score for better analysis
    class_performance = []
    for i in range(num_classes):
        class_name = idx2name.get(i, f"Class_{i}")
        class_performance.append({
            'class_id': i,
            'class_name': class_name,
            'precision': per_class_precision[i],
            'recall': per_class_recall[i],
            'f1_score': per_class_f1[i],
            'support': per_class_support[i]
        })

    # Sort by F1-score (descending)
    class_performance.sort(key=lambda x: x['f1_score'], reverse=True)

    # Display top 15 and bottom 15 classes
    print(f"TOP 15 PERFORMING CLASSES:")
    for i, perf in enumerate(class_performance[:15]):
        print(f"{perf['class_id']:<8} {perf['class_name'][:34]:<35} {perf['precision']:<10.4f} {perf['recall']:<10.4f} {perf['f1_score']:<10.4f} {perf['support']:<8}")

    print(f"\nBOTTOM 15 PERFORMING CLASSES:")
    for i, perf in enumerate(class_performance[-15:]):
        print(f"{perf['class_id']:<8} {perf['class_name'][:34]:<35} {perf['precision']:<10.4f} {perf['recall']:<10.4f} {perf['f1_score']:<10.4f} {perf['support']:<8}")

    # Class imbalance analysis
    print(f"\nCLASS IMBALANCE ANALYSIS:")
    print("-" * 50)

    support_values = np.array(per_class_support)
    mean_support = support_values.mean()
    median_support = np.median(support_values)
    std_support = support_values.std()

    print(f"  Mean samples per class: {mean_support:.1f}")
    print(f"  Median samples per class: {median_support:.1f}")
    print(f"  Standard deviation: {std_support:.1f}")
    print(f"  Classes with < 10 samples: {(support_values < 10).sum()}")
    print(f"  Classes with > 100 samples: {(support_values > 100).sum()}")

    # Performance correlation with class support
    f1_scores = np.array(per_class_f1)
    correlation = np.corrcoef(support_values, f1_scores)[0, 1]
    print(f"  Correlation (Support vs F1-Score): {correlation:.4f}")

    # Create comprehensive visualizations
    create_precision_recall_f1_visualizations(
        class_performance, per_class_precision, per_class_recall, per_class_f1,
        per_class_support, macro_precision, macro_recall, macro_f1,
        micro_precision, micro_recall, micro_f1,
        weighted_precision, weighted_recall, weighted_f1,
        acc_test, acc_top3, acc_top5, correlation,
        mean_support, median_support, support_values  # Added these variables
    )

    # Save detailed metrics
    metrics_summary = {
        'overall': {
            'accuracy': acc_test,
            'top3_accuracy': acc_top3,
            'top5_accuracy': acc_top5
        },
        'macro_averaged': {
            'precision': macro_precision,
            'recall': macro_recall,
            'f1_score': macro_f1
        },
        'micro_averaged': {
            'precision': micro_precision,
            'recall': micro_recall,
            'f1_score': micro_f1
        },
        'weighted_averaged': {
            'precision': weighted_precision,
            'recall': weighted_recall,
            'f1_score': weighted_f1
        },
        'per_class': class_performance,
        'class_imbalance': {
            'mean_support': mean_support,
            'median_support': median_support,
            'std_support': std_support,
            'correlation_with_f1': correlation
        }
    }

    # Save to file
    metrics_path = os.path.join(DRIVE_ANALYSIS_DIR, 'stage2_precision_recall_f1_analysis.json')
    with open(metrics_path, 'w') as f:
        json.dump(metrics_summary, f, indent=2, default=str)

    print(f"\nDetailed metrics saved to: {metrics_path}")

    return metrics_summary

def create_precision_recall_f1_visualizations(
    class_performance, per_class_precision, per_class_recall, per_class_f1,
    per_class_support, macro_precision, macro_recall, macro_f1,
    micro_precision, micro_recall, micro_f1,
    weighted_precision, weighted_recall, weighted_f1,
    acc_test, acc_top3, acc_top5, correlation,
    mean_support, median_support, support_values  # Added these parameters
):
    """Create comprehensive visualizations for precision, recall, and F1-score analysis"""

    print("\n" + "="*60)
    print(" CREATING PRECISION, RECALL, F1-SCORE VISUALIZATIONS")
    print(" Multi-Class Disease Classification Performance Analysis")
    print("="*60)

    # Set up the plotting style
    plt.style.use('seaborn-v0_8')

    # Create a comprehensive figure
    fig = plt.figure(figsize=(20, 16))
    fig.suptitle('Stage 2: Multi-Class Disease Classification Performance Analysis\n'
                 'Precision, Recall, and F1-Score Comprehensive Analysis',
                 fontsize=18, fontweight='bold')

    # Plot 1: Overall Performance Metrics Comparison
    ax1 = plt.subplot(3, 3, 1)
    overall_metrics = ['Accuracy', 'Top-3', 'Top-5']
    overall_values = [acc_test, acc_top3, acc_top5]
    colors = ['skyblue', 'lightgreen', 'plum']

    bars = ax1.bar(overall_metrics, overall_values, color=colors, alpha=0.8)
    ax1.set_ylabel('Score')
    ax1.set_title('Overall Performance Metrics')
    ax1.grid(True, alpha=0.3)
    ax1.set_ylim(0, 1)

    # Add value labels
    for bar, value in zip(bars, overall_values):
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height + 0.01, f'{value:.3f}',
                ha='center', va='bottom', fontweight='bold')

    # Plot 2: Averaging Methods Comparison
    ax2 = plt.subplot(3, 3, 2)
    averaging_methods = ['Macro', 'Micro', 'Weighted']
    precision_values = [macro_precision, micro_precision, weighted_precision]
    recall_values = [macro_recall, micro_recall, weighted_recall]
    f1_values = [macro_f1, micro_f1, weighted_f1]

    x = np.arange(len(averaging_methods))
    width = 0.25

    bars1 = ax2.bar(x - width, precision_values, width, label='Precision', alpha=0.8, color='skyblue')
    bars2 = ax2.bar(x, recall_values, width, label='Recall', alpha=0.8, color='lightcoral')
    bars3 = ax2.bar(x + width, f1_values, width, label='F1-Score', alpha=0.8, color='lightgreen')

    ax2.set_xlabel('Averaging Method')
    ax2.set_ylabel('Score')
    ax2.set_title('Performance Metrics by Averaging Method')
    ax2.set_xticks(x)
    ax2.set_xticklabels(averaging_methods)
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim(0, 1)

    # Plot 3: Top 20 Classes Performance
    ax3 = plt.subplot(3, 3, 3)
    top_20 = class_performance[:20]
    class_names = [perf['class_name'][:15] for perf in top_20]
    f1_scores = [perf['f1_score'] for perf in top_20]

    bars = ax3.barh(range(len(class_names)), f1_scores, color='lightgreen', alpha=0.8)
    ax3.set_yticks(range(len(class_names)))
    ax3.set_yticklabels(class_names)
    ax3.set_xlabel('F1-Score')
    ax3.set_title('Top 20 Performing Classes')
    ax3.grid(True, alpha=0.3)
    ax3.set_xlim(0, 1)

    # Plot 4: Bottom 20 Classes Performance
    ax4 = plt.subplot(3, 3, 4)
    bottom_20 = class_performance[-20:]
    class_names_bottom = [perf['class_name'][:15] for perf in bottom_20]
    f1_scores_bottom = [perf['f1_score'] for perf in bottom_20]

    bars = ax4.barh(range(len(class_names_bottom)), f1_scores_bottom, color='lightcoral', alpha=0.8)
    ax4.set_yticks(range(len(class_names_bottom)))
    ax4.set_yticklabels(class_names_bottom)
    ax4.set_xlabel('F1-Score')
    ax4.set_title('Bottom 20 Performing Classes')
    ax4.grid(True, alpha=0.3)
    ax4.set_xlim(0, 1)

    # Plot 5: Precision vs Recall Scatter Plot
    ax5 = plt.subplot(3, 3, 5)
    ax5.scatter(per_class_precision, per_class_recall, alpha=0.6, s=30, c=per_class_f1, cmap='RdYlGn')

    # Add diagonal line for equal precision/recall
    ax5.plot([0, 1], [0, 1], 'k--', alpha=0.5, label='Equal Precision/Recall')

    ax5.set_xlabel('Precision')
    ax5.set_ylabel('Recall')
    ax5.set_title('Precision vs Recall by Class (colored by F1-Score)')
    ax5.legend()
    ax5.grid(True, alpha=0.3)

    # Add colorbar
    scatter = ax5.scatter(per_class_precision, per_class_recall, alpha=0.6, s=30, c=per_class_f1, cmap='RdYlGn')
    plt.colorbar(scatter, ax=ax5, shrink=0.8, label='F1-Score')

    # Plot 6: Class Support Distribution
    ax6 = plt.subplot(3, 3, 6)
    ax6.hist(per_class_support, bins=30, alpha=0.7, color='lightsteelblue', edgecolor='black')
    ax6.axvline(mean_support, color='red', linestyle='--', linewidth=2,
                label=f'Mean: {mean_support:.1f}')
    ax6.axvline(median_support, color='orange', linestyle='--', linewidth=2,
                label=f'Median: {median_support:.1f}')

    ax6.set_xlabel('Samples per Class')
    ax6.set_ylabel('Number of Classes')
    ax6.set_title('Class Support Distribution')
    ax6.legend()
    ax6.grid(True, alpha=0.3)

    # Plot 7: Performance vs Class Support
    ax7 = plt.subplot(3, 3, 7)
    ax7.scatter(per_class_support, per_class_f1, alpha=0.6, s=30, color='purple')

    # Add trend line
    z = np.polyfit(per_class_support, per_class_f1, 1)
    p = np.poly1d(z)
    ax7.plot(per_class_support, p(per_class_support), "r--", alpha=0.8,
             label=f'Correlation: {correlation:.3f}')

    ax7.set_xlabel('Class Support (Number of Samples)')
    ax7.set_ylabel('F1-Score')
    ax7.set_title('Performance vs Class Support')
    ax7.legend()
    ax7.grid(True, alpha=0.3)

    # Plot 8: Metrics Distribution by Class
    ax8 = plt.subplot(3, 3, 8)
    metrics_data = [per_class_precision, per_class_recall, per_class_f1]
    metric_labels = ['Precision', 'Recall', 'F1-Score']
    colors_metrics = ['skyblue', 'lightcoral', 'lightgreen']

    bp = ax8.boxplot(metrics_data, labels=metric_labels, patch_artist=True)
    for patch, color in zip(bp['boxes'], colors_metrics):
        patch.set_facecolor(color)
        patch.set_alpha(0.7)

    ax8.set_ylabel('Score')
    ax8.set_title('Distribution of Metrics Across Classes')
    ax8.grid(True, alpha=0.3)
    ax8.set_ylim(0, 1)

    # Plot 9: Statistical Summary
    ax9 = plt.subplot(3, 3, 9)
    ax9.axis('off')

    summary_text = "PERFORMANCE ANALYSIS SUMMARY:\n\n"
    summary_text += f"Overall Accuracy: {acc_test:.4f}\n"
    summary_text += f"Top-3 Accuracy: {acc_top3:.4f}\n"
    summary_text += f"Top-5 Accuracy: {acc_top5:.4f}\n\n"

    summary_text += "AVERAGING METHODS:\n"
    summary_text += f"Macro F1: {macro_f1:.4f}\n"
    summary_text += f"Micro F1: {micro_f1:.4f}\n"
    summary_text += f"Weighted F1: {weighted_f1:.4f}\n\n"

    summary_text += "CLASS IMBALANCE:\n"
    summary_text += f"Mean Support: {mean_support:.1f}\n"
    summary_text += f"Classes < 10: {(support_values < 10).sum()}\n"
    summary_text += f"Classes > 100: {(support_values > 100).sum()}\n\n"

    summary_text += f"Correlation: {correlation:.3f}"

    ax9.text(0.05, 0.95, summary_text, transform=ax9.transAxes, fontsize=9,
             verticalalignment='top', fontfamily='monospace',
             bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray"))

    plt.tight_layout()

    # Save the comprehensive visualization
    save_path = '/content/drive/MyDrive/plantwild_stage1_models/stage2_precision_recall_f1_analysis_comprehensive.png'
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"✓ Comprehensive precision/recall/F1 visualization saved to: {save_path}")

    plt.show()

# Run enhanced precision, recall, F1 analysis with plots
precision_recall_f1_analysis = comprehensive_precision_recall_f1_analysis()

In [None]:
# **STATISTICAL SIGNIFICANCE ANALYSIS FOR MULTI-CLASS + PLOTS**

def statistical_significance_analysis_multi_class():
    """Perform statistical significance analysis for multi-class classification with visualizations"""

    print("="*60)
    print(" STATISTICAL SIGNIFICANCE ANALYSIS - MULTI-CLASS")
    print("="*60)

    # Calculate per-class metrics
    per_class_precision, per_class_recall, per_class_f1, per_class_support = precision_recall_fscore_support(
        y_test, yhat_test, average=None, zero_division=0
    )

    # Calculate macro-averaged metrics
    macro_precision, macro_recall, macro_f1, _ = precision_recall_fscore_support(
        y_test, yhat_test, average='macro', zero_division=0
    )

    # Calculate micro-averaged metrics
    micro_precision, micro_recall, micro_f1, _ = precision_recall_fscore_support(
        y_test, yhat_test, average='micro', zero_division=0
    )

    # Calculate per-class accuracy
    per_class_acc = []
    for i in range(num_classes):
        class_mask = (y_test == i)
        if class_mask.sum() > 0:
            class_acc = (yhat_test[class_mask] == y_test[class_mask]).mean()
            per_class_acc.append(class_acc)
        else:
            per_class_acc.append(0.0)

    per_class_acc = np.array(per_class_acc)

    print(f"PER-CLASS PERFORMANCE ANALYSIS:")
    print(f"  Number of classes: {num_classes}")
    print(f"  Mean F1-Score: {per_class_f1.mean():.4f}")
    print(f"  Std F1-Score: {per_class_f1.std():.4f}")
    print(f"  Mean Precision: {per_class_precision.mean():.4f}")
    print(f"  Std Precision: {per_class_precision.std():.4f}")
    print(f"  Mean Recall: {per_class_recall.mean():.4f}")
    print(f"  Std Recall: {per_class_recall.std():.4f}")
    print(f"  Mean Accuracy: {per_class_acc.mean():.4f}")
    print(f"  Std Accuracy: {per_class_acc.std():.4f}")

    print(f"\nAVERAGING METHODS COMPARISON:")
    print(f"  Macro-Averaged:")
    print(f"    Precision: {macro_precision:.4f}")
    print(f"    Recall: {macro_recall:.4f}")
    print(f"    F1-Score: {macro_f1:.4f}")
    print(f"  Micro-Averaged:")
    print(f"    Precision: {micro_precision:.4f}")
    print(f"    Recall: {micro_recall:.4f}")
    print(f"    F1-Score: {micro_f1:.4f}")

    # Statistical tests
    print(f"\nSTATISTICAL ANALYSIS:")

    # Test for normality (Shapiro-Wilk)
    try:
        f1_stat, f1_p = stats.shapiro(per_class_f1)
        print(f"  F1-Score normality test (Shapiro-Wilk):")
        print(f"    Statistic: {f1_stat:.4f}")
        print(f"    P-value: {f1_p:.6f}")
        print(f"    Normal distribution: {'Yes' if f1_p > 0.05 else 'No'}")
    except:
        print(f"  F1-Score normality test failed")

    # Confidence intervals for mean performance
    f1_ci = stats.t.interval(0.95, len(per_class_f1)-1,
                            loc=per_class_f1.mean(),
                            scale=stats.sem(per_class_f1))

    prec_ci = stats.t.interval(0.95, len(per_class_precision)-1,
                              loc=per_class_precision.mean(),
                              scale=stats.sem(per_class_precision))

    rec_ci = stats.t.interval(0.95, len(per_class_recall)-1,
                             loc=per_class_recall.mean(),
                             scale=stats.sem(per_class_recall))

    acc_ci = stats.t.interval(0.95, len(per_class_acc)-1,
                             loc=per_class_acc.mean(),
                             scale=stats.sem(per_class_acc))

    print(f"\nCONFIDENCE INTERVALS (95%):")
    print(f"  F1-Score: [{f1_ci[0]:.4f}, {f1_ci[1]:.4f}]")
    print(f"  Precision: [{prec_ci[0]:.4f}, {prec_ci[1]:.4f}]")
    print(f"  Recall: [{rec_ci[0]:.4f}, {rec_ci[1]:.4f}]")
    print(f"  Accuracy: [{acc_ci[0]:.4f}, {acc_ci[1]:.4f}]")

    # Effect size analysis
    print(f"\nEFFECT SIZE ANALYSIS:")

    # Calculate Cohen's d for class performance variation
    f1_range = per_class_f1.max() - per_class_f1.min()
    f1_cohens_d = f1_range / per_class_f1.std()

    print(f"  F1-Score variation (Cohen's d): {f1_cohens_d:.4f}")

    if f1_cohens_d < 0.2:
        effect_size = "Negligible"
    elif f1_cohens_d < 0.5:
        effect_size = "Small"
    elif f1_cohens_d < 0.8:
        effect_size = "Medium"
    else:
        effect_size = "Large"

    print(f"  Effect Size: {effect_size}")

    # ===== CREATE VISUALIZATIONS =====
    print(f"\nCreating statistical analysis visualizations...")

    # Create comprehensive subplot grid
    fig, axes = plt.subplots(2, 3, figsize=(20, 12))

    # 1. Per-Class F1-Score Distribution
    axes[0, 0].hist(per_class_f1, bins=20, color='skyblue', alpha=0.7, edgecolor='black')
    axes[0, 0].axvline(per_class_f1.mean(), color='red', linestyle='--',
                       label=f'Mean: {per_class_f1.mean():.3f}')
    axes[0, 0].axvline(per_class_f1.mean() + per_class_f1.std(), color='orange', linestyle='--',
                       label=f'+1 Std: {per_class_f1.mean() + per_class_f1.std():.3f}')
    axes[0, 0].axvline(per_class_f1.mean() - per_class_f1.std(), color='orange', linestyle='--',
                       label=f'-1 Std: {per_class_f1.mean() - per_class_f1.std():.3f}')
    axes[0, 0].set_title('Per-Class F1-Score Distribution', fontsize=14, fontweight='bold')
    axes[0, 0].set_xlabel('F1-Score', fontsize=12)
    axes[0, 0].set_ylabel('Number of Classes', fontsize=12)
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)

    # 2. Per-Class Precision Distribution
    axes[0, 1].hist(per_class_precision, bins=20, color='lightgreen', alpha=0.7, edgecolor='black')
    axes[0, 1].axvline(per_class_precision.mean(), color='red', linestyle='--',
                       label=f'Mean: {per_class_precision.mean():.3f}')
    axes[0, 1].set_title('Per-Class Precision Distribution', fontsize=14, fontweight='bold')
    axes[0, 1].set_xlabel('Precision', fontsize=12)
    axes[0, 1].set_ylabel('Number of Classes', fontsize=12)
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)

    # 3. Per-Class Recall Distribution
    axes[0, 2].hist(per_class_recall, bins=20, color='lightcoral', alpha=0.7, edgecolor='black')
    axes[0, 2].axvline(per_class_recall.mean(), color='red', linestyle='--',
                       label=f'Mean: {per_class_recall.mean():.3f}')
    axes[0, 2].set_title('Per-Class Recall Distribution', fontsize=14, fontweight='bold')
    axes[0, 2].set_xlabel('Recall', fontsize=12)
    axes[0, 2].set_ylabel('Number of Classes', fontsize=12)
    axes[0, 2].legend()
    axes[0, 2].grid(True, alpha=0.3)

    # 4. Averaging Methods Comparison
    averaging_methods = ['Macro', 'Micro']
    precision_values = [macro_precision, micro_precision]
    recall_values = [macro_recall, micro_recall]
    f1_values = [macro_f1, micro_f1]

    x_pos = np.arange(len(averaging_methods))
    width = 0.25

    axes[1, 0].bar(x_pos - width, precision_values, width, label='Precision', color='#FF6B6B', alpha=0.8)
    axes[1, 0].bar(x_pos, recall_values, width, label='Recall', color='#4ECDC4', alpha=0.8)
    axes[1, 0].bar(x_pos + width, f1_values, width, label='F1-Score', color='#45B7D1', alpha=0.8)

    axes[1, 0].set_title('Averaging Methods Comparison', fontsize=14, fontweight='bold')
    axes[1, 0].set_xlabel('Averaging Method', fontsize=12)
    axes[1, 0].set_ylabel('Score', fontsize=12)
    axes[1, 0].set_xticks(x_pos)
    axes[1, 0].set_xticklabels(averaging_methods)
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)

    # 5. Performance vs Class Support Scatter
    support_values = np.array(per_class_support)
    axes[1, 1].scatter(support_values, per_class_f1, alpha=0.7, color='purple', s=50)
    axes[1, 1].set_title('F1-Score vs Class Support', fontsize=14, fontweight='bold')
    axes[1, 1].set_xlabel('Number of Samples (Support)', fontsize=12)
    axes[1, 1].set_ylabel('F1-Score', fontsize=12)
    axes[1, 1].grid(True, alpha=0.3)

    # Add trend line
    z = np.polyfit(support_values, per_class_f1, 1)
    p = np.poly1d(z)
    correlation = np.corrcoef(support_values, per_class_f1)[0, 1]
    axes[1, 1].plot(support_values, p(support_values), "r--", alpha=0.8,
                     label=f"Trend (r={correlation:.3f})")
    axes[1, 1].legend()

    # 6. Confidence Intervals Visualization
    metrics = ['F1-Score', 'Precision', 'Recall', 'Accuracy']
    means = [per_class_f1.mean(), per_class_precision.mean(), per_class_recall.mean(), per_class_acc.mean()]
    ci_lower = [f1_ci[0], prec_ci[0], rec_ci[0], acc_ci[0]]
    ci_upper = [f1_ci[1], prec_ci[1], rec_ci[1], acc_ci[1]]

    x_pos = np.arange(len(metrics))
    yerr = np.array([np.array(means) - np.array(ci_lower), np.array(ci_upper) - np.array(means)])

    axes[1, 2].errorbar(x_pos, means, yerr=yerr, fmt='o', capsize=5, capthick=2,
                        markersize=8, color='blue', alpha=0.8)
    axes[1, 2].set_title('95% Confidence Intervals', fontsize=14, fontweight='bold')
    axes[1, 2].set_xlabel('Metrics', fontsize=12)
    axes[1, 2].set_ylabel('Score', fontsize=12)
    axes[1, 2].set_xticks(x_pos)
    axes[1, 2].set_xticklabels(metrics, rotation=45)
    axes[1, 2].grid(True, alpha=0.3)

    # Add value labels
    for i, (mean, ci_l, ci_u) in enumerate(zip(means, ci_lower, ci_upper)):
        axes[1, 2].text(i, mean + yerr[1, i] + 0.01, f'{mean:.3f}',
                        ha='center', va='bottom', fontweight='bold')
        axes[1, 2].text(i, mean - yerr[0, i] - 0.01, f'[{ci_l:.3f}, {ci_u:.3f}]',
                        ha='center', va='top', fontsize=8)

    plt.tight_layout()

    # Save visualization
    stats_viz_path = os.path.join(DRIVE_VISUALIZATIONS_DIR, 'stage2_statistical_analysis.png')
    plt.savefig(stats_viz_path, dpi=300, bbox_inches='tight')
    plt.show()

    print(f"Statistical analysis visualizations saved to: {stats_viz_path}")

    return {
        'per_class_f1': per_class_f1,
        'per_class_precision': per_class_precision,
        'per_class_recall': per_class_recall,
        'per_class_acc': per_class_acc,
        'macro_metrics': {'precision': macro_precision, 'recall': macro_recall, 'f1': macro_f1},
        'micro_metrics': {'precision': micro_precision, 'recall': micro_recall, 'f1': micro_f1},
        'f1_ci': f1_ci,
        'prec_ci': prec_ci,
        'rec_ci': rec_ci,
        'acc_ci': acc_ci,
        'cohens_d': f1_cohens_d
    }

# Run statistical analysis with plots
stats_results = statistical_significance_analysis_multi_class()

In [None]:
# **CROSS-VALIDATION ANALYSIS FOR MULTI-CLASS + PLOTS**

def cross_validation_analysis_multi_class(k_folds=5):
    """Perform k-fold cross-validation for multi-class classification with visualizations"""

    print("="*60)
    print(" CROSS-VALIDATION ANALYSIS - MULTI-CLASS")
    print("="*60)

    from sklearn.model_selection import StratifiedKFold

    # Prepare data for cross-validation
    X = []
    y = []

    print(f"Preparing {len(test_df)} test samples for cross-validation...")

    for _, sample in test_df.iterrows():
        try:
            img = tf.keras.preprocessing.image.load_img(
                sample['image_path_absolute'], target_size=(224, 224)
            )
            img_array = tf.keras.preprocessing.image.img_to_array(img)
            img_array = tf.keras.applications.mobilenet_v2.preprocess_input(img_array)
            X.append(img_array)
            y.append(sample['encoded_class_id'])
        except Exception as e:
            print(f"Error processing {sample['image_path_absolute']}: {e}")
            continue

    X = np.array(X)
    y = np.array(y)

    print(f"Data prepared: X shape {X.shape}, y shape {y.shape}")

    # K-fold cross-validation
    kf = StratifiedKFold(n_splits=k_folds, shuffle=True, random_state=42)
    cv_scores = []

    print(f"\nPerforming {k_folds}-fold cross-validation...")

    for fold, (train_idx, val_idx) in enumerate(kf.split(X, y)):
        print(f"\nFold {fold + 1}/{k_folds}")

        X_train_fold, X_val_fold = X[train_idx], X[val_idx]
        y_train_fold, y_val_fold = y[train_idx], y[val_idx]

        # Evaluate on validation fold
        predictions = model.predict(X_val_fold, verbose=0)
        pred_binary = predictions.argmax(axis=1)

        # Calculate all metrics
        accuracy = (pred_binary == y_val_fold).mean()

        # Calculate precision, recall, F1 for this fold
        precision, recall, f1, _ = precision_recall_fscore_support(
            y_val_fold, pred_binary, average='macro', zero_division=0
        )

        # Calculate micro-averaged metrics
        micro_precision, micro_recall, micro_f1, _ = precision_recall_fscore_support(
            y_val_fold, pred_binary, average='micro', zero_division=0
        )

        cv_scores.append({
            'fold': fold + 1,
            'accuracy': accuracy,
            'f1_score': f1,
            'precision': precision,
            'recall': recall,
            'micro_precision': micro_precision,
            'micro_recall': micro_recall,
            'micro_f1': micro_f1
        })

        print(f"  Accuracy: {accuracy:.4f}")
        print(f"  F1-Score (Macro): {f1:.4f}")
        print(f"  Precision (Macro): {precision:.4f}")
        print(f"  Recall (Macro): {recall:.4f}")
        print(f"  F1-Score (Micro): {micro_f1:.4f}")

    # Cross-validation summary
    print(f"\n" + "="*60)
    print(" CROSS-VALIDATION SUMMARY")
    print("="*60)

    accuracies = [score['accuracy'] for score in cv_scores]
    f1_scores = [score['f1_score'] for score in cv_scores]
    precisions = [score['precision'] for score in cv_scores]
    recalls = [score['recall'] for score in cv_scores]
    micro_f1s = [score['micro_f1'] for score in cv_scores]

    print(f"Accuracy: {np.mean(accuracies):.4f} ± {np.std(accuracies):.4f}")
    print(f"F1-Score (Macro): {np.mean(f1_scores):.4f} ± {np.std(f1_scores):.4f}")
    print(f"Precision (Macro): {np.mean(precisions):.4f} ± {np.std(precisions):.4f}")
    print(f"Recall (Macro): {np.mean(recalls):.4f} ± {np.std(recalls):.4f}")
    print(f"F1-Score (Micro): {np.mean(micro_f1s):.4f} ± {np.std(micro_f1s):.4f}")

    # Stability assessment
    stability = 'Good' if np.std(accuracies) < 0.05 else 'Moderate' if np.std(accuracies) < 0.1 else 'Poor'
    print(f"Stability: {stability}")

    # ===== CREATE VISUALIZATIONS =====
    print(f"\nCreating cross-validation visualizations...")

    # Create comprehensive subplot grid
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))

    # 1. Fold-by-Fold Performance Comparison
    fold_numbers = [score['fold'] for score in cv_scores]

    axes[0, 0].plot(fold_numbers, accuracies, 'o-', linewidth=2, markersize=8,
                     color='blue', label='Accuracy', alpha=0.8)
    axes[0, 0].plot(fold_numbers, f1_scores, 's-', linewidth=2, markersize=8,
                     color='red', label='F1-Score (Macro)', alpha=0.8)
    axes[0, 0].plot(fold_numbers, micro_f1s, '^-', linewidth=2, markersize=8,
                     color='green', label='F1-Score (Micro)', alpha=0.8)

    axes[0, 0].set_title('Fold-by-Fold Performance Comparison', fontsize=14, fontweight='bold')
    axes[0, 0].set_xlabel('Fold Number', fontsize=12)
    axes[0, 0].set_ylabel('Score', fontsize=12)
    axes[0, 0].set_xticks(fold_numbers)
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)

    # Add mean lines
    axes[0, 0].axhline(np.mean(accuracies), color='blue', linestyle='--', alpha=0.5,
                       label=f'Mean Acc: {np.mean(accuracies):.3f}')
    axes[0, 0].axhline(np.mean(f1_scores), color='red', linestyle='--', alpha=0.5,
                       label=f'Mean F1: {np.mean(f1_scores):.3f}')

    # 2. Precision vs Recall by Fold
    axes[0, 1].scatter(precisions, recalls, s=100, c=fold_numbers, cmap='viridis',
                        alpha=0.8, edgecolors='black', linewidth=2)

    # Add fold labels
    for i, fold in enumerate(fold_numbers):
        axes[0, 1].annotate(f'Fold {fold}', (precisions[i], recalls[i]),
                           xytext=(5, 5), textcoords='offset points',
                           fontsize=10, fontweight='bold')

    axes[0, 1].set_title('Precision vs Recall by Fold', fontsize=14, fontweight='bold')
    axes[0, 1].set_xlabel('Precision (Macro)', fontsize=12)
    axes[0, 1].set_ylabel('Recall (Macro)', fontsize=12)
    axes[0, 1].grid(True, alpha=0.3)

    # Add colorbar
    scatter = axes[0, 1].scatter(precisions, recalls, s=100, c=fold_numbers, cmap='viridis',
                                 alpha=0.8, edgecolors='black', linewidth=2)
    cbar = plt.colorbar(scatter, ax=axes[0, 1])
    cbar.set_label('Fold Number', fontsize=12)

    # 3. Performance Stability Analysis
    metrics_names = ['Accuracy', 'F1-Score\n(Macro)', 'Precision\n(Macro)', 'Recall\n(Macro)', 'F1-Score\n(Micro)']
    metrics_values = [accuracies, f1_scores, precisions, recalls, micro_f1s]
    metrics_means = [np.mean(values) for values in metrics_values]
    metrics_stds = [np.std(values) for values in metrics_values]

    x_pos = np.arange(len(metrics_names))
    bars = axes[1, 0].bar(x_pos, metrics_means, yerr=metrics_stds, capsize=5,
                           color=['blue', 'red', 'green', 'orange', 'purple'], alpha=0.8)

    axes[1, 0].set_title('Performance Stability Across Folds', fontsize=14, fontweight='bold')
    axes[1, 0].set_xlabel('Metrics', fontsize=12)
    axes[1, 0].set_ylabel('Score', fontsize=12)
    axes[1, 0].set_xticks(x_pos)
    axes[1, 0].set_xticklabels(metrics_names, rotation=45, ha='right')
    axes[1, 0].grid(True, alpha=0.3)

    # Add value labels on bars
    for bar, mean, std in zip(bars, metrics_means, metrics_stds):
        height = bar.get_height()
        axes[1, 0].text(bar.get_x() + bar.get_width()/2., height + std + 0.01,
                        f'{mean:.3f}\n±{std:.3f}', ha='center', va='bottom',
                        fontweight='bold', fontsize=10)

    # 4. Fold Performance Heatmap
    # Prepare data for heatmap
    heatmap_data = np.array([accuracies, f1_scores, precisions, recalls, micro_f1s])
    heatmap_labels = ['Accuracy', 'F1-Score\n(Macro)', 'Precision\n(Macro)', 'Recall\n(Macro)', 'F1-Score\n(Micro)']

    im = axes[1, 1].imshow(heatmap_data, cmap='YlOrRd', aspect='auto', alpha=0.8)

    # Set labels
    axes[1, 1].set_xticks(range(len(fold_numbers)))
    axes[1, 1].set_xticklabels([f'Fold {f}' for f in fold_numbers])
    axes[1, 1].set_yticks(range(len(heatmap_labels)))
    axes[1, 1].set_yticklabels(heatmap_labels)

    # Add text annotations
    for i in range(len(heatmap_labels)):
        for j in range(len(fold_numbers)):
            text = axes[1, 1].text(j, i, f'{heatmap_data[i, j]:.3f}',
                                  ha="center", va="center", color="black", fontweight='bold')

    axes[1, 1].set_title('Cross-Validation Performance Heatmap', fontsize=14, fontweight='bold')

    # Add colorbar
    cbar = plt.colorbar(im, ax=axes[1, 1])
    cbar.set_label('Score Value', fontsize=12)

    plt.tight_layout()

    # Save visualization
    cv_viz_path = os.path.join(DRIVE_VISUALIZATIONS_DIR, 'stage2_cross_validation_analysis.png')
    plt.savefig(cv_viz_path, dpi=300, bbox_inches='tight')
    plt.show()

    print(f"Cross-validation visualizations saved to: {cv_viz_path}")

    return cv_scores

# Run cross-validation with plots
cv_results = cross_validation_analysis_multi_class()

In [None]:
# **ERROR ANALYSIS AND MISCLASSIFICATION STUDY - MULTI-CLASS + PLOTS**

def error_analysis_multi_class():
    """Analyze errors and misclassifications for multi-class classification with visualizations"""

    print("="*60)
    print(" ERROR ANALYSIS AND MISCLASSIFICATION STUDY - MULTI-CLASS")
    print("="*60)

    # Analyze test predictions
    total_samples = len(y_test)
    correct_predictions = (yhat_test == y_test).sum()
    misclassifications = total_samples - correct_predictions

    print(f"Total test samples: {total_samples}")
    print(f"Correct predictions: {correct_predictions} ({correct_predictions/total_samples*100:.1f}%)")
    print(f"Misclassifications: {misclassifications} ({misclassifications/total_samples*100:.1f}%)")

    # Calculate per-class error rates
    print(f"\nPER-CLASS ERROR ANALYSIS:")
    print("-" * 50)

    class_errors = {}
    for i in range(num_classes):
        class_mask = (y_test == i)
        if class_mask.sum() > 0:
            class_correct = (yhat_test[class_mask] == y_test[class_mask]).sum()
            class_total = class_mask.sum()
            class_error_rate = (class_total - class_correct) / class_total

            # Calculate per-class precision, recall, F1
            class_precision, class_recall, class_f1, _ = precision_recall_fscore_support(
                y_test, yhat_test, labels=[i], average=None, zero_division=0
            )

            class_errors[i] = {
                'total': class_total,
                'correct': class_correct,
                'errors': class_total - class_correct,
                'error_rate': class_error_rate,
                'precision': class_precision[0],
                'recall': class_recall[0],
                'f1_score': class_f1[0]
            }

    # Sort by error rate
    sorted_errors = sorted(class_errors.items(), key=lambda x: x[1]['error_rate'], reverse=True)

    print(f"Top 10 Classes with Highest Error Rates:")
    print(f"{'Class ID':<8} {'Class Name':<35} {'Error Rate':<12} {'Precision':<12} {'Recall':<12} {'F1-Score':<12}")
    print("-" * 100)

    for i, (class_id, error_info) in enumerate(sorted_errors[:10]):
        class_name = idx2name.get(class_id, f"Class_{class_id}")
        print(f"{class_id:<8} {class_name[:34]:<35} {error_info['error_rate']:<12.1%} {error_info['precision']:<12.4f} {error_info['recall']:<12.4f} {error_info['f1_score']:<12.4f}")

    # Analyze misclassifications by class
    print(f"\nMISCLASSIFICATION PATTERN ANALYSIS:")
    print("-" * 50)

    # Create confusion matrix for detailed analysis
    cm = confusion_matrix(y_test, yhat_test, labels=range(num_classes))

    # Find most confused class pairs
    most_confused = []
    for i in range(num_classes):
        for j in range(num_classes):
            if i != j and cm[i, j] > 0:
                most_confused.append({
                    'true_class': i,
                    'predicted_class': j,
                    'count': cm[i, j],
                    'true_class_name': idx2name.get(i, f"Class_{i}"),
                    'predicted_class_name': idx2name.get(j, f"Class_{j}"),
                    'error_rate': cm[i, j] / cm[i, :].sum() if cm[i, :].sum() > 0 else 0
                })

    # Sort by confusion count
    most_confused.sort(key=lambda x: x['count'], reverse=True)

    print(f"Top 10 Most Confused Class Pairs:")
    for i, confusion in enumerate(most_confused[:10]):
        print(f"  {i+1:2d}. {confusion['true_class_name'][:25]:<25} → {confusion['predicted_class_name'][:25]:<25}")
        print(f"      Count: {confusion['count']} samples, Error Rate: {confusion['error_rate']:.1%}")

    # Analyze confidence vs accuracy
    print(f"\nCONFIDENCE ANALYSIS:")
    print("-" * 30)

    # Get prediction confidences
    confidences = p_test.max(axis=1)
    correct_mask = (yhat_test == y_test)

    # High confidence errors
    high_conf_threshold = 0.8
    high_conf_mask = confidences >= high_conf_threshold
    high_conf_errors = high_conf_mask & ~correct_mask

    print(f"High confidence errors (≥{high_conf_threshold}): {high_conf_errors.sum()}")
    print(f"High confidence accuracy: {(high_conf_mask & correct_mask).sum() / high_conf_mask.sum():.1%}")

    # Low confidence errors
    low_conf_threshold = 0.5
    low_conf_mask = confidences < low_conf_threshold
    low_conf_errors = low_conf_mask & ~correct_mask

    print(f"Low confidence errors (<{low_conf_threshold}): {low_conf_errors.sum()}")
    print(f"Low confidence accuracy: {(low_conf_mask & correct_mask).sum() / low_conf_mask.sum():.1%}")

    # Performance correlation analysis
    print(f"\nPERFORMANCE CORRELATION ANALYSIS:")
    print("-" * 40)

    # Correlation between class support and performance
    support_values = np.array([class_errors[i]['total'] for i in range(num_classes)])
    f1_values = np.array([class_errors[i]['f1_score'] for i in range(num_classes)])

    support_f1_corr = np.corrcoef(support_values, f1_values)[0, 1]
    print(f"Correlation (Support vs F1-Score): {support_f1_corr:.4f}")

    # Correlation between class support and error rate
    error_rates = np.array([class_errors[i]['error_rate'] for i in range(num_classes)])
    support_error_corr = np.corrcoef(support_values, error_rates)[0, 1]
    print(f"Correlation (Support vs Error Rate): {support_error_corr:.4f}")

    # ===== CREATE VISUALIZATIONS =====
    print(f"\nCreating error analysis visualizations...")

    # Create comprehensive subplot grid
    fig, axes = plt.subplots(2, 3, figsize=(20, 12))

    # 1. Top Error Classes Bar Chart
    top_10_errors = sorted_errors[:10]
    class_names = [idx2name.get(cls_id, f"Class_{cls_id}")[:20] + "..." for cls_id, _ in top_10_errors]
    error_rates = [error_info['error_rate'] for _, error_info in top_10_errors]

    bars = axes[0, 0].bar(range(len(class_names)), error_rates, color='red', alpha=0.7)
    axes[0, 0].set_title('Top 10 Classes with Highest Error Rates', fontsize=14, fontweight='bold')
    axes[0, 0].set_xlabel('Classes', fontsize=12)
    axes[0, 0].set_ylabel('Error Rate', fontsize=12)
    axes[0, 0].set_xticks(range(len(class_names)))
    axes[0, 0].set_xticklabels(class_names, rotation=45, ha='right')
    axes[0, 0].grid(True, alpha=0.3)

    # Add value labels on bars
    for bar, rate in zip(bars, error_rates):
        height = bar.get_height()
        axes[0, 0].text(bar.get_x() + bar.get_width()/2., height + 0.01,
                        f'{rate:.1%}', ha='center', va='bottom', fontweight='bold')

    # 2. Error Rate vs F1-Score Scatter
    all_error_rates = [error_info['error_rate'] for _, error_info in sorted_errors]
    all_f1_scores = [error_info['f1_score'] for _, error_info in sorted_errors]

    axes[0, 1].scatter(all_error_rates, all_f1_scores, alpha=0.7, color='blue', s=50)
    axes[0, 1].set_title('Error Rate vs F1-Score Correlation', fontsize=14, fontweight='bold')
    axes[0, 1].set_xlabel('Error Rate', fontsize=12)
    axes[0, 1].set_ylabel('F1-Score', fontsize=12)
    axes[0, 1].grid(True, alpha=0.3)

    # Add trend line
    z = np.polyfit(all_error_rates, all_f1_scores, 1)
    p = np.poly1d(z)
    correlation = np.corrcoef(all_error_rates, all_f1_scores)[0, 1]
    axes[0, 1].plot(all_error_rates, p(all_error_rates), "r--", alpha=0.8,
                     label=f"Trend (r={correlation:.3f})")
    axes[0, 1].legend()

    # 3. Class Support vs Error Rate
    all_supports = [error_info['total'] for _, error_info in sorted_errors]

    axes[0, 2].scatter(all_supports, all_error_rates, alpha=0.7, color='green', s=50)
    axes[0, 2].set_title('Class Support vs Error Rate', fontsize=14, fontweight='bold')
    axes[0, 2].set_xlabel('Number of Samples (Support)', fontsize=12)
    axes[0, 2].set_ylabel('Error Rate', fontsize=12)
    axes[0, 2].grid(True, alpha=0.3)

    # Add trend line
    z = np.polyfit(all_supports, all_error_rates, 1)
    p = np.poly1d(z)
    correlation = np.corrcoef(all_supports, all_error_rates)[0, 1]
    axes[0, 2].plot(all_supports, p(all_supports), "r--", alpha=0.8,
                     label=f"Trend (r={correlation:.3f})")
    axes[0, 2].legend()

    # 4. Most Confused Class Pairs
    top_confused = most_confused[:10]
    confusion_counts = [confusion['count'] for confusion in top_confused]
    confusion_labels = [f"{confusion['true_class_name'][:15]}...\n→ {confusion['predicted_class_name'][:15]}..."
                       for confusion in top_confused]

    bars = axes[1, 0].bar(range(len(confusion_counts)), confusion_counts, color='orange', alpha=0.7)
    axes[1, 0].set_title('Top 10 Most Confused Class Pairs', fontsize=14, fontweight='bold')
    axes[1, 0].set_xlabel('Class Pairs', fontsize=12)
    axes[1, 0].set_ylabel('Number of Confusions', fontsize=12)
    axes[1, 0].set_xticks(range(len(confusion_labels)))
    axes[1, 0].set_xticklabels(confusion_labels, rotation=45, ha='right')
    axes[1, 0].grid(True, alpha=0.3)

    # Add value labels on bars
    for bar, count in zip(bars, confusion_counts):
        height = bar.get_height()
        axes[1, 0].text(bar.get_x() + bar.get_width()/2., height + 0.5,
                        str(count), ha='center', va='bottom', fontweight='bold')

    # 5. Confidence vs Accuracy Analysis
    confidence_bins = np.linspace(0, 1, 11)
    bin_accuracies = []
    bin_centers = []

    for i in range(len(confidence_bins) - 1):
        bin_mask = (confidences >= confidence_bins[i]) & (confidences < confidence_bins[i + 1])
        if bin_mask.sum() > 0:
            bin_accuracy = correct_mask[bin_mask].mean()
            bin_accuracies.append(bin_accuracy)
            bin_centers.append((confidence_bins[i] + confidence_bins[i + 1]) / 2)

    axes[1, 1].plot(bin_centers, bin_accuracies, 'o-', linewidth=2, markersize=8,
                     color='purple', alpha=0.8)
    axes[1, 1].plot([0, 1], [0, 1], '--', color='red', alpha=0.5, label='Perfect Calibration')
    axes[1, 1].set_title('Confidence vs Accuracy Analysis', fontsize=14, fontweight='bold')
    axes[1, 1].set_xlabel('Confidence', fontsize=12)
    axes[1, 1].set_ylabel('Accuracy', fontsize=12)
    axes[1, 1].grid(True, alpha=0.3)
    axes[1, 1].legend()

    # 6. Error Distribution by Performance Level
    performance_levels = ['High (F1 > 0.8)', 'Medium (0.5 < F1 ≤ 0.8)', 'Low (F1 ≤ 0.5)']
    high_perf = [error_info for _, error_info in sorted_errors if error_info['f1_score'] > 0.8]
    medium_perf = [error_info for _, error_info in sorted_errors if 0.5 < error_info['f1_score'] <= 0.8]
    low_perf = [error_info for _, error_info in sorted_errors if error_info['f1_score'] <= 0.5]

    performance_counts = [len(high_perf), len(medium_perf), len(low_perf)]
    performance_colors = ['green', 'orange', 'red']

    bars = axes[1, 2].bar(performance_levels, performance_counts, color=performance_colors, alpha=0.7)
    axes[1, 2].set_title('Error Distribution by Performance Level', fontsize=14, fontweight='bold')
    axes[1, 2].set_xlabel('Performance Level', fontsize=12)
    axes[1, 2].set_ylabel('Number of Classes', fontsize=12)
    axes[1, 2].grid(True, alpha=0.3)

    # Add value labels on bars
    for bar, count in zip(bars, performance_counts):
        height = bar.get_height()
        axes[1, 2].text(bar.get_x() + bar.get_width()/2., height + 0.1,
                        str(count), ha='center', va='bottom', fontweight='bold')

    plt.tight_layout()

    # Save visualization
    error_viz_path = os.path.join(DRIVE_VISUALIZATIONS_DIR, 'stage2_error_analysis.png')
    plt.savefig(error_viz_path, dpi=300, bbox_inches='tight')
    plt.show()

    print(f"Error analysis visualizations saved to: {error_viz_path}")

    return {
        'class_errors': class_errors,
        'most_confused_pairs': most_confused,
        'high_conf_errors': high_conf_errors.sum(),
        'low_conf_errors': low_conf_errors.sum(),
        'overall_error_rate': misclassifications / total_samples,
        'correlations': {
            'support_vs_f1': support_f1_corr,
            'support_vs_error': support_error_corr
        }
    }

# Run error analysis with plots
error_results = error_analysis_multi_class()

In [None]:
# **CONFUSION MATRIX ANALYSIS - MULTI-CLASS**

def confusion_matrix_analysis_multi_class():
    """Comprehensive confusion matrix analysis for multi-class disease classification"""

    print("="*60)
    print(" CONFUSION MATRIX ANALYSIS - MULTI-CLASS")
    print(" Disease Classification Confusion Analysis")
    print("="*60)

    # Calculate confusion matrix
    cm = confusion_matrix(y_test, yhat_test, labels=range(num_classes))

    print(f"CONFUSION MATRIX OVERVIEW:")
    print(f"  Matrix shape: {cm.shape} ({num_classes} x {num_classes})")
    print(f"  Total samples: {cm.sum()}")
    print(f"  Correct predictions (diagonal): {np.trace(cm)}")
    print(f"  Incorrect predictions (off-diagonal): {cm.sum() - np.trace(cm)}")

    # Basic confusion matrix statistics
    print(f"\nBASIC STATISTICS:")
    print(f"  Overall Accuracy: {np.trace(cm) / cm.sum():.4f}")
    print(f"  Error Rate: {(cm.sum() - np.trace(cm)) / cm.sum():.4f}")

    # Per-class confusion analysis
    print(f"\nPER-CLASS CONFUSION ANALYSIS:")
    print("-" * 80)
    print(f"{'Class ID':<8} {'Class Name':<35} {'TP':<8} {'FP':<8} {'FN':<8} {'TN':<8} {'Precision':<12} {'Recall':<12}")
    print("-" * 80)

    class_confusion_analysis = []

    for i in range(num_classes):
        # True Positives (correctly predicted as class i)
        tp = cm[i, i]

        # False Positives (incorrectly predicted as class i)
        fp = cm[:, i].sum() - tp

        # False Negatives (class i incorrectly predicted as other classes)
        fn = cm[i, :].sum() - tp

        # True Negatives (correctly not predicted as class i)
        tn = cm.sum() - tp - fp - fn

        # Calculate precision and recall
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0

        class_name = idx2name.get(i, f"Class_{i}")

        print(f"{i:<8} {class_name[:34]:<35} {tp:<8} {fp:<8} {fn:<8} {tn:<8} {precision:<12.4f} {recall:<12.4f}")

        class_confusion_analysis.append({
            'class_id': i,
            'class_name': class_name,
            'tp': tp,
            'fp': fp,
            'fn': fn,
            'tn': tn,
            'precision': precision,
            'recall': recall
        })

    # Most confused class pairs analysis
    print(f"\nMOST CONFUSED CLASS PAIRS:")
    print("-" * 60)

    most_confused = []
    for i in range(num_classes):
        for j in range(num_classes):
            if i != j and cm[i, j] > 0:
                most_confused.append({
                    'true_class': i,
                    'predicted_class': j,
                    'count': cm[i, j],
                    'true_class_name': idx2name.get(i, f"Class_{i}"),
                    'predicted_class_name': idx2name.get(j, f"Class_{j}"),
                    'error_rate': cm[i, j] / cm[i, :].sum() if cm[i, :].sum() > 0 else 0
                })

    # Sort by confusion count
    most_confused.sort(key=lambda x: x['count'], reverse=True)

    print(f"Top 15 Most Confused Class Pairs:")
    for i, confusion in enumerate(most_confused[:15]):
        print(f"  {i+1:2d}. {confusion['true_class_name'][:25]:<25} → {confusion['predicted_class_name'][:25]:<25}")
        print(f"      Count: {confusion['count']} samples, Error Rate: {confusion['error_rate']:.1%}")

    # Class-wise error analysis
    print(f"\nCLASS-WISE ERROR ANALYSIS:")
    print("-" * 50)

    # Sort classes by error rate
    class_confusion_analysis.sort(key=lambda x: (x['fp'] + x['fn']) / (x['tp'] + x['fp'] + x['fn'] + x['tn']), reverse=True)

    print(f"Top 10 Classes with Highest Error Rates:")
    for i, analysis in enumerate(class_confusion_analysis[:10]):
        total_samples = analysis['tp'] + analysis['fp'] + analysis['fn'] + analysis['tn']
        error_rate = (analysis['fp'] + analysis['fn']) / total_samples if total_samples > 0 else 0
        print(f"  {i+1:2d}. {analysis['class_name'][:30]:<30}")
        print(f"      Error Rate: {error_rate:.1%}, FP: {analysis['fp']}, FN: {analysis['fn']}")

    # Confusion matrix visualization
    print(f"\nCREATING CONFUSION MATRIX VISUALIZATIONS...")

    # 1. Full confusion matrix heatmap
    plt.figure(figsize=(16, 14))

    # Normalize confusion matrix for better visualization
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    cm_normalized = np.nan_to_num(cm_normalized)

    # Create heatmap
    sns.heatmap(cm_normalized,
                annot=True,
                fmt='.2f',
                cmap='Blues',
                xticklabels=[f"{i}\n{idx2name.get(i, f'Class_{i}')[:15]}..." for i in range(num_classes)],
                yticklabels=[f"{i}\n{idx2name.get(i, f'Class_{i}')[:15]}..." for i in range(num_classes)],
                cbar_kws={'label': 'Normalized Confusion Rate'})

    plt.title('Stage 2 - Multi-Class Disease Classification Confusion Matrix\n(Normalized by True Class)',
              fontsize=16, fontweight='bold')
    plt.xlabel('Predicted Class', fontsize=14, fontweight='bold')
    plt.ylabel('True Class', fontsize=14, fontweight='bold')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()

    # Save full confusion matrix
    full_cm_path = os.path.join(DRIVE_VISUALIZATIONS_DIR, 'stage2_full_confusion_matrix.png')
    plt.savefig(full_cm_path, dpi=300, bbox_inches='tight')
    plt.show()

    # 2. Top confused classes heatmap (focus on most problematic classes)
    print(f"Creating focused confusion matrix for top confused classes...")

    # Get top 20 classes with most errors
    top_error_classes = sorted(class_confusion_analysis,
                              key=lambda x: (x['fp'] + x['fn']),
                              reverse=True)[:20]

    top_class_ids = [cls['class_id'] for cls in top_error_classes]

    # Extract sub-matrix for top confused classes
    cm_top = cm[np.ix_(top_class_ids, top_class_ids)]

    plt.figure(figsize=(14, 12))

    # Normalize
    cm_top_normalized = cm_top.astype('float') / cm_top.sum(axis=1)[:, np.newaxis]
    cm_top_normalized = np.nan_to_num(cm_top_normalized)

    # Create heatmap
    sns.heatmap(cm_top_normalized,
                annot=True,
                fmt='.2f',
                cmap='Reds',
                xticklabels=[f"{i}\n{idx2name.get(i, f'Class_{i}')[:20]}" for i in top_class_ids],
                yticklabels=[f"{i}\n{idx2name.get(i, f'Class_{i}')[:20]}" for i in top_class_ids],
                cbar_kws={'label': 'Normalized Confusion Rate'})

    plt.title('Stage 2 - Top Confused Classes Confusion Matrix\n(Focus on Most Problematic Classes)',
              fontsize=16, fontweight='bold')
    plt.xlabel('Predicted Class', fontsize=14, fontweight='bold')
    plt.ylabel('True Class', fontsize=14, fontweight='bold')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()

    # Save focused confusion matrix
    focused_cm_path = os.path.join(DRIVE_VISUALIZATIONS_DIR, 'stage2_focused_confusion_matrix.png')
    plt.savefig(focused_cm_path, dpi=300, bbox_inches='tight')
    plt.show()

    # 3. Confusion pattern analysis
    print(f"Creating confusion pattern analysis...")

    # Analyze confusion patterns
    confusion_patterns = {}

    for i in range(num_classes):
        class_errors = []
        for j in range(num_classes):
            if i != j and cm[i, j] > 0:
                class_errors.append({
                    'predicted_as': j,
                    'count': cm[i, j],
                    'error_rate': cm[i, j] / cm[i, :].sum() if cm[i, :].sum() > 0 else 0,
                    'predicted_class_name': idx2name.get(j, f"Class_{j}")
                })

        if class_errors:
            # Sort by error count
            class_errors.sort(key=lambda x: x['count'], reverse=True)
            confusion_patterns[i] = class_errors

    # Save confusion analysis results
    confusion_analysis_results = {
        'confusion_matrix': cm.tolist(),
        'normalized_confusion_matrix': cm_normalized.tolist(),
        'class_confusion_analysis': class_confusion_analysis,
        'most_confused_pairs': most_confused[:20],  # Top 20
        'confusion_patterns': confusion_patterns,
        'matrix_statistics': {
            'total_samples': int(cm.sum()),
            'correct_predictions': int(np.trace(cm)),
            'incorrect_predictions': int(cm.sum() - np.trace(cm)),
            'overall_accuracy': float(np.trace(cm) / cm.sum()),
            'overall_error_rate': float((cm.sum() - np.trace(cm)) / cm.sum())
        }
    }

    # Save to file
    confusion_analysis_path = os.path.join(DRIVE_ANALYSIS_DIR, 'stage2_confusion_matrix_analysis.json')
    with open(confusion_analysis_path, 'w') as f:
        json.dump(confusion_analysis_results, f, indent=2, default=str)

    print(f"\nConfusion matrix analysis completed!")
    print(f"Results saved to: {confusion_analysis_path}")
    print(f"Visualizations saved to:")
    print(f"  - Full confusion matrix: {full_cm_path}")
    print(f"  - Focused confusion matrix: {focused_cm_path}")

    return confusion_analysis_results

# Run confusion matrix analysis
confusion_matrix_results = confusion_matrix_analysis_multi_class()

# **MODEL TESTING WITH GRAD-CAM VISUALIZATION**

In [None]:
# Grad-CAM core for Stage 2 (MobileNetV2 head rebuilt in one graph)

import os
import numpy as np
import tensorflow as tf
import cv2

# Cache to avoid rebuilding the feature+prediction model each call
_GRADCAM_FEATPRED_CACHE = {}

def _build_featpred_model_from_existing(model):
    # Rebuild: inputs -> base(out_relu) -> replay the head -> outputs
    inputs = model.input                                   # original input tensor
    base = model.get_layer('mobilenetv2_1.00_224')         # nested MobileNetV2 (seen in your summary)
    feat = base(inputs)                                    # last conv feature (7x7x1280)

    x = feat                                               # replay layers 2..end (head)
    for layer in model.layers[2:]:
        x = layer(x)

    # outputs: (features, logits_or_probs)
    featpred = tf.keras.Model(inputs=inputs, outputs=[feat, x])
    return featpred

def create_stage2_gradcam(model, img_path, target_class=None, target_size=(224, 224), idx2name=None):
    # Get or build the single-graph (features, predictions) model
    key = id(model)
    featpred = _GRADCAM_FEATPRED_CACHE.get(key)
    if featpred is None:
        featpred = _build_featpred_model_from_existing(model)
        _GRADCAM_FEATPRED_CACHE[key] = featpred

    # Load + preprocess image
    img = tf.keras.preprocessing.image.load_img(img_path, target_size=target_size)
    img_arr = tf.keras.preprocessing.image.img_to_array(img)
    orig = img_arr.astype(np.uint8)
    img_proc = tf.keras.applications.mobilenet_v2.preprocess_input(img_arr.copy())
    inp = tf.convert_to_tensor(np.expand_dims(img_proc, 0), dtype=tf.float32)

    # Forward for prediction (use softmax for safety)
    _, preds = featpred(inp, training=False)
    probs = tf.nn.softmax(preds, axis=-1).numpy()[0]
    pred_cls = int(np.argmax(probs))
    pred_conf = float(probs[pred_cls])
    target = int(target_class) if target_class is not None else pred_cls
    pred_name = idx2name.get(target, f"Class_{target}") if isinstance(idx2name, dict) else f"Class_{target}"

    # Gradients of target score w.r.t. conv features (same graph tensors)
    with tf.GradientTape() as tape:
        feat_map, preds2 = featpred(inp, training=False)
        score = tf.nn.softmax(preds2, axis=-1)[:, target]
    grads = tape.gradient(score, feat_map)
    if grads is None:
        return orig, None, None, pred_cls, pred_name, pred_conf

    # Grad-CAM heatmap (global-average pool grads -> channel weights)
    weights = tf.reduce_mean(grads, axis=(0, 1, 2))           # (C,)
    cam = tf.reduce_sum(weights * feat_map[0], axis=-1)       # (Hf, Wf)
    cam = tf.nn.relu(cam)
    cam = cam / (tf.reduce_max(cam) + 1e-8)
    heat = cam.numpy()

    # Resize, colorize, overlay
    heat_r = cv2.resize(heat, (target_size[1], target_size[0]), interpolation=cv2.INTER_CUBIC)
    focus = np.where(heat_r > 0.30, heat_r, 0.0)
    heat_rgb = cv2.applyColorMap(np.uint8(255 * focus), cv2.COLORMAP_JET)
    heat_rgb = cv2.cvtColor(heat_rgb, cv2.COLOR_BGR2RGB)

    overlay = orig.copy().astype(np.float32)
    mask = focus > 0.40
    if np.any(mask):
        overlay[mask] = 0.6 * heat_rgb[mask] + 0.4 * orig[mask]
    overlay = np.clip(overlay, 0, 255).astype(np.uint8)

    return orig, heat_rgb, overlay, pred_cls, pred_name, pred_conf

def build_idx2name_from_df(df):
    # Build {class_id: class_name} from a dataframe
    if not {'encoded_class_id', 'class_name'}.issubset(df.columns):
        return {}
    return (df[['encoded_class_id', 'class_name']]
            .drop_duplicates()
            .sort_values('encoded_class_id')
            .set_index('encoded_class_id')['class_name']
            .to_dict())

In [None]:
# Random Grad-CAM panel visualizer (picks different samples each run unless seed is set)

import time
import matplotlib.pyplot as plt

# Output folder
if 'DRIVE_GRADCAM_DIR' not in globals():
    DRIVE_GRADCAM_DIR = '/content/drive/MyDrive/Stage2_Enhanced_GradCAM'
os.makedirs(DRIVE_GRADCAM_DIR, exist_ok=True)

def visualize_and_save_stage2_gradcam(
    model,
    test_df,
    idx2name=None,
    num_samples=6,
    target_mode='true',    # 'true' -> explain true class, 'pred' -> explain predicted class
    seed=None,             # None => different each run; set int for reproducibility
    save_individual_overlays=True
):
    # Build idx2name if not provided
    if idx2name is None:
        idx2name = build_idx2name_from_df(test_df)

    rng = np.random.default_rng(None if seed is None else int(seed))

    # Pick distinct classes at random
    unique_classes = test_df['encoded_class_id'].unique()
    chosen = rng.choice(unique_classes, size=min(num_samples, len(unique_classes)), replace=False)

    # Pick one random image per chosen class
    samples = []
    for cls in chosen:
        sub = test_df[test_df['encoded_class_id'] == cls]
        if len(sub) == 0:
            continue
        rand_idx = int(rng.integers(0, len(sub)))
        samples.append(sub.iloc[rand_idx])

    rows = len(samples)
    fig, axes = plt.subplots(rows, 3, figsize=(18, 4 * rows))
    if rows == 1:
        axes = np.array([axes])

    saved_overlay_paths = []

    for i, sample in enumerate(samples):
        img_path = sample['image_path_absolute']
        true_id = int(sample['encoded_class_id'])
        true_name = idx2name.get(true_id, f"Class_{true_id}")

        target = true_id if target_mode == 'true' else None

        original, heatmap, overlay, pred_id, pred_name, pred_conf = create_stage2_gradcam(
            model=model,
            img_path=img_path,
            target_class=target,
            target_size=(224, 224),
            idx2name=idx2name
        )
        if target_mode == 'pred':
            pred_name = idx2name.get(pred_id, f"Class_{pred_id}")

        is_correct = (pred_id == true_id)
        border_color = 'green' if is_correct else 'red'

        axes[i, 0].imshow(original.astype(np.uint8))
        axes[i, 0].set_title(f"Original\nTrue: {true_name[:28]}", fontsize=11, fontweight='bold')
        axes[i, 0].axis('off')

        if heatmap is not None:
            axes[i, 1].imshow(heatmap)
            axes[i, 1].set_title("Disease Attention Map\nRed = High Attention", fontsize=11, fontweight='bold')
        else:
            axes[i, 1].text(0.5, 0.5, 'Grad-CAM Unavailable', ha='center', va='center', fontsize=12, color='red', fontweight='bold')
        axes[i, 1].axis('off')

        axes[i, 2].imshow((overlay if overlay is not None else original).astype(np.uint8))
        axes[i, 2].set_title(f"Pred: {pred_name[:28]}\nConf: {pred_conf:.1%}", fontsize=11, fontweight='bold')
        axes[i, 2].axis('off')

        for col in range(3):
            for spine in axes[i, col].spines.values():
                spine.set_color(border_color)
                spine.set_linewidth(3)
                spine.set_visible(True)

        if save_individual_overlays and overlay is not None:
            base = os.path.splitext(os.path.basename(img_path))[0]
            tag = 'true' if target_mode == 'true' else 'pred'
            overlay_name = f"{base}_overlay_{tag}_cls{true_id}_pred{pred_id}.png"
            overlay_path = os.path.join(DRIVE_GRADCAM_DIR, overlay_name)
            plt.imsave(overlay_path, overlay.astype(np.uint8))
            saved_overlay_paths.append(overlay_path)

    mode_label = 'TrueClass' if target_mode == 'true' else 'PredClass'
    plt.suptitle(f"Stage 2 Grad-CAM ({mode_label}) – {rows} Classes", fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.subplots_adjust(top=0.90)

    ts = int(time.time())
    panel_name = f"stage2_gradcam_panel_{mode_label}_{rows}rows_{ts}.png"
    panel_path = os.path.join(DRIVE_GRADCAM_DIR, panel_name)
    plt.savefig(panel_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.show()

    print(f"Saved panel: {panel_path}")
    if saved_overlay_paths:
        print(f"Saved {len(saved_overlay_paths)} overlay(s) to: {DRIVE_GRADCAM_DIR}")
    return panel_path, saved_overlay_paths

In [None]:
# Ensure you have: model (loaded Stage 2 model) and test_df (Stage 2 test split)

# Build class-name map once (optional; the visualizer will build it if None)
idx2name = build_idx2name_from_df(test_df)

# Random samples every run (seed=None)
panel_path, overlay_paths = visualize_and_save_stage2_gradcam(
    model=model,
    test_df=test_df,
    idx2name=idx2name,
    num_samples=6,
    target_mode='true',          # use 'pred' to explain predicted class
    seed=None,                   # None => different random selection each run
    save_individual_overlays=True
)
print(panel_path)

In [None]:
# **COMPREHENSIVE RESULTS SUMMARY AND VISUALIZATION**

def create_comprehensive_summary():
    """Create comprehensive summary of all evaluation results"""

    print("="*60)
    print(" COMPREHENSIVE EVALUATION SUMMARY - STAGE 2")
    print("="*60)

    # Compile all results
    summary = {
        'model_info': {
            'model_path': BEST_MODEL_PATH,
            'num_classes': num_classes,
            'total_test_samples': len(y_test)
        },
        'performance_metrics': {
            'accuracy': acc_test,
            'top3_accuracy': acc_top3,
            'top5_accuracy': acc_top5,
            'f1_macro': precision_recall_fscore_support(
                y_test, yhat_test, average='macro', zero_division=0
            )[2]
        },
        'bootstrap_results': bootstrap_results,
        'cross_validation': {
            'mean_accuracy': np.mean([score['accuracy'] for score in cv_results]),
            'std_accuracy': np.std([score['accuracy'] for score in cv_results]),
            'mean_f1': np.mean([score['f1_score'] for score in cv_results]),
            'std_f1': np.std([score['f1_score'] for score in cv_results])
        },
        'error_analysis': error_results,
        'statistical_analysis': stats_results
    }

    # Print summary
    print(f"MODEL PERFORMANCE SUMMARY:")
    print(f"  Test Accuracy: {summary['performance_metrics']['accuracy']:.4f}")
    print(f"  Top-3 Accuracy: {summary['performance_metrics']['top3_accuracy']:.4f}")
    print(f"  Top-5 Accuracy: {summary['performance_metrics']['top5_accuracy']:.4f}")
    print(f"  F1-Score (Macro): {summary['performance_metrics']['f1_macro']:.4f}")

    print(f"\nBOOTSTRAP CONFIDENCE INTERVALS (95%):")
    print(f"  Accuracy: {bootstrap_results['accuracy']['mean']:.4f} [{bootstrap_results['accuracy']['ci'][0]:.4f}, {bootstrap_results['accuracy']['ci'][1]:.4f}]")
    print(f"  F1-Score: {bootstrap_results['f1_score']['mean']:.4f} [{bootstrap_results['f1_score']['ci'][0]:.4f}, {bootstrap_results['f1_score']['ci'][1]:.4f}]")

    print(f"\nCROSS-VALIDATION RESULTS:")
    print(f"  Accuracy: {summary['cross_validation']['mean_accuracy']:.4f} ± {summary['cross_validation']['std_accuracy']:.4f}")
    print(f"  F1-Score: {summary['cross_validation']['mean_f1']:.4f} ± {summary['cross_validation']['std_f1']:.4f}")

    print(f"\nERROR ANALYSIS:")
    print(f"  Overall Error Rate: {error_results['overall_error_rate']:.1%}")
    print(f"  High Confidence Errors: {error_results['high_conf_errors']}")
    print(f"  Low Confidence Errors: {error_results['low_conf_errors']}")

    # Save comprehensive summary
    summary_path = os.path.join(DRIVE_ANALYSIS_DIR, 'stage2_comprehensive_evaluation_summary.json')
    with open(summary_path, 'w') as f:
        json.dump(summary, f, indent=2, default=str)

    print(f"\nComprehensive summary saved to: {summary_path}")

    return summary

# Create and display comprehensive summary
final_summary = create_comprehensive_summary()

print("\n" + "="*60)
print(" STAGE 2 COMPREHENSIVE EVALUATION COMPLETE!")
print("="*60)
print("All evaluation metrics, visualizations, and analyses have been completed.")
