In [None]:
from google.colab import drive
drive.mount('/content/drive')

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from sklearn.utils import class_weight
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

print("="*70)
print("DIABETIC RETINOPATHY CLASSIFICATION")
print("="*70 + "\n")

# ========== CONFIGURATION ==========
IMG_SIZE = 224
BATCH_SIZE = 16
EPOCHS = 100
LEARNING_RATE = 0.0005

# ========== DATA AUGMENTATION ==========
data_augmentation = tf.keras.Sequential([
    layers.RandomFlip("horizontal_and_vertical"),
    layers.RandomRotation(0.2),
    layers.RandomZoom(0.15),
    layers.RandomTranslation(0.1, 0.1),
    layers.RandomContrast(0.2),
], name='data_augmentation')

# ========== LOAD DATA ==========
print("Loading dataset...")
train_ds = tf.keras.utils.image_dataset_from_directory(
    "/content/drive/MyDrive/Colab Notebooks/dataset",
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    shuffle=True
)

val_test_ds = tf.keras.utils.image_dataset_from_directory(
    "/content/drive/MyDrive/Colab Notebooks/dataset",
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE
)

class_names = train_ds.class_names
num_classes = len(class_names)
print(f"\nClasses detected: {class_names}")
print(f"Number of classes: {num_classes}")

# Split validation and test sets
val_batches = tf.data.experimental.cardinality(val_test_ds)
val_size = val_batches // 2
val_ds = val_test_ds.take(val_size)
test_ds = val_test_ds.skip(val_size)

print(f"\nDataset split:")
print(f"  Training batches: {tf.data.experimental.cardinality(train_ds)}")
print(f"  Validation batches: {tf.data.experimental.cardinality(val_ds)}")
print(f"  Test batches: {tf.data.experimental.cardinality(test_ds)}")

# Optimize performance
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
test_ds = test_ds.cache().prefetch(buffer_size=AUTOTUNE)

# ========== CALCULATE CLASS WEIGHTS ==========
print("\nCalculating class weights for imbalanced data...")
labels = []
for _, batch_labels in train_ds:
    labels.extend(batch_labels.numpy())

class_weights = class_weight.compute_class_weight(
    class_weight='balanced',
    classes=np.unique(labels),
    y=labels
)
class_weight_dict = dict(enumerate(class_weights))

print("\nClass distribution:")
for i, name in enumerate(class_names):
    count = labels.count(i)
    print(f"  {name}: {count} samples (weight: {class_weights[i]:.2f})")

# ========== BUILD MODEL ==========
print("\n" + "="*70)
print("BUILDING MODEL")
print("="*70 + "\n")

base_model = EfficientNetB0(
    input_shape=(IMG_SIZE, IMG_SIZE, 3),
    include_top=False,
    weights='imagenet'
)
base_model.trainable = False

inputs = tf.keras.Input(shape=(IMG_SIZE, IMG_SIZE, 3))
x = data_augmentation(inputs)
x = tf.keras.applications.efficientnet.preprocess_input(x)
x = base_model(x, training=False)
x = layers.GlobalAveragePooling2D()(x)
x = layers.BatchNormalization()(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(256, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.01))(x)
x = layers.BatchNormalization()(x)
x = layers.Dropout(0.4)(x)
x = layers.Dense(128, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.01))(x)
x = layers.Dropout(0.3)(x)
outputs = layers.Dense(num_classes, activation='softmax')(x)

model = models.Model(inputs, outputs)

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

print("Model architecture:")
model.summary()

# ========== CALLBACKS ==========
callbacks = [
    EarlyStopping(
        monitor='val_accuracy',
        patience=25,
        restore_best_weights=True,
        verbose=1,
        mode='max'
    ),
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=8,
        min_lr=1e-8,
        verbose=1
    ),
    ModelCheckpoint(
        'best_diabetic_retinopathy_model.keras',
        monitor='val_accuracy',
        save_best_only=True,
        mode='max',
        verbose=1
    )
]

# ========== TRAINING ==========
print("\n" + "="*70)
print("TRAINING MODEL")
print("="*70 + "\n")

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    callbacks=callbacks,
    class_weight=class_weight_dict,
    verbose=1
)

# ========== EVALUATION ==========
print("\n" + "="*70)
print("MODEL EVALUATION")
print("="*70 + "\n")

test_loss, test_accuracy = model.evaluate(test_ds, verbose=0)
print(f"‚úì Test Accuracy: {test_accuracy*100:.2f}%")
print(f"‚úì Test Loss: {test_loss:.4f}")

# Get predictions
print("\nGenerating predictions...")
y_true = []
y_pred = []
y_pred_proba = []

for images, labels in test_ds:
    predictions = model.predict(images, verbose=0)
    y_true.extend(labels.numpy())
    y_pred.extend(np.argmax(predictions, axis=1))
    y_pred_proba.extend(predictions)

# ========== DETAILED METRICS ==========
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support

print("\n" + "="*70)
print("CLASSIFICATION REPORT")
print("="*70 + "\n")
print(classification_report(y_true, y_pred, target_names=class_names, digits=3))

# Calculate additional metrics
precision, recall, f1, support = precision_recall_fscore_support(y_true, y_pred, average=None)

print("\n" + "="*70)
print("DETAILED PER-CLASS METRICS")
print("="*70 + "\n")
print(f"{'Class':<20} {'Precision':<12} {'Recall':<12} {'F1-Score':<12} {'Support':<10}")
print("-" * 66)
for i, name in enumerate(class_names):
    print(f"{name:<20} {precision[i]:>7.3f}      {recall[i]:>7.3f}      {f1[i]:>7.3f}      {support[i]:>5}")

print("\n" + "-" * 66)
print(f"{'Accuracy':<20} {test_accuracy:>7.3f}")
print(f"{'Macro Avg F1':<20} {np.mean(f1):>7.3f}")
print(f"{'Weighted Avg F1':<20} {np.average(f1, weights=support):>7.3f}")

# ========== CONFUSION MATRIX ==========
print("\n" + "="*70)
print("CONFUSION MATRIX")
print("="*70 + "\n")

cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=class_names, yticklabels=class_names,
            cbar_kws={'label': 'Number of Predictions'},
            annot_kws={'size': 14, 'weight': 'bold'})
plt.title('Confusion Matrix - Diabetic Retinopathy Classification', 
          fontsize=16, fontweight='bold', pad=20)
plt.ylabel('True Label', fontsize=14, fontweight='bold')
plt.xlabel('Predicted Label', fontsize=14, fontweight='bold')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()

# ========== TRAINING HISTORY ==========
print("\n" + "="*70)
print("TRAINING HISTORY")
print("="*70 + "\n")

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

best_epoch = np.argmax(val_acc) + 1
best_val_acc = max(val_acc)

print(f"Training completed in {len(acc)} epochs")
print(f"Best validation accuracy: {best_val_acc*100:.2f}% (Epoch {best_epoch})")
print(f"Final training accuracy: {acc[-1]*100:.2f}%")
print(f"Final validation accuracy: {val_acc[-1]*100:.2f}%")

# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# Accuracy plot
axes[0, 0].plot(acc, 'b-', linewidth=2, label='Training Accuracy')
axes[0, 0].plot(val_acc, 'r-', linewidth=2, label='Validation Accuracy')
axes[0, 0].axvline(x=best_epoch-1, color='green', linestyle='--', 
                   linewidth=2, alpha=0.7, label=f'Best Epoch ({best_epoch})')
axes[0, 0].set_title('Model Accuracy Over Time', fontsize=14, fontweight='bold')
axes[0, 0].set_xlabel('Epoch', fontsize=12)
axes[0, 0].set_ylabel('Accuracy', fontsize=12)
axes[0, 0].legend(loc='lower right', fontsize=11)
axes[0, 0].grid(True, alpha=0.3)

# Loss plot
axes[0, 1].plot(loss, 'b-', linewidth=2, label='Training Loss')
axes[0, 1].plot(val_loss, 'r-', linewidth=2, label='Validation Loss')
axes[0, 1].axvline(x=best_epoch-1, color='green', linestyle='--', 
                   linewidth=2, alpha=0.7, label=f'Best Epoch ({best_epoch})')
axes[0, 1].set_title('Model Loss Over Time', fontsize=14, fontweight='bold')
axes[0, 1].set_xlabel('Epoch', fontsize=12)
axes[0, 1].set_ylabel('Loss', fontsize=12)
axes[0, 1].legend(loc='upper right', fontsize=11)
axes[0, 1].grid(True, alpha=0.3)

# Overfitting analysis
gap = np.array(acc) - np.array(val_acc)
axes[1, 0].plot(gap, 'purple', linewidth=2)
axes[1, 0].axhline(y=0, color='black', linestyle='-', alpha=0.3)
axes[1, 0].fill_between(range(len(gap)), gap, 0, alpha=0.3, color='purple')
axes[1, 0].set_title('Overfitting Analysis (Train-Val Gap)', fontsize=14, fontweight='bold')
axes[1, 0].set_xlabel('Epoch', fontsize=12)
axes[1, 0].set_ylabel('Accuracy Difference', fontsize=12)
axes[1, 0].grid(True, alpha=0.3)

# Per-class accuracy
class_accuracy = [cm[i, i] / cm[i].sum() if cm[i].sum() > 0 else 0 
                  for i in range(len(class_names))]
colors = ['green' if acc > 0.7 else 'orange' if acc > 0.5 else 'red' 
          for acc in class_accuracy]
axes[1, 1].bar(range(len(class_names)), class_accuracy, color=colors, alpha=0.7)
axes[1, 1].set_title('Per-Class Accuracy', fontsize=14, fontweight='bold')
axes[1, 1].set_xlabel('Class', fontsize=12)
axes[1, 1].set_ylabel('Accuracy', fontsize=12)
axes[1, 1].set_xticks(range(len(class_names)))
axes[1, 1].set_xticklabels(class_names, rotation=45, ha='right')
axes[1, 1].set_ylim([0, 1])
axes[1, 1].grid(True, alpha=0.3, axis='y')

for i, acc in enumerate(class_accuracy):
    axes[1, 1].text(i, acc + 0.02, f'{acc*100:.1f}%', 
                    ha='center', fontsize=10, fontweight='bold')

plt.tight_layout()
plt.show()

# ========== SAMPLE PREDICTIONS ==========
print("\n" + "="*70)
print("SAMPLE PREDICTIONS VISUALIZATION")
print("="*70 + "\n")

fig, axes = plt.subplots(3, 5, figsize=(20, 12))
axes = axes.ravel()

sample_idx = 0
for images, labels in test_ds.take(3):
    predictions = model.predict(images, verbose=0)
    for i in range(min(5, len(images))):
        if sample_idx >= 15:
            break
        
        axes[sample_idx].imshow(images[i].numpy().astype("uint8"))
        
        pred_class = np.argmax(predictions[i])
        true_class = labels[i].numpy()
        confidence = predictions[i][pred_class] * 100
        
        # Get top 2 predictions
        top2_idx = np.argsort(predictions[i])[-2:][::-1]
        
        color = 'green' if pred_class == true_class else 'red'
        
        title = f"True: {class_names[true_class]}\n"
        title += f"Predicted: {class_names[pred_class]}\n"
        title += f"Confidence: {confidence:.1f}%"
        
        if pred_class != true_class:
            title += f"\n(2nd: {class_names[top2_idx[1]]}, {predictions[i][top2_idx[1]]*100:.1f}%)"
        
        axes[sample_idx].set_title(title, color=color, fontsize=10, fontweight='bold')
        axes[sample_idx].axis('off')
        sample_idx += 1
    
    if sample_idx >= 15:
        break

plt.tight_layout()
plt.show()



# ========== FINAL SUMMARY ==========
print("\n" + "="*70)
print("FINAL SUMMARY")
print("="*70)
print(f"\nüéØ Test Accuracy: {test_accuracy*100:.2f}%")
print(f"üìä Best Validation Accuracy: {best_val_acc*100:.2f}%")
print(f"üìà Macro F1-Score: {np.mean(f1)*100:.2f}%")
print(f"‚öñÔ∏è  Weighted F1-Score: {np.average(f1, weights=support)*100:.2f}%")
print(f"\nüèÜ Best Performing Class: {class_names[np.argmax(recall)]} ({recall[np.argmax(recall)]*100:.1f}% recall)")
print(f"‚ö†Ô∏è  Weakest Class: {class_names[np.argmin(recall)]} ({recall[np.argmin(recall)]*100:.1f}% recall)")
print(f"\nüìÅ Model saved and ready for deployment!")
print("="*70)