# 🧠 CNN Baseline for Diabetic Retinopathy Classification

## 📋 Experiment Overview
- **Model**: Custom CNN built from scratch
- **Task**: Binary classification (No DR vs Has DR)
- **Dataset**: Kaggle Diabetic Retinopathy (35,126 images)
- **Strategy**: Simple CNN architecture as baseline
- **Expected AUC**: 0.60-0.70 (baseline for comparison)

## 🎯 Key Features:
1. **Lightweight architecture** (faster training)
2. **From-scratch learning** (no transfer learning)
3. **Medical-focused design** (appropriate for retinal images)
4. **Class imbalance handling** (focal loss + class weights)
5. **Efficient for limited resources**

In [None]:
# ============================================
# STEP 1: IMPORT LIBRARIES
# ============================================
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import time
from pathlib import Path

# Deep Learning
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import (
    Conv2D, MaxPooling2D, GlobalAveragePooling2D,
    Dense, Dropout, BatchNormalization, Flatten
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Scikit-learn
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, 
    roc_auc_score, confusion_matrix, roc_curve, precision_recall_curve
)
from sklearn.utils.class_weight import compute_class_weight

# Suppress warnings
import warnings
warnings.filterwarnings('ignore')

print("📦 All libraries imported successfully!")
print(f"🔥 TensorFlow version: {tf.__version__}")
print(f"🎯 GPU available: {tf.config.list_physical_devices('GPU')}")

In [None]:
# ============================================
# STEP 2: FOCAL LOSS IMPLEMENTATION
# ============================================
print("\n" + "="*60)
print("STEP 2: FOCAL LOSS FOR CLASS IMBALANCE")
print("="*60)

def focal_loss(gamma=2.0, alpha=0.75):
    """
    Focal Loss for handling severe class imbalance in medical data.
    
    Args:
        gamma: Focusing parameter (higher = more focus on hard examples)
        alpha: Weighting factor for minority class
    
    Returns:
        Focal loss function for Keras
    """
    def focal_loss_fixed(y_true, y_pred):
        # Ensure types
        y_true = tf.cast(y_true, tf.float32)
        epsilon = tf.keras.backend.epsilon()
        y_pred = tf.clip_by_value(y_pred, epsilon, 1.0 - epsilon)
        
        # Calculate p_t
        p_t = tf.where(tf.equal(y_true, 1), y_pred, 1 - y_pred)
        
        # Calculate alpha_t
        alpha_factor = tf.ones_like(y_true) * alpha
        alpha_t = tf.where(tf.equal(y_true, 1), alpha_factor, 1 - alpha_factor)
        
        # Calculate focal weight
        cross_entropy = -tf.math.log(p_t)
        weight = alpha_t * tf.pow((1 - p_t), gamma)
        
        # Final loss
        loss = weight * cross_entropy
        return tf.reduce_mean(loss)
    
    return focal_loss_fixed

print("✅ Focal loss function defined!")
print(f"   Gamma (focusing): {2.0}")
print(f"   Alpha (weighting): {0.75}")
print("   📋 This will help CNN focus on hard-to-classify DR cases")

In [None]:
# ============================================
# STEP 3: DATA LOADING & SETUP
# ============================================
# For Kaggle - adjust paths as needed
KAGGLE_INPUT_PATH = '/kaggle/input/diabetic-retinopathy-resized'
LOCAL_PATH = '/Users/muhirwa/Desktop/projects/diabetes_retinopathy/data'

# Try Kaggle path first, fallback to local
if os.path.exists(KAGGLE_INPUT_PATH):
    DATA_PATH = KAGGLE_INPUT_PATH
    print("🌐 Using Kaggle dataset path")
elif os.path.exists(LOCAL_PATH):
    DATA_PATH = LOCAL_PATH
    print("💻 Using local dataset path")
else:
    print("❌ Dataset not found! Please check paths.")
    DATA_PATH = None

print(f"📁 Data path: {DATA_PATH}")

# Load labels
if DATA_PATH:
    labels_path = os.path.join(DATA_PATH, 'trainLabels.csv')
    if os.path.exists(labels_path):
        labels_df = pd.read_csv(labels_path)
        print(f"📊 Labels loaded! Total images: {len(labels_df)}")
        
        # Convert to binary classification
        labels_df['binary_label'] = (labels_df['level'] > 0).astype(int)
        
        print("\n🎯 Binary classification:")
        print(labels_df['binary_label'].value_counts())
        print(f"  No DR (0): {(labels_df['binary_label']==0).sum()} images ({(labels_df['binary_label']==0).sum()/len(labels_df)*100:.1f}%)")
        print(f"  Has DR (1): {(labels_df['binary_label']==1).sum()} images ({(labels_df['binary_label']==1).sum()/len(labels_df)*100:.1f}%)")
        
        # Add file paths
        image_dir = os.path.join(DATA_PATH, 'resized_train_cropped', 'resized_train_cropped')
        labels_df['filepath'] = labels_df['image'].apply(lambda x: os.path.join(image_dir, f"{x}.jpeg"))
        
        # Use subset for faster baseline training (uncomment if needed)
        # labels_df = labels_df.sample(n=10000, random_state=42).reset_index(drop=True)
        # print(f"Using subset of {len(labels_df)} images for faster baseline training")
        
    else:
        print(f"❌ Labels file not found at {labels_path}")

In [None]:
# ============================================
# STEP 4: TRAIN/VAL/TEST SPLITS
# ============================================
print("\n" + "="*60)
print("STEP 4: CREATE TRAIN/VAL/TEST SPLITS")
print("="*60)

# Use all images for CNN baseline
labels_df_sample = labels_df.copy()
print(f"Using all {len(labels_df_sample)} images for CNN baseline")

# Split: 70% train, 15% validation, 15% test
train_df, temp_df = train_test_split(
    labels_df_sample,
    test_size=0.3,
    random_state=42,
    stratify=labels_df_sample['binary_label']  # Maintain class distribution
)

val_df, test_df = train_test_split(
    temp_df,
    test_size=0.5,
    random_state=42,
    stratify=temp_df['binary_label']
)

print("\n📊 Data split complete:")
print(f"  Train:      {len(train_df):6d} images ({len(train_df)/len(labels_df_sample)*100:.1f}%)")
print(f"  Validation: {len(val_df):6d} images ({len(val_df)/len(labels_df_sample)*100:.1f}%)")
print(f"  Test:       {len(test_df):6d} images ({len(test_df)/len(labels_df_sample)*100:.1f}%)")

# Check class distribution in each split
print("\n📊 Class distribution per split:")
for name, df in [('Train', train_df), ('Val', val_df), ('Test', test_df)]:
    class_counts = df['binary_label'].value_counts()
    print(f"  {name:5s}: No DR = {class_counts[0]:4d} ({class_counts[0]/len(df)*100:.1f}%), Has DR = {class_counts[1]:4d} ({class_counts[1]/len(df)*100:.1f}%)")

In [None]:
# ============================================
# STEP 5: DATA AUGMENTATION & GENERATORS
# ============================================
print("\n" + "="*60)
print("STEP 5: DATA AUGMENTATION & GENERATORS")
print("="*60)

# Conservative augmentation for baseline CNN
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=15,          # Conservative rotation
    width_shift_range=0.1,
    height_shift_range=0.1,
    shear_range=0.05,
    zoom_range=0.1,
    horizontal_flip=True,
    brightness_range=[0.9, 1.1], # Minimal brightness change
    fill_mode='nearest'
)

# Validation and test (no augmentation)
val_test_datagen = ImageDataGenerator(rescale=1./255)

# Convert labels to strings (required by flow_from_dataframe)
train_df['binary_label_str'] = train_df['binary_label'].astype(str)
val_df['binary_label_str'] = val_df['binary_label'].astype(str)
test_df['binary_label_str'] = test_df['binary_label'].astype(str)

# Create generators
BATCH_SIZE = 32  # Standard batch size
IMG_SIZE = (128, 128)  # Smaller size for faster baseline training

train_generator = train_datagen.flow_from_dataframe(
    train_df,
    x_col='filepath',
    y_col='binary_label_str',
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='binary',
    shuffle=True,
    seed=42
)

val_generator = val_test_datagen.flow_from_dataframe(
    val_df,
    x_col='filepath',
    y_col='binary_label_str',
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='binary',
    shuffle=False
)

test_generator = val_test_datagen.flow_from_dataframe(
    test_df,
    x_col='filepath',
    y_col='binary_label_str',
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='binary',
    shuffle=False
)

print(f"\n✅ Generators created successfully!")
print(f"   Input size: {IMG_SIZE}")
print(f"   Batch size: {BATCH_SIZE}")
print(f"   Train batches: {len(train_generator)}")
print(f"   Val batches:   {len(val_generator)}")
print(f"   Test batches:  {len(test_generator)}")

In [None]:
# ============================================
# STEP 6: CNN BASELINE ARCHITECTURE
# ============================================
print("\n" + "="*60)
print("STEP 6: BUILDING CNN BASELINE MODEL")
print("="*60)

def create_cnn_baseline(input_shape=(128, 128, 3)):
    """
    Create a simple CNN baseline for diabetic retinopathy classification.
    Designed to be fast and effective for medical images.
    """
    model = Sequential([
        # First Conv Block
        Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),
        BatchNormalization(),
        MaxPooling2D(2, 2),
        Dropout(0.25),
        
        # Second Conv Block
        Conv2D(64, (3, 3), activation='relu'),
        BatchNormalization(),
        MaxPooling2D(2, 2),
        Dropout(0.25),
        
        # Third Conv Block
        Conv2D(128, (3, 3), activation='relu'),
        BatchNormalization(),
        MaxPooling2D(2, 2),
        Dropout(0.25),
        
        # Fourth Conv Block
        Conv2D(256, (3, 3), activation='relu'),
        BatchNormalization(),
        MaxPooling2D(2, 2),
        Dropout(0.25),
        
        # Global Average Pooling (better than Flatten)
        GlobalAveragePooling2D(),
        
        # Dense layers
        Dense(512, activation='relu'),
        BatchNormalization(),
        Dropout(0.5),
        
        Dense(128, activation='relu'),
        Dropout(0.3),
        
        # Output layer
        Dense(1, activation='sigmoid')  # Binary classification
    ])
    
    return model

# Create the model
model = create_cnn_baseline(input_shape=(*IMG_SIZE, 3))

print(f"🧠 CNN Baseline Architecture:")
model.summary()

# Count parameters
total_params = model.count_params()
print(f"\n✅ Total parameters: {total_params:,}")
print(f"📊 Much smaller than ResNet50 ({total_params/25000000:.1f}x smaller)")
print(f"⚡ Faster training expected")

In [None]:
# ============================================
# STEP 7: COMPILE MODEL & CALLBACKS
# ============================================
print("\n" + "="*60)
print("STEP 7: COMPILE MODEL & SETUP CALLBACKS")
print("="*60)

# Calculate class weights
y_train_labels = train_df['binary_label'].astype(int).values
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(y_train_labels),
    y=y_train_labels
)
class_weight_dict = {0: class_weights[0], 1: class_weights[1]}

print(f"⚖️  Class weights calculated:")
print(f"   No DR (0): {class_weight_dict[0]:.3f}")
print(f"   Has DR (1): {class_weight_dict[1]:.3f}")

# Compile model
model.compile(
    optimizer=Adam(learning_rate=0.001),  # Standard learning rate
    loss=focal_loss(gamma=2.0, alpha=0.75),  # Focal loss for imbalance
    metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
)

print(f"✅ Model compiled with:")
print(f"   Optimizer: Adam (LR=0.001)")
print(f"   Loss: Focal Loss (γ=2.0, α=0.75)")
print(f"   Metrics: Accuracy, AUC")

# Define callbacks
callbacks = [
    EarlyStopping(
        monitor='val_auc',
        patience=10,  # More patience for baseline
        mode='max',
        restore_best_weights=True,
        verbose=1
    ),
    
    ReduceLROnPlateau(
        monitor='val_auc',
        factor=0.5,
        patience=5,
        mode='max',
        min_lr=1e-7,
        verbose=1
    ),
    
    ModelCheckpoint(
        'best_cnn_baseline.h5',
        monitor='val_auc',
        mode='max',
        save_best_only=True,
        verbose=1
    )
]

print(f"\n🔧 Callbacks configured:")
print(f"   📊 Monitoring: val_auc (maximize)")
print(f"   ⏹️  Early stopping: 10 epochs patience")
print(f"   📉 Learning rate reduction: 5 epochs patience")
print(f"   💾 Model checkpoint: Save best AUC")

In [None]:
# ============================================
# STEP 8: TRAIN CNN BASELINE
# ============================================
print("\n" + "="*60)
print("STEP 8: TRAINING CNN BASELINE")
print("="*60)

EPOCHS = 30  # Sufficient for baseline convergence

print(f"⏳ Training CNN baseline for up to {EPOCHS} epochs...")
print(f"📊 Learning rate: 0.001")
print(f"🎯 Goal: Establish baseline performance for comparison")

start_time = time.time()

history = model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=EPOCHS,
    callbacks=callbacks,
    verbose=1,
    class_weight=class_weight_dict
)

training_time = time.time() - start_time
print(f"\n✅ Training completed in {training_time:.2f} seconds ({training_time/60:.2f} minutes)")

# Get training results
best_val_auc = max(history.history['val_auc'])
final_epoch = len(history.history['loss'])

print(f"📊 Training Summary:")
print(f"   Epochs completed: {final_epoch}")
print(f"   Best validation AUC: {best_val_auc:.4f}")
print(f"   Training time: {training_time/60:.1f} minutes")

In [None]:
# ============================================
# STEP 9: EVALUATE ON TEST SET
# ============================================
print("\n" + "="*60)
print("STEP 9: EVALUATING ON TEST SET")
print("="*60)

# Get predictions
print("🔮 Generating predictions on test set...")
y_test_proba = model.predict(test_generator, verbose=1)
y_test_pred = (y_test_proba > 0.5).astype(int).flatten()
y_test_true = test_generator.classes

# Calculate metrics
test_acc = accuracy_score(y_test_true, y_test_pred)
test_prec = precision_score(y_test_true, y_test_pred, zero_division=0)
test_rec = recall_score(y_test_true, y_test_pred, zero_division=0)
test_f1 = f1_score(y_test_true, y_test_pred, zero_division=0)
test_auc = roc_auc_score(y_test_true, y_test_proba)

print(f"\n📊 Test Set Results (threshold=0.5):")
print(f"   Accuracy:  {test_acc*100:.2f}%")
print(f"   Precision: {test_prec*100:.2f}%")
print(f"   Recall:    {test_rec*100:.2f}%")
print(f"   F1 Score:  {test_f1*100:.2f}%")
print(f"   ROC AUC:   {test_auc:.4f}")

# Comparison with other models
print(f"\n🆚 Comparison with other models:")
models_comparison = [
    ('Random Forest', 0.551, 60.2),
    ('SVM', 0.505, 58.8),
    ('VGG16', 0.706, 1.0),
    ('CNN Baseline', test_auc, test_rec*100)
]

for model_name, auc, recall in models_comparison:
    print(f"   {model_name:<15}: AUC {auc:.3f}, Recall {recall:.1f}%")

# Determine performance level
if test_auc >= 0.70:
    print("🎉 CNN Baseline performed excellently!")
elif test_auc >= 0.60:
    print("👍 CNN Baseline shows good performance!")
elif test_auc >= 0.55:
    print("📈 CNN Baseline shows moderate performance")
else:
    print("⚠️  CNN Baseline needs improvement")

print(f"\n🎯 Expected baseline performance: AUC 0.60-0.70")
if test_auc >= 0.60:
    print("✅ Baseline performance target achieved!")
else:
    print("📊 Consider adjusting architecture or hyperparameters")

In [None]:
# ============================================
# STEP 10: THRESHOLD OPTIMIZATION
# ============================================
print("\n" + "="*60)
print("STEP 10: THRESHOLD OPTIMIZATION")
print("="*60)

# Find optimal threshold for medical use
precision_curve, recall_curve, thresholds = precision_recall_curve(y_test_true, y_test_proba)

# Target 60% recall for baseline (lower than transfer learning)
target_recall = 0.60
recall_diffs = np.abs(recall_curve - target_recall)
optimal_idx = np.argmin(recall_diffs)
optimal_threshold = thresholds[optimal_idx] if optimal_idx < len(thresholds) else 0.5

print(f"🎯 Threshold optimization:")
print(f"   Target recall: {target_recall*100:.0f}% (baseline target)")
print(f"   Optimal threshold: {optimal_threshold:.3f}")

# Test different thresholds
print(f"\n📊 Performance at different thresholds:")
print(f"{'Threshold':<10} {'Accuracy':<10} {'Precision':<10} {'Recall':<10} {'F1':<10}")
print("-" * 50)

for thresh in [0.1, 0.2, 0.3, 0.4, 0.5, optimal_threshold]:
    y_pred_thresh = (y_test_proba > thresh).astype(int).flatten()
    acc = accuracy_score(y_test_true, y_pred_thresh)
    prec = precision_score(y_test_true, y_pred_thresh, zero_division=0)
    rec = recall_score(y_test_true, y_pred_thresh, zero_division=0)
    f1 = f1_score(y_test_true, y_pred_thresh, zero_division=0)
    
    print(f"{thresh:<10.2f} {acc*100:<10.1f} {prec*100:<10.1f} {rec*100:<10.1f} {f1*100:<10.1f}")

# Optimal threshold results
y_pred_optimal = (y_test_proba > optimal_threshold).astype(int).flatten()
optimal_acc = accuracy_score(y_test_true, y_pred_optimal)
optimal_prec = precision_score(y_test_true, y_pred_optimal, zero_division=0)
optimal_rec = recall_score(y_test_true, y_pred_optimal, zero_division=0)
optimal_f1 = f1_score(y_test_true, y_pred_optimal, zero_division=0)

print(f"\n🏥 Optimized for baseline use (threshold={optimal_threshold:.3f}):")
print(f"   Accuracy:  {optimal_acc*100:.2f}%")
print(f"   Precision: {optimal_prec*100:.2f}%")
print(f"   Recall:    {optimal_rec*100:.2f}%")
print(f"   F1 Score:  {optimal_f1*100:.2f}%")

In [None]:
# ============================================
# STEP 11: VISUALIZATION & ANALYSIS
# ============================================
print("\n" + "="*60)
print("STEP 11: VISUALIZATION & ANALYSIS")
print("="*60)

# Create comprehensive visualization
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle('CNN Baseline Training Progress & Results', fontsize=16)

# Loss plot
axes[0, 0].plot(history.history['loss'], label='Train Loss', color='blue')
axes[0, 0].plot(history.history['val_loss'], label='Val Loss', color='orange')
axes[0, 0].set_title('Model Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Accuracy plot
axes[0, 1].plot(history.history['accuracy'], label='Train Accuracy', color='blue')
axes[0, 1].plot(history.history['val_accuracy'], label='Val Accuracy', color='orange')
axes[0, 1].set_title('Model Accuracy')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# AUC plot
axes[1, 0].plot(history.history['auc'], label='Train AUC', color='blue')
axes[1, 0].plot(history.history['val_auc'], label='Val AUC', color='orange')
axes[1, 0].set_title('Model AUC')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('AUC')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Confusion Matrix
cm = confusion_matrix(y_test_true, y_pred_optimal)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[1, 1])
axes[1, 1].set_title(f'Confusion Matrix (threshold={optimal_threshold:.3f})')
axes[1, 1].set_xlabel('Predicted Label')
axes[1, 1].set_ylabel('True Label')
axes[1, 1].set_xticklabels(['No DR', 'Has DR'])
axes[1, 1].set_yticklabels(['No DR', 'Has DR'])

plt.tight_layout()
plt.show()

# ROC Curve
fpr, tpr, _ = roc_curve(y_test_true, y_test_proba)

plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='orange', lw=2, label=f'ROC (AUC = {test_auc:.3f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve - CNN Baseline')
plt.legend(loc="lower right")
plt.grid(True, alpha=0.3)
plt.show()

print("📊 Visualizations complete!")

In [None]:
# ============================================
# STEP 12: FINAL RESULTS SUMMARY
# ============================================
print("\n" + "="*60)
print("STEP 12: FINAL RESULTS SUMMARY")
print("="*60)

print(f"🧠 CNN Baseline Results Summary:")
print(f"\n📊 Model Architecture:")
print(f"   Type: Custom CNN (4 conv blocks + 2 dense layers)")
print(f"   Parameters: {total_params:,}")
print(f"   Input size: {IMG_SIZE}")
print(f"   Training time: {training_time/60:.1f} minutes")

print(f"\n🎯 Performance Results:")
print(f"   Test AUC: {test_auc:.4f}")
print(f"   Default threshold (0.5):")
print(f"     Accuracy: {test_acc*100:.1f}%")
print(f"     Precision: {test_prec*100:.1f}%")
print(f"     Recall: {test_rec*100:.1f}%")
print(f"     F1: {test_f1*100:.1f}%")

print(f"\n🏥 Optimized threshold ({optimal_threshold:.3f}):")
print(f"     Accuracy: {optimal_acc*100:.1f}%")
print(f"     Precision: {optimal_prec*100:.1f}%")
print(f"     Recall: {optimal_rec*100:.1f}%")
print(f"     F1: {optimal_f1*100:.1f}%")

print(f"\n🆚 Model Comparison (Final Ranking):")
all_models = [
    ('Random Forest', 0.551, 60.2, 'Traditional ML'),
    ('SVM', 0.505, 58.8, 'Traditional ML'),
    ('CNN Baseline', test_auc, optimal_rec*100, 'Deep Learning'),
    ('VGG16', 0.706, 1.0, 'Transfer Learning'),  # Note: VGG16 recall was broken
]

# Sort by AUC
all_models_sorted = sorted(all_models, key=lambda x: x[1], reverse=True)

print(f"   {'Rank':<5} {'Model':<15} {'AUC':<8} {'Recall':<8} {'Type':<15}")
print("-" * 60)
for i, (name, auc, recall, model_type) in enumerate(all_models_sorted, 1):
    print(f"   {i:<5} {name:<15} {auc:<8.3f} {recall:<8.1f} {model_type:<15}")

print(f"\n🎯 Baseline Assessment:")
if test_auc >= 0.65:
    print(f"   ✅ STRONG BASELINE: CNN achieved competitive performance")
    print(f"   🚀 Ready for advanced architectures (ResNet50, ensemble)")
elif test_auc >= 0.60:
    print(f"   👍 GOOD BASELINE: CNN shows promise for deep learning approach")
    print(f"   📈 Transfer learning should provide significant improvement")
elif test_auc >= 0.55:
    print(f"   📊 MODERATE BASELINE: CNN outperforms some traditional ML")
    print(f"   🔧 Consider architecture improvements or more data")
else:
    print(f"   ⚠️  WEAK BASELINE: Traditional ML may be more suitable")
    print(f"   🛠️ Revise architecture or feature engineering approach")

print(f"\n💡 Key Insights:")
train_auc = max(history.history['auc'])
overfitting = train_auc - test_auc
print(f"   • Training AUC: {train_auc:.3f}")
print(f"   • Test AUC: {test_auc:.3f}")
print(f"   • Overfitting: {overfitting:.3f} ({'Low' if overfitting < 0.05 else 'Moderate' if overfitting < 0.1 else 'High'})")
print(f"   • Baseline established for {'transfer learning' if test_auc >= 0.60 else 'traditional ML'} comparison")

print(f"\n📋 Recommendations:")
if test_auc >= 0.65:
    print(f"   1. ✅ CNN baseline is strong - proceed with ResNet50")
    print(f"   2. 🎯 Consider ensemble methods (CNN + XGBoost)")
    print(f"   3. 🔬 Experiment with different architectures")
elif test_auc >= 0.60:
    print(f"   1. 📈 ResNet50 transfer learning should significantly improve")
    print(f"   2. 🎛️ Consider data augmentation improvements")
    print(f"   3. 🔍 Compare with XGBoost performance")
else:
    print(f"   1. 🛠️ Revise CNN architecture (more layers, different activation)")
    print(f"   2. 📊 Focus on XGBoost and feature engineering")
    print(f"   3. 🔄 Consider larger input size (224x224)")

print(f"\n💾 Model saved as: best_cnn_baseline.h5")
print(f"🎯 Use this as baseline for comparison with advanced models")

## 🎯 CNN Baseline Experiment Summary

### Purpose:
Establish a **deep learning baseline** for diabetic retinopathy classification to compare against:
1. Traditional ML methods (Random Forest, SVM)
2. Advanced deep learning (ResNet50, VGG16)
3. Ensemble approaches

### Key Features:
- **Lightweight architecture**: 4 conv blocks + 2 dense layers
- **Fast training**: ~30 minutes vs hours for transfer learning
- **Medical-optimized**: Focal loss + class weights for imbalance
- **Conservative augmentation**: Appropriate for medical images

### Expected Performance:
- **Target AUC**: 0.60-0.70 (baseline expectation)
- **Training time**: 20-40 minutes
- **Memory usage**: Low (smaller than transfer learning)

### Success Criteria:
- **AUC ≥ 0.65**: Strong baseline, proceed with advanced models
- **AUC 0.60-0.64**: Good baseline, transfer learning should help
- **AUC < 0.60**: Consider architecture improvements

### Next Steps Based on Results:
1. **If CNN baseline performs well**: Compare with ResNet50
2. **If CNN baseline struggles**: Focus on XGBoost and feature engineering
3. **Either way**: Use as ensemble component

### For Production Use:
```python
# Load the baseline model
from tensorflow.keras.models import load_model

# Custom objects for focal loss
custom_objects = {'focal_loss_fixed': focal_loss(2.0, 0.75)}
model = load_model('best_cnn_baseline.h5', custom_objects=custom_objects)

# Predict with optimized threshold
predictions = model.predict(images)
binary_predictions = (predictions > optimal_threshold).astype(int)
```