In [42]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, roc_curve
from sklearn.preprocessing import RobustScaler, LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D, Concatenate, Dropout, RandomFlip, RandomRotation, RandomZoom, RandomContrast
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, LearningRateScheduler
from tensorflow.keras import regularizers
import tensorflow as tf
from typing import Tuple, Dict, List
import pandas as pd
import time
from concurrent.futures import ThreadPoolExecutor

# Enhanced configuration
class Config:
    SEED = 42
    IMAGE_SIZE = (224, 224)
    BATCH_SIZE = 32
    INITIAL_EPOCHS = 15
    FINE_TUNE_EPOCHS = 10
    DATA_DIR = '/Users/adyasha/Downloads/thesis_data/'
    OUTPUT_DIR = os.path.join(DATA_DIR, 'outputs')  # New output directory
    CLINICAL_FEATURES = [
        'vertical_cdr', 'horizontal_cdr', 'area_cdr', 'rim_ratio',
        'inferior_rim', 'superior_rim', 'nasal_rim', 'temporal_rim',
        'mean_pallor', 'max_pallor', 'pallor_asymmetry',
        'vessel_density', 'vessel_tortuosity', 'istn_compliance'
    ]
    AUGMENTATION_PARAMS = {
        'flip': True,
        'rotation': 0.1,
        'zoom': 0.1,
        'contrast': (0.9, 1.1),
        'brightness': 0.1
    }
    REGULARIZATION_L2 = 0.001
    DROPOUT_RATE = 0.5
    FINE_TUNE_AT = 100
    FEATURE_LENGTH = len(CLINICAL_FEATURES)  # Ensure consistent feature length

# Create output directory if it doesn't exist
os.makedirs(Config.OUTPUT_DIR, exist_ok=True)

# Set random seeds
np.random.seed(Config.SEED)
tf.random.set_seed(Config.SEED)

# Enhanced Optic Disc Analyzer
class OpticDiscAnalyzer:
    """Optimized optic disc feature extraction with caching and shape consistency"""
    
    def __init__(self):
        self.cache = {}
        
    def analyze_image(self, img: np.ndarray) -> Dict[str, float]:
        """Consistent feature extraction"""
        if img is None or img.size == 0:
            return {k: 0.0 for k in Config.CLINICAL_FEATURES}
        
        # Use image hash for caching
        img_hash = hash(img.tobytes())
        if img_hash in self.cache:
            return self.cache[img_hash]
        
        # Preprocess image
        if len(img.shape) == 2:
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
        img = cv2.resize(img, Config.IMAGE_SIZE)
        
        # Simulated feature extraction - ensure consistent output
        features = {
            'vertical_cdr': float(np.random.uniform(0.3, 0.8)),
            'horizontal_cdr': float(np.random.uniform(0.3, 0.8)),
            'area_cdr': float(np.random.uniform(0.3, 0.8)),
            'rim_ratio': float(np.random.uniform(0.1, 0.5)),
            'inferior_rim': float(np.random.uniform(0.05, 0.2)),
            'superior_rim': float(np.random.uniform(0.05, 0.2)),
            'nasal_rim': float(np.random.uniform(0.05, 0.2)),
            'temporal_rim': float(np.random.uniform(0.05, 0.2)),
            'mean_pallor': float(np.random.uniform(100, 200)),
            'max_pallor': float(np.random.uniform(150, 250)),
            'pallor_asymmetry': float(np.random.uniform(0.1, 0.3)),
            'vessel_density': float(np.random.uniform(0.1, 0.4)),
            'vessel_tortuosity': float(np.random.uniform(1.0, 1.5)),
            'istn_compliance': float(np.random.choice([0, 1]))
        }
        
        self.cache[img_hash] = features
        return features

# Robust Data Generator
class GlaucomaDataGenerator(tf.keras.utils.Sequence):
    def __init__(self, directory: str, analyzer: OpticDiscAnalyzer, 
                 shuffle: bool = True, augment: bool = False):
        self.directory = directory
        self.analyzer = analyzer
        self.shuffle = shuffle
        self.augment = augment
        self.class_folders = sorted(os.listdir(directory))
        self.class_encoder = LabelEncoder()
        self.file_paths = self._get_file_paths()
        self.scaler = RobustScaler()
        self.indexes = np.arange(len(self.file_paths))
        self.on_epoch_end()
        self.precomputed_features = self._precompute_features_parallel()
    
    def _get_file_paths(self) -> List[Tuple[str, int]]:
        paths = []
        for class_name in self.class_folders:
            class_dir = os.path.join(self.directory, class_name)
            if not os.path.isdir(class_dir):
                continue
                
            for filename in os.listdir(class_dir):
                if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
                    file_path = os.path.join(class_dir, filename)
                    if os.path.isfile(file_path):
                        paths.append((file_path, class_name))
        
        # Encode class labels
        all_labels = [label for _, label in paths]
        if not all_labels:
            raise ValueError(f"No valid images found in directory: {self.directory}")
            
        self.class_encoder.fit(all_labels)
        self.classes = self.class_encoder.classes_
        
        # Convert to numerical labels
        paths = [(path, self.class_encoder.transform([label])[0]) for path, label in paths]
        return paths
    
    def _process_image(self, file_path: str) -> np.ndarray:
        """Process a single image file with error handling and shape validation"""
        try:
            img = cv2.imread(file_path)
            if img is None:
                print(f"Warning: Could not load image {file_path}. Using zeros.")
                return np.zeros(Config.FEATURE_LENGTH, dtype=np.float32)
                
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            features_dict = self.analyzer.analyze_image(img)
            
            # Ensure consistent feature order and length
            features = np.array([features_dict[k] for k in Config.CLINICAL_FEATURES], dtype=np.float32)
            
            if features.shape != (Config.FEATURE_LENGTH,):
                print(f"Warning: Invalid feature shape {features.shape} for {file_path}. Using zeros.")
                return np.zeros(Config.FEATURE_LENGTH, dtype=np.float32)
                
            return features
        except Exception as e:
            print(f"Error processing {file_path}: {str(e)}")
            return np.zeros(Config.FEATURE_LENGTH, dtype=np.float32)
    
    def _precompute_features_parallel(self) -> Dict[str, np.ndarray]:
        """Precompute features using parallel processing with shape validation"""
        features = {}
        print(f"Precomputing features for {len(self.file_paths)} images...")
        start_time = time.time()
        
        # Use ThreadPoolExecutor for parallel processing
        with ThreadPoolExecutor(max_workers=min(8, os.cpu_count())) as executor:
            futures = {}
            for file_path, _ in self.file_paths:
                futures[file_path] = executor.submit(self._process_image, file_path)
            
            for file_path, future in futures.items():
                features[file_path] = future.result()
        
        # Validate all features have consistent shape
        all_features = np.array(list(features.values()))
        if all_features.ndim != 2 or all_features.shape[1] != Config.FEATURE_LENGTH:
            raise ValueError(f"Feature array has invalid shape {all_features.shape}. Expected (n, {Config.FEATURE_LENGTH})")
        
        self.scaler.fit(all_features)
        
        elapsed = time.time() - start_time
        print(f"Feature computation completed in {elapsed:.1f} seconds")
        return features
    
    def __len__(self) -> int:
        return int(np.ceil(len(self.file_paths) / Config.BATCH_SIZE))
    
    def __getitem__(self, idx: int) -> Tuple[Tuple[np.ndarray, np.ndarray], np.ndarray]:
        batch_indices = self.indexes[idx*Config.BATCH_SIZE:(idx+1)*Config.BATCH_SIZE]
        batch_paths = [self.file_paths[i] for i in batch_indices]
        
        images = []
        clinical_features = []
        diagnosis_labels = []
        
        for file_path, label in batch_paths:
            img = self._load_and_preprocess(file_path)
            images.append(img)
            clinical_features.append(self.precomputed_features[file_path])
            diagnosis_labels.append(label)
        
        # Convert to arrays with consistent shapes
        images = np.stack(images, axis=0)
        clinical_features = np.stack(clinical_features, axis=0)
        clinical_features = self.scaler.transform(clinical_features)
        diagnosis_labels = np.array(diagnosis_labels).reshape(-1, 1)  # Ensure 2D shape
        
        return (images, clinical_features), diagnosis_labels
    
    def _augment_image(self, img: np.ndarray) -> np.ndarray:
        """Apply data augmentation while maintaining shape consistency"""
        # Random horizontal flip
        if Config.AUGMENTATION_PARAMS['flip'] and np.random.rand() > 0.5:
            img = np.fliplr(img)
        
        # Random rotation
        if Config.AUGMENTATION_PARAMS['rotation'] > 0 and np.random.rand() > 0.5:
            angle = np.random.uniform(-Config.AUGMENTATION_PARAMS['rotation'] * 180, 
                                     Config.AUGMENTATION_PARAMS['rotation'] * 180)
            rows, cols = img.shape[:2]
            M = cv2.getRotationMatrix2D((cols/2, rows/2), angle, 1)
            img = cv2.warpAffine(img, M, (cols, rows), borderMode=cv2.BORDER_REFLECT)
        
        # Random zoom with safe resizing
        if Config.AUGMENTATION_PARAMS['zoom'] > 0 and np.random.rand() > 0.5:
            zoom_factor = 1 + np.random.uniform(-Config.AUGMENTATION_PARAMS['zoom'], 
                                               Config.AUGMENTATION_PARAMS['zoom'])
            h, w = img.shape[:2]
            new_h, new_w = int(h * zoom_factor), int(w * zoom_factor)
            img = cv2.resize(img, (new_w, new_h))
            # Resize back to original dimensions
            img = cv2.resize(img, (w, h))
        
        # Random contrast
        if Config.AUGMENTATION_PARAMS['contrast'] and np.random.rand() > 0.5:
            alpha = np.random.uniform(Config.AUGMENTATION_PARAMS['contrast'][0],
                                     Config.AUGMENTATION_PARAMS['contrast'][1])
            img = np.clip(alpha * img, 0, 1)
        
        # Random brightness
        if Config.AUGMENTATION_PARAMS['brightness'] > 0 and np.random.rand() > 0.5:
            beta = np.random.uniform(-Config.AUGMENTATION_PARAMS['brightness'],
                                    Config.AUGMENTATION_PARAMS['brightness'])
            img = np.clip(img + beta, 0, 1)
        
        return img
    
    def _load_and_preprocess(self, file_path: str) -> np.ndarray:
        """Load and preprocess image with shape validation"""
        img = cv2.imread(file_path)
        if img is None:
            print(f"Warning: Could not load image {file_path}. Using zeros.")
            img = np.zeros((*Config.IMAGE_SIZE, 3), dtype=np.uint8)
        else:
            img = cv2.resize(img, Config.IMAGE_SIZE)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # Convert to float and normalize
        img = img.astype(np.float32) / 255.0
        
        if self.augment:
            img = self._augment_image(img)
            
        # Ensure consistent shape
        if img.shape != (*Config.IMAGE_SIZE, 3):
            # Force correct shape
            img = cv2.resize(img, Config.IMAGE_SIZE)
            if len(img.shape) == 2:
                img = np.stack([img, img, img], axis=-1)
            elif img.shape[2] > 3:
                img = img[:, :, :3]
            elif img.shape[2] == 1:
                img = np.stack([img.squeeze()]*3, axis=-1)
                
        return img
    
    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indexes)
    
    def get_class_distribution(self):
        counts = {class_name: 0 for class_name in self.classes}
        for _, label in self.file_paths:
            class_name = self.class_encoder.inverse_transform([label])[0]
            counts[class_name] += 1
        return counts
    
    def get_diagnosis_labels(self):
        return [label for _, label in self.file_paths]

# Enhanced Model Architecture (Single Output)
def build_glaucomanet(trainable_base: bool = False) -> tf.keras.Model:
    # Image augmentation layers
    augmentation = tf.keras.Sequential([
        RandomFlip("horizontal"),
        RandomRotation(Config.AUGMENTATION_PARAMS['rotation']),
        RandomZoom(Config.AUGMENTATION_PARAMS['zoom']),
        RandomContrast(Config.AUGMENTATION_PARAMS['brightness']),
    ], name='augmentation')
    
    # Image input pipeline
    img_input = Input(shape=(*Config.IMAGE_SIZE, 3), name='image_input')
    augmented = augmentation(img_input)
    
    # Base model
    base_model = tf.keras.applications.MobileNetV2(
        include_top=False, 
        weights='imagenet', 
        input_tensor=augmented
    )
    base_model.trainable = trainable_base
    x_img = GlobalAveragePooling2D()(base_model.output)
    
    # Clinical features input
    clinical_input = Input(shape=(Config.FEATURE_LENGTH,), name='clinical_input')
    
    # Feature fusion
    fused_features = Concatenate()([x_img, clinical_input])
    
    # Classification head with regularization
    x = Dense(128, activation='relu', 
              kernel_regularizer=regularizers.l2(Config.REGULARIZATION_L2))(fused_features)
    x = Dropout(Config.DROPOUT_RATE)(x)
    
    # Single output for diagnosis
    diagnosis_output = Dense(1, activation='sigmoid', name='diagnosis')(x)
    
    return Model(
        inputs=[img_input, clinical_input],
        outputs=diagnosis_output
    ), base_model

# Learning rate scheduler
def lr_schedule(epoch):
    """Learning rate schedule with warmup and decay"""
    warmup_epochs = 3
    initial_lr = 1e-4
    decay_factor = 0.1
    
    if epoch < warmup_epochs:
        return initial_lr * (epoch + 1) / warmup_epochs
    elif epoch < Config.INITIAL_EPOCHS:
        return initial_lr
    else:
        return initial_lr * (decay_factor ** ((epoch - Config.INITIAL_EPOCHS) // 4))

# Training and Evaluation
class GlaucomaTrainer:
    def __init__(self):
        self.analyzer = OpticDiscAnalyzer()
        self.model, self.base_model = build_glaucomanet(trainable_base=False)
        self.output_dir = Config.OUTPUT_DIR
        
    def get_data_generators(self):
        train_gen = GlaucomaDataGenerator(
            os.path.join(Config.DATA_DIR, 'train'), 
            self.analyzer,
            augment=True
        )
        val_gen = GlaucomaDataGenerator(
            os.path.join(Config.DATA_DIR, 'validate'), 
            self.analyzer, 
            shuffle=False
        )
        test_gen = GlaucomaDataGenerator(
            os.path.join(Config.DATA_DIR, 'test'), 
            self.analyzer, 
            shuffle=False
        )
        return train_gen, val_gen, test_gen
    
    def compute_class_weights(self, train_gen):
        y_train = train_gen.get_diagnosis_labels()
        unique_classes = np.unique(y_train)
        if len(unique_classes) < 2:
            print("Warning: Only one class present in training data")
            return {0: 1.0, 1: 1.0}
            
        class_weights = compute_class_weight('balanced', classes=unique_classes, y=y_train)
        return {i: weight for i, weight in enumerate(class_weights)}
    
    def plot_class_distribution(self, train_gen, val_gen, test_gen):
        fig, axes = plt.subplots(1, 3, figsize=(18, 5))
        
        train_dist = train_gen.get_class_distribution()
        axes[0].bar(train_dist.keys(), train_dist.values())
        axes[0].set_title('Train Distribution')
        axes[0].set_ylabel('Count')
        
        val_dist = val_gen.get_class_distribution()
        axes[1].bar(val_dist.keys(), val_dist.values())
        axes[1].set_title('Validation Distribution')
        
        test_dist = test_gen.get_class_distribution()
        axes[2].bar(test_dist.keys(), test_dist.values())
        axes[2].set_title('Test Distribution')
        
        plt.tight_layout()
        output_path = os.path.join(self.output_dir, 'class_distribution.png')
        plt.savefig(output_path)
        plt.close()
    
    def plot_feature_distribution(self, gen, name):
        features = []
        for i in range(len(gen)):
            (_, clinical), _ = gen[i]
            features.append(clinical)
            
        features = np.vstack(features)
        df = pd.DataFrame(features, columns=Config.CLINICAL_FEATURES)
        
        plt.figure(figsize=(15, 10))
        df.boxplot()
        plt.title(f'Clinical Feature Distributions - {name}')
        plt.xticks(rotation=45)
        plt.tight_layout()
        output_path = os.path.join(self.output_dir, f'feature_distribution_{name}.png')
        plt.savefig(output_path)
        plt.close()
        
        plt.figure(figsize=(12, 10))
        sns.heatmap(df.corr(), annot=True, cmap='coolwarm', fmt='.2f')
        plt.title(f'Feature Correlation Matrix - {name}')
        plt.tight_layout()
        output_path = os.path.join(self.output_dir, f'feature_correlation_{name}.png')
        plt.savefig(output_path)
        plt.close()
    
    def train(self):
        train_gen, val_gen, _ = self.get_data_generators()
        
        self.plot_class_distribution(train_gen, val_gen, val_gen)
        self.plot_feature_distribution(train_gen, 'train')
        self.plot_feature_distribution(val_gen, 'validation')
        
        # Compute class weights for imbalanced data
        class_weights = self.compute_class_weights(train_gen)
        print(f"Class weights: {class_weights}")
        
        # Phase 1: Train the head
        print("\n=== Phase 1: Training head ===")
        self.model.compile(
            optimizer=Adam(1e-4),
            loss='binary_crossentropy',
            metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
        )
        
        callbacks = [
            EarlyStopping(
                monitor='val_auc',
                patience=5,
                restore_best_weights=True,
                mode='max',
                verbose=1
            ),
            ModelCheckpoint(
                os.path.join(self.output_dir, 'best_model_phase1.weights.h5'),
                save_best_only=True,
                monitor='val_auc',
                mode='max',
                save_weights_only=True
            ),
            ReduceLROnPlateau(
                monitor='val_auc',
                factor=0.5,
                patience=2,
                min_lr=1e-6,
                mode='max',
                verbose=1
            ),
            LearningRateScheduler(lr_schedule)
        ]
        
        history_phase1 = self.model.fit(
            train_gen,
            validation_data=val_gen,
            epochs=Config.INITIAL_EPOCHS,
            callbacks=callbacks,
            verbose=1,
            class_weight=class_weights
        )
        
        # Phase 2: Fine-tuning
        print("\n=== Phase 2: Fine-tuning ===")
        self.base_model.trainable = True
        
        # Freeze layers before FINE_TUNE_AT
        for layer in self.base_model.layers[:Config.FINE_TUNE_AT]:
            layer.trainable = False
            
        print(f"Number of trainable layers: {sum([1 for layer in self.base_model.layers if layer.trainable])}")
        
        self.model.compile(
            optimizer=Adam(1e-5),  # Lower learning rate
            loss='binary_crossentropy',
            metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
        )
        
        callbacks_fine = [
            EarlyStopping(
                monitor='val_auc',
                patience=5,
                restore_best_weights=True,
                mode='max',
                verbose=1
            ),
            ModelCheckpoint(
                os.path.join(self.output_dir, 'best_model_final.weights.h5'),
                save_best_only=True,
                monitor='val_auc',
                mode='max',
                save_weights_only=True
            ),
            ReduceLROnPlateau(
                monitor='val_auc',
                factor=0.5,
                patience=2,
                min_lr=1e-7,
                mode='max',
                verbose=1
            )
        ]
        
        history_fine = self.model.fit(
            train_gen,
            validation_data=val_gen,
            epochs=Config.INITIAL_EPOCHS + Config.FINE_TUNE_EPOCHS,
            initial_epoch=history_phase1.epoch[-1] + 1,
            callbacks=callbacks_fine,
            verbose=1,
            class_weight=class_weights
        )
        
        # Combine histories
        history = {
            'accuracy': history_phase1.history['accuracy'] + history_fine.history['accuracy'],
            'val_accuracy': history_phase1.history['val_accuracy'] + history_fine.history['val_accuracy'],
            'auc': history_phase1.history['auc'] + history_fine.history['auc'],
            'val_auc': history_phase1.history['val_auc'] + history_fine.history['val_auc'],
            'loss': history_phase1.history['loss'] + history_fine.history['loss'],
            'val_loss': history_phase1.history['val_loss'] + history_fine.history['val_loss']
        }
        
        self.plot_training_history(history)
        return history
    
    def plot_training_history(self, history):
        plt.figure(figsize=(12, 10))
        
        plt.subplot(2, 2, 1)
        plt.plot(history['accuracy'], label='Train')
        plt.plot(history['val_accuracy'], label='Validation')
        plt.title('Diagnosis Accuracy')
        plt.ylabel('Accuracy')
        plt.xlabel('Epoch')
        plt.legend()
        
        plt.subplot(2, 2, 2)
        plt.plot(history['auc'], label='Train')
        plt.plot(history['val_auc'], label='Validation')
        plt.title('Diagnosis AUC')
        plt.ylabel('AUC')
        plt.xlabel('Epoch')
        plt.legend()
        
        plt.subplot(2, 2, 3)
        plt.plot(history['loss'], label='Train')
        plt.plot(history['val_loss'], label='Validation')
        plt.title('Loss')
        plt.ylabel('Loss')
        plt.xlabel('Epoch')
        plt.legend()
        
        plt.tight_layout()
        output_path = os.path.join(self.output_dir, 'training_history.png')
        plt.savefig(output_path)
        plt.close()
    
    def evaluate(self):
        _, _, test_gen = self.get_data_generators()
        
        print("\nEvaluating model...")
        test_dist = test_gen.get_class_distribution()
        print("Test set class distribution:", test_dist)
        
        # Load best weights
        self.model.load_weights(os.path.join(self.output_dir, 'best_model_final.weights.h5'))
        
        results = self.model.evaluate(test_gen, verbose=0)
        print("\nTest Results:")
        print(f"Loss: {results[0]:.4f}")
        print(f"Accuracy: {results[1]:.4f}")
        print(f"AUC: {results[2]:.4f}")
        
        y_true = []
        y_pred = []
        
        for i in range(len(test_gen)):
            (images, clinical), labels = test_gen[i]
            batch_pred = self.model.predict([images, clinical], verbose=0)
            y_true.extend(labels.flatten().tolist())
            y_pred.extend(batch_pred.flatten().tolist())
        
        y_true = np.array(y_true)
        y_pred = np.array(y_pred)
        
        # Convert labels to 0 and 1
        unique_vals = np.unique(y_true)
        print(f"Unique values in y_true: {unique_vals}")
        
        if set(unique_vals) == {1, 2}:
            print("Adjusting labels: subtracting 1")
            y_true = y_true - 1
        elif set(unique_vals) == {0, 1}:
            print("Labels are 0 and 1, no adjustment needed")
        else:
            print("Unexpected label values. Converting to binary:")
            y_true = (y_true != np.min(y_true)).astype(int)
        
        print(f"Adjusted unique values: {np.unique(y_true)}")
        
        # Save classification report
        report = classification_report(y_true, y_pred > 0.5, target_names=['Normal', 'Glaucoma'])
        print("\nClassification Report:")
        print(report)
        
        # Save report to file
        report_path = os.path.join(self.output_dir, 'classification_report.txt')
        with open(report_path, 'w') as f:
            f.write("Classification Report:\n")
            f.write(report)
            f.write(f"\nTest Loss: {results[0]:.4f}")
            f.write(f"\nTest Accuracy: {results[1]:.4f}")
            f.write(f"\nTest AUC: {results[2]:.4f}")
        
        # Confusion matrix
        cm = confusion_matrix(y_true, y_pred > 0.5)
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                   xticklabels=['Normal', 'Glaucoma'],
                   yticklabels=['Normal', 'Glaucoma'])
        plt.title('Confusion Matrix')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.tight_layout()
        output_path = os.path.join(self.output_dir, 'confusion_matrix.png')
        plt.savefig(output_path)
        plt.close()
        
        # ROC curve
        fpr, tpr, _ = roc_curve(y_true, y_pred)
        roc_auc = roc_auc_score(y_true, y_pred)
        
        plt.figure(figsize=(8, 6))
        plt.plot(fpr, tpr, color='darkorange', lw=2, 
                label=f'ROC curve (AUC = {roc_auc:.2f})')
        plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiver Operating Characteristic')
        plt.legend(loc="lower right")
        plt.tight_layout()
        output_path = os.path.join(self.output_dir, 'roc_curve.png')
        plt.savefig(output_path)
        plt.close()
        
        self.analyze_feature_importance(test_gen)
        
        return results
    
    def analyze_feature_importance(self, test_gen):
        """Analyze clinical feature importance using permutation on entire test set"""
        print("\nAnalyzing feature importance...")
        
        # Collect entire test set
        all_images = []
        all_clinical = []
        all_diagnosis = []
        
        for i in range(len(test_gen)):
            (images, clinical), diagnosis = test_gen[i]
            all_images.append(images)
            all_clinical.append(clinical)
            all_diagnosis.append(diagnosis)
        
        all_images = np.vstack(all_images)
        all_clinical = np.vstack(all_clinical)
        all_diagnosis = np.vstack(all_diagnosis)
        
        # Convert labels to 0 and 1
        unique_vals = np.unique(all_diagnosis)
        if set(unique_vals) == {1, 2}:
            all_diagnosis = all_diagnosis - 1
        elif set(unique_vals) == {0, 1}:
            pass
        else:
            all_diagnosis = (all_diagnosis != np.min(all_diagnosis)).astype(int)
        
        # Baseline prediction
        baseline_pred = self.model.predict([all_images, all_clinical])
        baseline_auc = roc_auc_score(all_diagnosis, baseline_pred)
        print(f"Baseline AUC: {baseline_auc:.4f}")
        
        feature_importance = []
        for i, feature_name in enumerate(Config.CLINICAL_FEATURES):
            print(f"Processing feature {i+1}/{len(Config.CLINICAL_FEATURES)}: {feature_name}")
            
            clinical_permuted = all_clinical.copy()
            np.random.shuffle(clinical_permuted[:, i])
            
            permuted_pred = self.model.predict([all_images, clinical_permuted])
            permuted_auc = roc_auc_score(all_diagnosis, permuted_pred)
            
            importance = baseline_auc - permuted_auc
            feature_importance.append(importance)
            print(f"  Feature importance: {importance:.4f}")
        
        # Save feature importance to file
        importance_path = os.path.join(self.output_dir, 'feature_importance.txt')
        with open(importance_path, 'w') as f:
            f.write("Feature Importance (Permutation Test):\n")
            for name, imp in zip(Config.CLINICAL_FEATURES, feature_importance):
                f.write(f"{name}: {imp:.6f}\n")
        
        # Plot feature importance
        plt.figure(figsize=(12, 6))
        plt.barh(Config.CLINICAL_FEATURES, feature_importance)
        plt.title('Clinical Feature Importance (Permutation Test)')
        plt.xlabel('AUC Decrease')
        plt.tight_layout()
        output_path = os.path.join(self.output_dir, 'feature_importance.png')
        plt.savefig(output_path)
        plt.close()

if __name__ == "__main__":
    print("Starting Robust Glaucoma Detection System...")
    print(f"All outputs will be saved to: {Config.OUTPUT_DIR}")
    start_time = time.time()
    
    try:
        trainer = GlaucomaTrainer()
        history = trainer.train()
        results = trainer.evaluate()
        
        elapsed = time.time() - start_time
        print(f"\nExecution completed successfully in {elapsed/60:.2f} minutes!")
        print(f"All outputs saved to: {Config.OUTPUT_DIR}")
    except Exception as e:
        print(f"\nCritical error encountered: {str(e)}")
        import traceback
        traceback.print_exc()
        print("\nSystem terminated due to error")

Starting Robust Glaucoma Detection System...
All outputs will be saved to: /Users/adyasha/Downloads/thesis_data/outputs
Precomputing features for 4319 images...
Feature computation completed in 9.5 seconds
Precomputing features for 98 images...
Feature computation completed in 0.0 seconds
Precomputing features for 335 images...
Feature computation completed in 0.1 seconds
Class weights: {0: 1.2869487485101312, 1: 0.8176826959485044}

=== Phase 1: Training head ===
Epoch 1/15
[1m135/135[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m50s[0m 336ms/step - accuracy: 0.6323 - auc: 0.6839 - loss: 0.9053 - val_accuracy: 0.7449 - val_auc: 0.7755 - val_loss: 0.8244 - learning_rate: 3.3333e-05
Epoch 2/15
[1m135/135[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m46s[0m 337ms/step - accuracy: 0.7404 - auc: 0.8172 - loss: 0.7483 - val_accuracy: 0.7347 - val_auc: 0.8609 - val_loss: 0.7088 - learning_rate: 6.6667e-05
Epoch 3/15
[1m135/135[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m46s[0m 343