In [None]:
import matplotlib.pyplot as plt
import cv2
import pandas as pd
import tensorflow as tf
import numpy as np
import os
from sklearn.model_selection import train_test_split
from tensorflow.keras.preprocessing import image
from tensorflow.keras import layers, models, callbacks
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, roc_auc_score,
    average_precision_score, confusion_matrix, cohen_kappa_score,
    precision_recall_curve, roc_curve, classification_report
)
import time
import psutil
from tensorflow.keras.applications import DenseNet121
import gc

# Set random seeds for reproducibility
tf.random.set_seed(7)
np.random.seed(7)

# CUDA/GPU Optimization Configuration
def setup_gpu_optimization():
    """Enhanced GPU setup with CUDA optimizations"""
    print("Setting up GPU/CUDA optimizations...")
    
    # Enable GPU growth and mixed precision
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            # Enable memory growth for all GPUs
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
                
            # Set memory limit if needed (optional - uncomment if you want to limit GPU memory)
            # tf.config.experimental.set_memory_limit(gpus[0], 8192)  # 8GB limit
            
            print(f"✓ Found {len(gpus)} GPU(s)")
            print(f"✓ GPU devices: {[gpu.name for gpu in gpus]}")
            
            # Enable mixed precision for faster training on modern GPUs
            policy = tf.keras.mixed_precision.Policy('mixed_float16')
            tf.keras.mixed_precision.set_global_policy(policy)
            print("✓ Mixed precision (float16) enabled for faster training")
            
            # Enable XLA compilation for better performance
            tf.config.optimizer.set_jit(True)
            print("✓ XLA JIT compilation enabled")
            
            # Set GPU as default device
            tf.config.experimental.set_memory_growth(gpus[0], True)
            
            return True, gpus
            
        except RuntimeError as e:
            print(f"GPU setup error: {e}")
            return False, []
    else:
        print("⚠ No GPU found, using CPU")
        return False, []

# Initialize GPU
HAS_GPU, GPU_DEVICES = setup_gpu_optimization()

# Enhanced Configuration
IMG_SIZE = 224
BATCH_SIZE = 64 if HAS_GPU else 16  # Larger batch size for GPU
EPOCHS = 50
PREFETCH_BUFFER = tf.data.AUTOTUNE
NUM_PARALLEL_CALLS = tf.data.AUTOTUNE

# Enable tensor core usage on modern GPUs (V100, A100, RTX series)
if HAS_GPU:
    os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1'
    os.environ['TF_GPU_THREAD_MODE'] = 'gpu_private'

class RetinaMaskLayer(tf.keras.layers.Layer):
    """CUDA-optimized custom layer for retina masking"""
    
    def __init__(self, **kwargs):
        super(RetinaMaskLayer, self).__init__(**kwargs)
    
    def build(self, input_shape):
        super(RetinaMaskLayer, self).build(input_shape)
    
    @tf.function(experimental_relax_shapes=True)
    def call(self, inputs):
        feature_maps, mask = inputs
        
        # Use GPU-optimized operations
        with tf.device('/GPU:0' if HAS_GPU else '/CPU:0'):
            # Get feature map dimensions
            feature_shape = tf.shape(feature_maps)
            batch_size = feature_shape[0]
            height = feature_shape[1] 
            width = feature_shape[2]
            channels = feature_shape[3]
            
            # Efficiently handle mask dimensions
            if len(mask.shape) == 5:
                mask = tf.squeeze(mask, axis=[-1, -2])
            elif len(mask.shape) == 4 and mask.shape[-1] == 1:
                mask = tf.squeeze(mask, axis=-1)
            
            # GPU-optimized resize with bilinear interpolation
            mask_resized = tf.image.resize(
                tf.expand_dims(mask, axis=-1), 
                [height, width], 
                method='bilinear',
                antialias=True
            )
            
            # Efficient broadcasting using tf.broadcast_to
            mask_broadcasted = tf.broadcast_to(
                mask_resized, 
                [batch_size, height, width, channels]
            )
            
            # Element-wise multiplication (GPU-accelerated)
            masked_features = tf.multiply(feature_maps, mask_broadcasted)
            
        return masked_features
    
    def get_config(self):
        config = super(RetinaMaskLayer, self).get_config()
        return config

@tf.function
def create_retina_mask_gpu(image_tensor):
    """GPU-accelerated retina mask creation using TensorFlow operations"""
    with tf.device('/GPU:0' if HAS_GPU else '/CPU:0'):
        # Convert to grayscale using weighted sum (faster than cv2)
        if len(image_tensor.shape) == 3 and image_tensor.shape[-1] == 3:
            # Standard RGB to grayscale weights
            weights = tf.constant([0.299, 0.587, 0.114], dtype=tf.float32)
            gray = tf.reduce_sum(image_tensor * weights, axis=-1)
        else:
            gray = tf.squeeze(image_tensor)
        
        # Normalize to 0-255 range
        gray = tf.cast(gray, tf.float32)
        if tf.reduce_max(gray) <= 1.0:
            gray = gray * 255.0
        
        # GPU-optimized Gaussian blur using separable filters
        kernel_size = 5
        sigma = 1.0
        
        # Create 1D Gaussian kernel
        x = tf.range(-kernel_size//2 + 1, kernel_size//2 + 1, dtype=tf.float32)
        kernel_1d = tf.exp(-0.5 * tf.square(x / sigma))
        kernel_1d = kernel_1d / tf.reduce_sum(kernel_1d)
        
        # Reshape for convolution
        kernel_1d = tf.reshape(kernel_1d, [kernel_size, 1, 1, 1])
        
        # Apply separable Gaussian blur
        gray_expanded = tf.expand_dims(tf.expand_dims(gray, 0), -1)
        blurred = tf.nn.conv2d(gray_expanded, kernel_1d, strides=[1,1,1,1], padding='SAME')
        kernel_1d_t = tf.transpose(kernel_1d, [1, 0, 2, 3])
        blurred = tf.nn.conv2d(blurred, kernel_1d_t, strides=[1,1,1,1], padding='SAME')
        blurred = tf.squeeze(blurred, [0, 3])
        
        # Thresholding (GPU-accelerated)
        threshold = 20.0
        binary_mask = tf.cast(blurred > threshold, tf.float32)
        
        # Morphological operations using GPU-optimized erosion/dilation
        kernel = tf.ones([7, 7], dtype=tf.float32)
        kernel_expanded = tf.expand_dims(tf.expand_dims(kernel, -1), -1)
        
        # Morphological close operation
        mask_expanded = tf.expand_dims(tf.expand_dims(binary_mask, 0), -1)
        
        # Dilation followed by erosion (morphological close)
        dilated = tf.nn.dilation2d(
            mask_expanded, kernel_expanded, 
            strides=[1,1,1,1], padding='SAME', data_format='NHWC', dilations=[1,1,1,1]
        )
        closed = tf.nn.erosion2d(
            dilated, kernel_expanded,
            strides=[1,1,1,1], padding='SAME', data_format='NHWC', dilations=[1,1,1,1]
        )
        
        # Erosion followed by dilation (morphological open)
        eroded = tf.nn.erosion2d(
            closed, kernel_expanded,
            strides=[1,1,1,1], padding='SAME', data_format='NHWC', dilations=[1,1,1,1]
        )
        opened = tf.nn.dilation2d(
            eroded, kernel_expanded,
            strides=[1,1,1,1], padding='SAME', data_format='NHWC', dilations=[1,1,1,1]
        )
        
        mask = tf.squeeze(opened, [0, 3])
        
        # Create circular fallback mask if needed
        h, w = tf.shape(mask)[0], tf.shape(mask)[1]
        center_y, center_x = h // 2, w // 2
        radius = tf.minimum(h, w) // 3
        
        # Create coordinate grids
        y_coords = tf.range(h, dtype=tf.float32)
        x_coords = tf.range(w, dtype=tf.float32)
        y_grid, x_grid = tf.meshgrid(y_coords, x_coords, indexing='ij')
        
        # Calculate distance from center
        distances = tf.sqrt(
            tf.square(y_grid - tf.cast(center_y, tf.float32)) + 
            tf.square(x_grid - tf.cast(center_x, tf.float32))
        )
        
        circular_mask = tf.cast(distances <= tf.cast(radius, tf.float32), tf.float32)
        
        # Use processed mask if it has substantial content, otherwise use circular mask
        mask_area = tf.reduce_sum(mask)
        total_area = tf.cast(h * w, tf.float32)
        mask_ratio = mask_area / total_area
        
        final_mask = tf.cond(
            mask_ratio > 0.1,  # If mask covers more than 10% of image
            lambda: mask,
            lambda: circular_mask
        )
        
        return final_mask

# GPU-optimized data pipeline
def create_optimized_dataset_with_masks(dataframe, batch_size, shuffle=False, cache=True):
    """Create highly optimized TensorFlow dataset with GPU acceleration"""
    
    @tf.function
    def load_and_preprocess_with_mask(path, label):
        # GPU-accelerated image loading and preprocessing
        with tf.device('/GPU:0' if HAS_GPU else '/CPU:0'):
            # Load image
            image_raw = tf.io.read_file(path)
            image_decoded = tf.image.decode_jpeg(image_raw, channels=3)  # Changed to decode_jpeg for better compatibility
            image_resized = tf.image.resize(
                image_decoded, 
                [IMG_SIZE, IMG_SIZE],
                method='bilinear',
                antialias=True
            )
            image_float = tf.cast(image_resized, tf.float32)
            
            # DenseNet preprocessing
            image_preprocessed = tf.keras.applications.densenet.preprocess_input(image_float)
            
            # Generate mask using GPU-optimized function
            mask = create_retina_mask_gpu(image_resized)
            mask = tf.expand_dims(mask, axis=-1)
            
            return {'image': image_preprocessed, 'mask': mask}, label
    
    # Create optimized dataset pipeline
    paths = tf.constant(dataframe['filepath'].values)
    labels = tf.constant(dataframe['label'].values, dtype=tf.int32)
    
    dataset = tf.data.Dataset.from_tensor_slices((paths, labels))
    
    # Optimize dataset pipeline
    dataset = dataset.map(
        load_and_preprocess_with_mask, 
        num_parallel_calls=NUM_PARALLEL_CALLS,
        deterministic=False  # Allow non-deterministic order for better performance
    )
    
    if cache:
        dataset = dataset.cache()  # Cache processed data in memory
    
    if shuffle:
        # Use larger shuffle buffer for better randomization
        buffer_size = min(len(dataframe), 10000)
        dataset = dataset.shuffle(buffer_size, seed=42, reshuffle_each_iteration=True)
    
    # Batch and prefetch optimizations
    dataset = dataset.batch(batch_size, drop_remainder=False)
    dataset = dataset.prefetch(PREFETCH_BUFFER)
    
    # Enable optimization
    options = tf.data.Options()
    options.experimental_optimization.map_paralization = True
    options.experimental_optimization.parallel_batch = True
    options.threading.private_threadpool_size = 8
    dataset = dataset.with_options(options)
    
    return dataset

def create_retina_focused_densenet_gpu():
    """Create GPU-optimized DenseNet model with retina masking"""
    tf.keras.backend.clear_session()
    gc.collect()  # Clear memory
    
    with tf.device('/GPU:0' if HAS_GPU else '/CPU:0'):
        # Input layers
        image_input = tf.keras.Input(shape=(IMG_SIZE, IMG_SIZE, 3), name='image')
        mask_input = tf.keras.Input(shape=(IMG_SIZE, IMG_SIZE, 1), name='mask')
        
        # Base DenseNet121 with optimizations
        base_model = DenseNet121(
            weights="imagenet", 
            include_top=False,
            input_shape=(IMG_SIZE, IMG_SIZE, 3),
            pooling=None
        )
        
        # Freeze base model initially
        base_model.trainable = False
        
        # Get feature maps
        feature_maps = base_model(image_input, training=False)
        
        # Apply retina mask
        masked_features = RetinaMaskLayer()([feature_maps, mask_input])
        
        # Enhanced attention mechanism with spatial attention
        # Channel attention
        channel_attention = tf.keras.layers.GlobalAveragePooling2D()(masked_features)
        channel_attention = tf.keras.layers.Dense(
            masked_features.shape[-1] // 8, 
            activation='relu',
            dtype='float32'
        )(channel_attention)
        channel_attention = tf.keras.layers.Dense(
            masked_features.shape[-1], 
            activation='sigmoid',
            dtype='float32'
        )(channel_attention)
        channel_attention = tf.keras.layers.Reshape((1, 1, masked_features.shape[-1]))(channel_attention)
        
        # Apply channel attention
        channel_refined = tf.keras.layers.Multiply()([masked_features, channel_attention])
        
        # Spatial attention
        spatial_attention = tf.keras.layers.Conv2D(
            1, (7, 7), 
            padding='same',
            activation='sigmoid', 
            dtype='float32',
            name='spatial_attention'
        )(channel_refined)
        
        # Apply spatial attention
        attended_features = tf.keras.layers.Multiply()([channel_refined, spatial_attention])
        
        # Advanced pooling strategy
        # Global Average Pooling
        gap = tf.keras.layers.GlobalAveragePooling2D()(attended_features)
        
        # Global Max Pooling
        gmp = tf.keras.layers.GlobalMaxPooling2D()(attended_features)
        
        # Concatenate different pooling strategies
        x = tf.keras.layers.Concatenate()([gap, gmp])
        
        # Enhanced classification head
        x = tf.keras.layers.Dropout(0.4)(x)
        x = tf.keras.layers.Dense(512, activation="relu", dtype='float32')(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Dropout(0.3)(x)
        x = tf.keras.layers.Dense(256, activation="relu", dtype='float32')(x)
        x = tf.keras.layers.Dropout(0.2)(x)
        
        # Output layer with proper dtype for mixed precision
        outputs = tf.keras.layers.Dense(
            5, 
            activation="softmax", 
            dtype='float32',  # Ensure float32 for final output
            name='predictions'
        )(x)
        
        # Create model
        model = tf.keras.Model(inputs=[image_input, mask_input], outputs=outputs)
    
    # Optimized compiler settings
    if HAS_GPU:
        optimizer = tf.keras.optimizers.Adam(
            learning_rate=0.0001,
            beta_1=0.9,
            beta_2=0.999,
            epsilon=1e-7,
            clipnorm=1.0  # Gradient clipping for mixed precision
        )
    else:
        optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
    
    model.compile(
        loss="sparse_categorical_crossentropy",
        optimizer=optimizer,
        metrics=["accuracy"],
        run_eagerly=False  # Ensure graph mode for better GPU performance
    )
    
    return model, base_model

# Enhanced evaluation with GPU optimization
@tf.function
def predict_batch_gpu(model, batch_data):
    """GPU-optimized batch prediction"""
    with tf.device('/GPU:0' if HAS_GPU else '/CPU:0'):
        return model(batch_data, training=False)

def evaluate_model_5_class_gpu(model, test_set, model_name="Model"):
    """GPU-optimized evaluation function"""
    print(f"\nEvaluating {model_name} on GPU...")
    
    y_true, y_pred_probs = [], []
    
    # Use GPU for batch predictions
    with tf.device('/GPU:0' if HAS_GPU else '/CPU:0'):
        for batch_num, (batch_data, labels) in enumerate(test_set):
            print(f"\rEvaluating batch {batch_num + 1}", end='')
            
            y_true.extend(labels.numpy())
            
            # GPU-optimized prediction
            if HAS_GPU:
                predictions = predict_batch_gpu(model, batch_data)
                y_pred_probs.extend(predictions.numpy())
            else:
                predictions = model.predict(batch_data, verbose=0)
                y_pred_probs.extend(predictions)
    
    print()  # New line
    
    y_true = np.array(y_true)
    y_pred_probs = np.array(y_pred_probs)
    y_pred = np.argmax(y_pred_probs, axis=1)
    
    # Calculate metrics
    accuracy = accuracy_score(y_true, y_pred)
    class_names = ['No DR', 'Mild DR', 'Moderate DR', 'Severe DR', 'Proliferative DR']
    
    print(f"\n{model_name} Results (5-Class Classification):")
    print("-" * 60)
    print(f"Overall Accuracy: {accuracy:.4f}")
    
    # Detailed classification report
    print(f"\nDetailed Classification Report:")
    print(classification_report(y_true, y_pred, target_names=class_names, digits=4))
    
    # Confusion Matrix
    conf_matrix = confusion_matrix(y_true, y_pred)
    print(f"\nConfusion Matrix:")
    print("Predicted ->")
    print("True |")
    print("     v")
    
    header = "    " + "".join([f"{i:>12}" for i in range(5)])
    print(header)
    for i, row in enumerate(conf_matrix):
        row_str = f"{i:>2}: " + "".join([f"{val:>12}" for val in row])
        print(row_str)
    
    # Cohen's Kappa
    kappa = cohen_kappa_score(y_true, y_pred)
    print(f"\nCohen's Kappa Score: {kappa:.4f}")
    
    return {
        'overall_accuracy': accuracy,
        'kappa_score': kappa,
        'confusion_matrix': conf_matrix,
        'y_true': y_true,
        'y_pred': y_pred,
        'y_pred_probs': y_pred_probs
    }

# GPU Memory monitoring
def monitor_gpu_memory():
    """Monitor GPU memory usage"""
    if HAS_GPU:
        try:
            gpu_info = tf.config.experimental.get_memory_info('GPU:0')
            current_mb = gpu_info['current'] / (1024**2)
            peak_mb = gpu_info['peak'] / (1024**2)
            print(f"GPU Memory - Current: {current_mb:.1f}MB, Peak: {peak_mb:.1f}MB")
        except:
            print("Could not retrieve GPU memory info")

def load_and_prepare_data(base_dir, train_csv_path, images_folder_name="images"):
    """
    Load and prepare data with new structure where all images are in one folder
    
    Args:
        base_dir: Base directory path
        train_csv_path: Path to the CSV file containing image_id and level information
        images_folder_name: Name of the folder containing all images (default: "images")
    
    Returns:
        prepared DataFrame with filepaths and labels
    """
    print("Loading and preparing dataset with new structure...")
    
    # Load CSV file
    if not os.path.exists(train_csv_path):
        raise FileNotFoundError(f"CSV file not found: {train_csv_path}")
    
    df = pd.read_csv(train_csv_path)
    print(f"Loaded CSV with {len(df)} entries")
    print(f"CSV columns: {list(df.columns)}")
    
    # Check if required columns exist
    required_columns = ['image_id', 'level']  # Updated column names
    for col in required_columns:
        if col not in df.columns:
            # Try alternative column names
            if col == 'image_id' and 'id_code' in df.columns:
                df['image_id'] = df['id_code']
                print(f"Using 'id_code' as 'image_id'")
            elif col == 'level' and 'diagnosis' in df.columns:
                df['level'] = df['diagnosis']
                print(f"Using 'diagnosis' as 'level'")
            else:
                raise ValueError(f"Required column '{col}' not found in CSV. Available columns: {list(df.columns)}")
    
    # Create label column
    df["label"] = df["level"]
    
    # Define class mappings
    class_names_5 = {
        0: "No DR",
        1: "Mild DR", 
        2: "Moderate DR",
        3: "Severe DR",
        4: "Proliferative DR"
    }
    
    df["class_name"] = df["label"].map(class_names_5)
    
    # Build image file paths
    images_dir = os.path.join(base_dir, images_folder_name)
    
    if not os.path.exists(images_dir):
        raise FileNotFoundError(f"Images directory not found: {images_dir}")
    
    print(f"Looking for images in: {images_dir}")
    
    # Try different image extensions
    image_extensions = ['.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG']
    
    def find_image_path(image_id):
        """Find the actual image file with correct extension"""
        for ext in image_extensions:
            filepath = os.path.join(images_dir, f"{image_id}{ext}")
            if os.path.exists(filepath):
                return filepath
        return None
    
    # Create file paths
    print("Mapping image IDs to file paths...")
    df["filepath"] = df["image_id"].apply(find_image_path)
    
    # Check for missing files
    missing_files = df["filepath"].isnull().sum()
    if missing_files > 0:
        print(f"Warning: {missing_files} images not found")
        print("Sample missing image IDs:")
        missing_ids = df[df["filepath"].isnull()]["image_id"].head(10).tolist()
        print(missing_ids)
        
        # List some actual files in the directory for debugging
        actual_files = os.listdir(images_dir)[:10]
        print(f"Sample files in directory: {actual_files}")
        
        # Remove rows with missing files
        df = df.dropna(subset=['filepath']).reset_index(drop=True)
        print(f"Keeping {len(df)} images with valid file paths")
    
    # Verify some files exist
    existing_files = df["filepath"].apply(os.path.exists)
    valid_files = existing_files.sum()
    print(f"Verified {valid_files} out of {len(df)} image files exist")
    
    # Keep only existing files
    df = df[existing_files].reset_index(drop=True)
    
    if len(df) == 0:
        raise ValueError("No valid image files found! Please check your directory structure and file paths.")
    
    print(f"Final dataset shape: {df.shape}")
    print(f"Class distribution:")
    for level in sorted(df['label'].unique()):
        count = (df['label'] == level).sum()
        class_name = class_names_5.get(level, f"Unknown ({level})")
        percentage = (count / len(df)) * 100
        print(f"  {level} ({class_name}): {count} images ({percentage:.1f}%)")
    
    return df

# Updated main training pipeline
def main_training_pipeline():
    """Main training pipeline with updated data loading"""
    print("="*60)
    print("CUDA-OPTIMIZED DIABETIC RETINOPATHY CLASSIFICATION")
    print("="*60)
    
    # GPU Information
    if HAS_GPU:
        print(f"✓ Running on GPU with {len(GPU_DEVICES)} device(s)")
        print(f"✓ Mixed precision: Enabled")
        print(f"✓ XLA compilation: Enabled")
        print(f"✓ Optimized batch size: {BATCH_SIZE}")
    else:
        print("⚠ Running on CPU")
    
    # Updated paths for new structure
    BASE_DIR = r"C:\Users\nande\OneDrive\Desktop\Diabetic_Retinopathy\DataBase"
    train_csv_path = os.path.join(BASE_DIR, "train.csv")  # CSV with image_id and level columns
    images_folder_name = "images"  # Folder name containing all images (adjust as needed)
    
    # Load and prepare data with new structure
    try:
        df = load_and_prepare_data(BASE_DIR, train_csv_path, images_folder_name)
    except Exception as e:
        print(f"Error loading data: {e}")
        print("\nPlease ensure:")
        print("1. CSV file exists and contains 'image_id' and 'level' columns")
        print("2. Images folder contains the actual image files")
        print("3. File paths and names are correct")
        return None
    
    # Train-validation-test split
    print("\nSplitting data into train/validation/test sets...")
    train_df, temp_df = train_test_split(df, test_size=0.4, stratify=df["label"], random_state=42)
    val_df, test_df = train_test_split(temp_df, test_size=0.5, stratify=temp_df["label"], random_state=42)
    
    print(f"Train: {len(train_df)}, Validation: {len(val_df)}, Test: {len(test_df)}")
    
    # Create optimized datasets
    print("Creating GPU-optimized datasets...")
    train_set = create_optimized_dataset_with_masks(train_df, BATCH_SIZE, shuffle=True, cache=True)
    valid_set = create_optimized_dataset_with_masks(val_df, BATCH_SIZE, cache=True)
    test_set = create_optimized_dataset_with_masks(test_df, BATCH_SIZE, cache=True)
    
    # Create GPU-optimized model
    print("Creating GPU-optimized DenseNet model...")
    model, base_model = create_retina_focused_densenet_gpu()
    
    print(f"\nModel Parameters:")
    print(f"Total params: {model.count_params():,}")
    print(f"Trainable params: {sum([tf.keras.backend.count_params(w) for w in model.trainable_weights]):,}")
    
    monitor_gpu_memory()
    
    # Enhanced callbacks with GPU monitoring
    callbacks_list = [
        tf.keras.callbacks.EarlyStopping(
            patience=10, 
            restore_best_weights=True, 
            verbose=1,
            monitor='val_accuracy',
            mode='max'
        ),
        tf.keras.callbacks.ReduceLROnPlateau(
            factor=0.5, 
            patience=5, 
            verbose=1,
            monitor='val_accuracy',
            mode='max',
            min_lr=1e-7
        ),
        tf.keras.callbacks.ModelCheckpoint(
            'best_retina_model_5class_gpu.keras', 
            save_best_only=True, 
            verbose=1,
            monitor='val_accuracy',
            mode='max',
            save_format='keras'
        ),
        tf.keras.callbacks.LambdaCallback(
            on_epoch_end=lambda epoch, logs: monitor_gpu_memory()
        )
    ]
    
    # Phase 1: Training with frozen base
    print(f"\nPhase 1: Training with frozen base model...")
    start_time = time.time()
    
    with tf.device('/GPU:0' if HAS_GPU else '/CPU:0'):
        history1 = model.fit(
            train_set,
            validation_data=valid_set,
            epochs=EPOCHS,
            callbacks=callbacks_list,
            verbose=1,
            workers=4 if not HAS_GPU else 1,  # Reduce workers for GPU
            use_multiprocessing=False  # Disable multiprocessing for GPU
        )
    
    phase1_time = time.time() - start_time
    print(f"Phase 1 training time: {phase1_time:.2f} seconds")
    
    # Clear memory
    gc.collect()
    if HAS_GPU:
        tf.keras.backend.clear_session()
    
    # Evaluate Phase 1
    results1 = evaluate_model_5_class_gpu(model, test_set, "DenseNet + Retina Mask (Frozen)")
    
    # Phase 2: Fine-tuning with GPU optimization
    print(f"\nPhase 2: Fine-tuning with GPU optimization...")
    
    base_model.trainable = True
    n_layers = len(base_model.layers)
    freeze_layers = int(0.8 * n_layers)  # Freeze more layers for stability
    
    for layer in base_model.layers[:freeze_layers]:
        layer.trainable = False
    
    unfrozen_layers = n_layers - freeze_layers
    print(f"Unfroze {unfrozen_layers} out of {n_layers} base model layers")
    
    # Recompile with lower learning rate and gradient clipping
    optimizer = tf.keras.optimizers.Adam(
        learning_rate=5e-6,  # Lower learning rate for fine-tuning
        clipnorm=1.0 if HAS_GPU else None  # Gradient clipping for stability
    )
    
    model.compile(
        loss="sparse_categorical_crossentropy",
        optimizer=optimizer,
        metrics=["accuracy"],
        run_eagerly=False
    )
    
    # Updated callbacks for fine-tuning
    callbacks_list_ft = [
        tf.keras.callbacks.EarlyStopping(
            patience=12, 
            restore_best_weights=True, 
            verbose=1,
            monitor='val_accuracy'
        ),
        tf.keras.callbacks.ReduceLROnPlateau(
            factor=0.3, 
            patience=6, 
            verbose=1,
            min_lr=1e-8
        ),
        tf.keras.callbacks.ModelCheckpoint(
            'best_retina_model_5class_gpu_final.keras', 
            save_best_only=True, 
            verbose=1,
            monitor='val_accuracy',
            mode='max'
        ),
        tf.keras.callbacks.LambdaCallback(
            on_epoch_end=lambda epoch, logs: monitor_gpu_memory()
        )
    ]
    
    start_time = time.time()
    
    with tf.device('/GPU:0' if HAS_GPU else '/CPU:0'):
        history2 = model.fit(
            train_set,
            validation_data=valid_set,
            epochs=EPOCHS,
            callbacks=callbacks_list_ft,
            verbose=1,
            workers=4 if not HAS_GPU else 1,
            use_multiprocessing=False
        )
    
    phase2_time = time.time() - start_time
    print(f"Phase 2 training time: {phase2_time:.2f} seconds")
    
    # Final evaluation
    results2 = evaluate_model_5_class_gpu(model, test_set, "DenseNet + Retina Mask (Fine-tuned)")
    
    # Save final model
    model_path = 'retina_focused_dr_classifier_5class_gpu.keras'
    print(f"\nSaving final model to: {model_path}")
    model.save(model_path, save_format='keras')
    print(f"✓ Model saved successfully!")
    
    # Training summary with GPU info
    total_time = phase1_time + phase2_time
    print(f"\n" + "="*60)
    print("CUDA-OPTIMIZED TRAINING SUMMARY")
    print("="*60)
    if HAS_GPU:
        print(f"GPU Acceleration: ENABLED")
        print(f"Mixed Precision: ENABLED")
        print(f"XLA Compilation: ENABLED")
        monitor_gpu_memory()
    print(f"Phase 1 time: {phase1_time:.2f} seconds")
    print(f"Phase 2 time: {phase2_time:.2f} seconds")
    print(f"Total training time: {total_time:.2f} seconds")
    print(f"Final model accuracy: {results2['overall_accuracy']:.4f}")
    print(f"Final model kappa score: {results2['kappa_score']:.4f}")
    print(f"Batch size used: {BATCH_SIZE}")
    print(f"Model saved as: {model_path}")
    
    # Performance comparison
    if 'results1' in locals() and 'results2' in locals():
        improvement = results2['overall_accuracy'] - results1['overall_accuracy']
        print(f"Accuracy improvement from fine-tuning: {improvement:.4f}")
    
    print("\n✓ CUDA-optimized training completed successfully!")
    
    # Clean up GPU memory
    if HAS_GPU:
        tf.keras.backend.clear_session()
        gc.collect()
    
    print(f"\nTo use the trained model for prediction:")
    print(f"model = tf.keras.models.load_model('{model_path}', custom_objects={{'RetinaMaskLayer': RetinaMaskLayer}})")
    
    return model, train_df, val_df, test_df, history1, history2, results1, results2

def predict_single_image_gpu(model, img_path, img_size=(224, 224)):
    """GPU-optimized single image prediction with visualization"""
    # Load and preprocess image
    img = cv2.imread(img_path)
    if img is None:
        print(f"Error: Could not load image from {img_path}")
        return None, None, None
        
    img = cv2.resize(img, img_size)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    # Create mask using GPU-optimized function
    img_tensor = tf.constant(img_rgb, dtype=tf.float32)
    
    with tf.device('/GPU:0' if HAS_GPU else '/CPU:0'):
        mask = create_retina_mask_gpu(img_tensor)
        
        # Preprocess for model
        img_array = tf.keras.applications.densenet.preprocess_input(
            tf.cast(img_tensor, tf.float32)
        )
        img_array = tf.expand_dims(img_array, axis=0)
        mask_array = tf.expand_dims(tf.expand_dims(mask, axis=-1), axis=0)
        
        # GPU-optimized prediction
        pred = model({'image': img_array, 'mask': mask_array}, training=False)
        pred = pred.numpy()
    
    pred_class = pred.argmax(axis=1)[0]
    confidence = pred.max(axis=1)[0]
    
    class_names = {0: "No DR", 1: "Mild DR", 2: "Moderate DR", 3: "Severe DR", 4: "Proliferative DR"}
    
    print(f"Prediction: {class_names[pred_class]}")
    print(f"Confidence: {confidence:.4f}")
    print(f"All probabilities: {pred[0]}")
    
    # Visualize the image and mask
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 4, 1)
    plt.imshow(img_rgb)
    plt.title('Original Image')
    plt.axis('off')
    
    plt.subplot(1, 4, 2)
    plt.imshow(mask.numpy(), cmap='viridis')
    plt.title('GPU-Generated Mask')
    plt.axis('off')
    
    plt.subplot(1, 4, 3)
    masked_img = img_rgb * np.expand_dims(mask.numpy(), axis=-1)
    plt.imshow(masked_img.astype(np.uint8))
    plt.title('Masked Retina Region')
    plt.axis('off')
    
    plt.subplot(1, 4, 4)
    bars = plt.bar(range(5), pred[0])
    plt.title(f'Prediction Probabilities\n{class_names[pred_class]} ({confidence:.3f})')
    plt.xticks(range(5), ['No DR', 'Mild', 'Mod.', 'Sev.', 'Prolif.'], rotation=45)
    plt.ylabel('Probability')
    
    # Color the predicted class bar differently
    colors = ['lightblue'] * 5
    colors[pred_class] = 'orange'
    for bar, color in zip(bars, colors):
        bar.set_color(color)
    
    plt.tight_layout()
    plt.show()
    
    return pred_class, confidence, pred[0]

def benchmark_gpu_performance(model, test_set, num_batches=10):
    """Benchmark GPU vs CPU performance"""
    print("\nBenchmarking GPU Performance...")
    
    # Take a subset of test data for benchmarking
    test_batches = list(test_set.take(num_batches))
    
    if HAS_GPU:
        # GPU timing
        with tf.device('/GPU:0'):
            start_time = time.time()
            for batch_data, _ in test_batches:
                _ = model(batch_data, training=False)
            gpu_time = time.time() - start_time
        
        # CPU timing for comparison
        with tf.device('/CPU:0'):
            start_time = time.time()
            for batch_data, _ in test_batches:
                _ = model(batch_data, training=False)
            cpu_time = time.time() - start_time
        
        speedup = cpu_time / gpu_time
        print(f"GPU inference time: {gpu_time:.3f} seconds")
        print(f"CPU inference time: {cpu_time:.3f} seconds")
        print(f"GPU speedup: {speedup:.2f}x")
        
        return gpu_time, cpu_time, speedup
    else:
        # CPU only timing
        start_time = time.time()
        for batch_data, _ in test_batches:
            _ = model(batch_data, training=False)
        cpu_time = time.time() - start_time
        
        print(f"CPU inference time: {cpu_time:.3f} seconds")
        return cpu_time, None, None

def create_gpu_training_report(history1, history2, results1, results2, total_time):
    """Generate comprehensive GPU training report"""
    print("\n" + "="*80)
    print("COMPREHENSIVE GPU TRAINING REPORT")
    print("="*80)
    
    # System Information
    print("\n1. SYSTEM CONFIGURATION:")
    print("-" * 40)
    if HAS_GPU:
        for i, gpu in enumerate(GPU_DEVICES):
            print(f"   GPU {i}: {gpu.name}")
        print(f"   Mixed Precision: ENABLED")
        print(f"   XLA Compilation: ENABLED")
        print(f"   Batch Size: {BATCH_SIZE}")
    else:
        print("   Device: CPU")
        print(f"   Batch Size: {BATCH_SIZE}")
    
    print(f"   Image Size: {IMG_SIZE}x{IMG_SIZE}")
    print(f"   Total Training Time: {total_time:.2f} seconds ({total_time/60:.1f} minutes)")
    
    # Training Performance
    print("\n2. TRAINING PERFORMANCE:")
    print("-" * 40)
    
    if history1:
        phase1_epochs = len(history1.history['loss'])
        final_train_acc1 = history1.history['accuracy'][-1]
        final_val_acc1 = history1.history['val_accuracy'][-1]
        
        print(f"   Phase 1 (Frozen Base):")
        print(f"     Epochs Completed: {phase1_epochs}")
        print(f"     Final Training Accuracy: {final_train_acc1:.4f}")
        print(f"     Final Validation Accuracy: {final_val_acc1:.4f}")
    
    if history2:
        phase2_epochs = len(history2.history['loss'])
        final_train_acc2 = history2.history['accuracy'][-1]
        final_val_acc2 = history2.history['val_accuracy'][-1]
        
        print(f"   Phase 2 (Fine-tuned):")
        print(f"     Epochs Completed: {phase2_epochs}")
        print(f"     Final Training Accuracy: {final_train_acc2:.4f}")
        print(f"     Final Validation Accuracy: {final_val_acc2:.4f}")
    
    # Model Performance
    print("\n3. MODEL EVALUATION:")
    print("-" * 40)
    
    if results1:
        print(f"   Phase 1 Results:")
        print(f"     Test Accuracy: {results1['overall_accuracy']:.4f}")
        print(f"     Cohen's Kappa: {results1['kappa_score']:.4f}")
    
    if results2:
        print(f"   Phase 2 Results:")
        print(f"     Test Accuracy: {results2['overall_accuracy']:.4f}")
        print(f"     Cohen's Kappa: {results2['kappa_score']:.4f}")
        
        if results1:
            improvement = results2['overall_accuracy'] - results1['overall_accuracy']
            print(f"     Improvement: +{improvement:.4f}")
    
    # GPU Utilization Analysis
    print("\n4. GPU UTILIZATION ANALYSIS:")
    print("-" * 40)
    
    if HAS_GPU:
        try:
            gpu_info = tf.config.experimental.get_memory_info('GPU:0')
            peak_mb = gpu_info['peak'] / (1024**2)
            print(f"   Peak GPU Memory Usage: {peak_mb:.1f} MB")
        except:
            print("   GPU memory info unavailable")
    else:
        print("   No GPU utilized")
    
    print("\n" + "="*80)

def optimize_gpu_memory():
    """Optimize GPU memory usage"""
    if HAS_GPU:
        tf.keras.backend.clear_session()
        gc.collect()
        try:
            tf.config.experimental.reset_memory_stats('GPU:0')
            print("✓ GPU memory optimized")
        except:
            pass

def validate_data_structure(base_dir, csv_filename="train.csv", images_folder="images"):
    """
    Validate the data structure and provide helpful information
    """
    print("Validating data structure...")
    print("="*50)
    
    # Check base directory
    if not os.path.exists(base_dir):
        print(f"❌ Base directory not found: {base_dir}")
        return False
    
    print(f"✓ Base directory exists: {base_dir}")
    
    # Check CSV file
    csv_path = os.path.join(base_dir, csv_filename)
    if not os.path.exists(csv_path):
        print(f"❌ CSV file not found: {csv_path}")
        print(f"Available files in base directory:")
        for f in os.listdir(base_dir):
            if f.endswith('.csv'):
                print(f"  - {f}")
        return False
    
    print(f"✓ CSV file exists: {csv_path}")
    
    # Check images directory
    images_dir = os.path.join(base_dir, images_folder)
    if not os.path.exists(images_dir):
        print(f"❌ Images directory not found: {images_dir}")
        print(f"Available directories in base directory:")
        for item in os.listdir(base_dir):
            if os.path.isdir(os.path.join(base_dir, item)):
                print(f"  - {item}/")
        return False
    
    print(f"✓ Images directory exists: {images_dir}")
    
    # Analyze CSV structure
    try:
        df = pd.read_csv(csv_path)
        print(f"✓ CSV loaded successfully with {len(df)} rows")
        print(f"✓ CSV columns: {list(df.columns)}")
        
        # Check for required columns
        has_image_id = 'image_id' in df.columns or 'id_code' in df.columns
        has_level = 'level' in df.columns or 'diagnosis' in df.columns
        
        if has_image_id:
            id_col = 'image_id' if 'image_id' in df.columns else 'id_code'
            print(f"✓ Image ID column found: {id_col}")
        else:
            print("❌ No image ID column found (expected 'image_id' or 'id_code')")
            
        if has_level:
            level_col = 'level' if 'level' in df.columns else 'diagnosis'
            print(f"✓ Level column found: {level_col}")
            unique_levels = sorted(df[level_col].unique())
            print(f"✓ Unique levels: {unique_levels}")
        else:
            print("❌ No level column found (expected 'level' or 'diagnosis')")
            
    except Exception as e:
        print(f"❌ Error reading CSV: {e}")
        return False
    
    # Check images in directory
    try:
        image_files = os.listdir(images_dir)
        num_images = len([f for f in image_files if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
        print(f"✓ Found {num_images} image files in directory")
        
        if num_images > 0:
            print(f"Sample image files: {image_files[:5]}")
        else:
            print("❌ No image files found in directory")
            
    except Exception as e:
        print(f"❌ Error reading images directory: {e}")
        return False
    
    print("="*50)
    print("✓ Data structure validation completed")
    return True

# Updated main execution
if __name__ == "__main__":
    # Configuration - UPDATE THESE PATHS FOR YOUR SETUP
    BASE_DIR = r"C:\Users\nande\OneDrive\Desktop\Diabetic_Retinopathy\DataBase"
    CSV_FILENAME = "train.csv"  # Your CSV file name
    IMAGES_FOLDER = "images"    # Your images folder name (all images in one folder)
    
    print("="*60)
    print("UPDATED DIABETIC RETINOPATHY CLASSIFIER")
    print("Supporting new data structure with all images in one folder")
    print("="*60)
    
    # First validate the data structure
    if not validate_data_structure(BASE_DIR, CSV_FILENAME, IMAGES_FOLDER):
        print("\n❌ Data structure validation failed!")
        print("\nPlease ensure:")
        print("1. Update BASE_DIR to your actual data directory")
        print("2. Update CSV_FILENAME to your actual CSV file name")
        print("3. Update IMAGES_FOLDER to your actual images folder name")
        print("4. CSV should contain 'image_id' (or 'id_code') and 'level' (or 'diagnosis') columns")
        print("5. All images should be in the specified images folder")
    else:
        print("\n✓ Data structure validation passed!")
        print("\nStarting training pipeline...")
        
        # Run the main training pipeline
        try:
            results = main_training_pipeline()
            if results:
                model, train_df, val_df, test_df, history1, history2, results1, results2 = results
                
                # Generate comprehensive report
                total_time = 0  # This would be calculated in the actual pipeline
                create_gpu_training_report(history1, history2, results1, results2, total_time)
                
                print("\n" + "="*60)
                print("TRAINING COMPLETED SUCCESSFULLY!")
                print("="*60)
                print("\nNext steps:")
                print("1. Use predict_single_image_gpu() for single image predictions")
                print("2. Use benchmark_gpu_performance() to test inference speed")
                print("3. Load the saved model for deployment")
                print("="*60)
            
        except Exception as e:
            print(f"\n❌ Training failed with error: {e}")
            print("\nPlease check your data paths and try again.")

def demo_usage_examples():
    """Demonstrate usage examples for the updated classifier"""
    print("\n" + "="*60)
    print("USAGE EXAMPLES FOR UPDATED CLASSIFIER")
    print("="*60)
    
    print("\n1. VALIDATE YOUR DATA STRUCTURE:")
    print("-" * 40)
    print("""
# Update these paths for your setup
BASE_DIR = r"your/path/to/database"
CSV_FILENAME = "your_labels.csv"
IMAGES_FOLDER = "your_images_folder"

# Validate structure
validate_data_structure(BASE_DIR, CSV_FILENAME, IMAGES_FOLDER)
    """)
    
    print("\n2. TRAIN THE MODEL:")
    print("-" * 40)
    print("""
# Run the main training pipeline
results = main_training_pipeline()
    """)
    
    print("\n3. MAKE PREDICTIONS:")
    print("-" * 40)
    print("""
# Load trained model
model = tf.keras.models.load_model(
    'retina_focused_dr_classifier_5class_gpu.keras',
    custom_objects={'RetinaMaskLayer': RetinaMaskLayer}
)

# Predict single image
pred_class, confidence, probabilities = predict_single_image_gpu(
    model, 
    'path/to/your/image.jpg'
)
    """)
    
    print("\n4. BENCHMARK PERFORMANCE:")
    print("-" * 40)
    print("""
# Benchmark GPU vs CPU performance
gpu_time, cpu_time, speedup = benchmark_gpu_performance(model, test_set)
    """)
    
    print("\n" + "="*60)

print("\n" + "="*80)
print("UPDATED CUDA-OPTIMIZED DIABETIC RETINOPATHY CLASSIFIER LOADED")
print("="*80)
print("New Features for Updated Data Structure:")
print("• Flexible data loading from single images folder")
print("• Automatic image file extension detection")
print("• CSV-based label mapping with image_id")
print("• Data structure validation and debugging")
print("• Support for different CSV column names")
print("• Comprehensive error handling and guidance")
print("="*80)
print("\nIMPORTANT: Update the paths in the main execution section!")
print("- BASE_DIR: Your database directory path")
print("- CSV_FILENAME: Your CSV file name (with image_id and level columns)")
print("- IMAGES_FOLDER: Your images folder name (containing all images)")
print("="*80)