In [22]:
# ===========================================================
# ü©∫ OCT Retinal Disease Detector ‚Äî ResNet50 + Grad-CAM + XAI
# ‚úÖ Kaggle-Ready | Offline | GPU | Patient-Friendly
# ===========================================================

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing import image_dataset_from_directory
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
from sklearn.utils.class_weight import compute_class_weight
import glob
import shutil
import cv2
import warnings
warnings.filterwarnings('ignore')

print("‚úÖ TensorFlow:", tf.__version__)
print("‚úÖ GPUs:", tf.config.list_physical_device('GPU'))

‚úÖ TensorFlow: 2.18.0
‚úÖ GPUs: []


In [24]:
# üîç Auto-detect dataset (handles 'OCT2017 /' with space)
oct_dirs = glob.glob("/kaggle/input/kermany2018/oct2017/OCT2017*")
if not oct_dirs:
    raise FileNotFoundError(
        "‚ùå Dataset not found! Please attach 'kermany2018':\n"
        "1. Click '+ Add data'\n2. Search 'kermany2018'\n3. Add it"
    )
BASE_DIR = oct_dirs[0]
print("‚úÖ Dataset:", BASE_DIR)

TRAIN_DIR = os.path.join(BASE_DIR, "train")
TEST_DIR = os.path.join(BASE_DIR, "test")
CLASSES = ['CNV', 'DME', 'DRUSEN', 'NORMAL']

# üí° Verify class sizes
print("\nüìä TRAIN set sizes:")
for cls in CLASSES:
    path = os.path.join(TRAIN_DIR, cls)
    count = len([f for f in os.listdir(path) if f.endswith(('.jpg', '.jpeg'))])
    print(f"  {cls:>6}: {count} images")

‚úÖ Dataset: /kaggle/input/kermany2018/oct2017/OCT2017 

üìä TRAIN set sizes:
     CNV: 37205 images
     DME: 11348 images
  DRUSEN: 8616 images
  NORMAL: 26315 images


In [25]:
# üß™ Create val_split (10% of train ‚Äî original 'val' is too small)
VAL_SPLIT_DIR = "/kaggle/working/val_split"
os.makedirs(VAL_SPLIT_DIR, exist_ok=True)

print("\nüîß Creating validation set (10% per class)...")
for cls in CLASSES:
    src = os.path.join(TRAIN_DIR, cls)
    dst = os.path.join(VAL_SPLIT_DIR, cls)
    os.makedirs(dst, exist_ok=True)
    
    imgs = [f for f in os.listdir(src) if f.endswith(('.jpg', '.jpeg'))]
    n_val = int(0.1 * len(imgs))
    np.random.seed(42)
    val_imgs = np.random.choice(imgs, n_val, replace=False)
    
    for img in val_imgs:
        shutil.copy2(os.path.join(src, img), os.path.join(dst, img))
    print(f"  {cls}: {len(imgs)} ‚Üí {n_val} val")


üîß Creating validation set (10% per class)...
  CNV: 37205 ‚Üí 3720 val
  DME: 11348 ‚Üí 1134 val
  DRUSEN: 8616 ‚Üí 861 val
  NORMAL: 26315 ‚Üí 2631 val


In [26]:
# üñºÔ∏è Load data with augmentation
IMG_SIZE = (224, 224)
BATCH_SIZE = 32

train_aug = keras.Sequential([
    layers.Rescaling(1./255),
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.1),
    layers.RandomZoom(0.1),
])

val_test_aug = keras.Sequential([layers.Rescaling(1./255)])

train_ds = image_dataset_from_directory(
    TRAIN_DIR, image_size=IMG_SIZE, batch_size=BATCH_SIZE,
    label_mode="categorical", class_names=CLASSES, seed=42
).map(lambda x, y: (train_aug(x), y)).prefetch(tf.data.AUTOTUNE)

val_ds = image_dataset_from_directory(
    VAL_SPLIT_DIR, image_size=IMG_SIZE, batch_size=BATCH_SIZE,
    label_mode="categorical", class_names=CLASSES, seed=42
).map(lambda x, y: (val_test_aug(x), y)).prefetch(tf.data.AUTOTUNE)

test_ds = image_dataset_from_directory(
    TEST_DIR, image_size=IMG_SIZE, batch_size=BATCH_SIZE,
    label_mode="categorical", class_names=CLASSES, seed=42
).map(lambda x, y: (val_test_aug(x), y)).prefetch(tf.data.AUTOTUNE)

print(f"\n‚úÖ Datasets ready! Train: {len(train_ds)} batches, Val: {len(val_ds)}, Test: {len(test_ds)}")

Found 83484 files belonging to 4 classes.
Found 8346 files belonging to 4 classes.
Found 968 files belonging to 4 classes.

‚úÖ Datasets ready! Train: 2609 batches, Val: 261, Test: 31


In [27]:
# ‚öñÔ∏è Handle class imbalance
train_labels = []
for _, labels in train_ds.take(50):  # sample 50 batches (~1.6K images)
    train_labels.append(tf.argmax(labels, axis=1))
train_labels = tf.concat(train_labels, axis=0).numpy()

class_weights = compute_class_weight('balanced', classes=np.unique(train_labels), y=train_labels)
class_weight_dict = dict(enumerate(class_weights))
print("\n‚öñÔ∏è Class weights:")
for i, cls in enumerate(CLASSES):
    print(f"  {cls}: {class_weight_dict[i]:.2f}")


‚öñÔ∏è Class weights:
  CNV: 0.54
  DME: 1.96
  DRUSEN: 2.68
  NORMAL: 0.79


In [28]:
# üß† Build model ‚Äî ResNet50 (pre-cached in Kaggle!)
print("\nüß† Building ResNet50 model...")

base_model = keras.applications.ResNet50(
    include_top=False,
    weights='imagenet',  # ‚úÖ Pre-cached ‚Äî no internet needed!
    input_shape=(*IMG_SIZE, 3)
)
base_model.trainable = False

inputs = keras.Input(shape=(*IMG_SIZE, 3))
x = base_model(inputs, training=False)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(128, activation='relu')(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(len(CLASSES), activation='softmax')(x)

model = keras.Model(inputs, outputs)

model.compile(
    optimizer=keras.optimizers.Adam(1e-4),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

print("‚úÖ Model built!")
model.summary()


üß† Building ResNet50 model...
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5


Exception: URL fetch failure on https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5: None -- [Errno -3] Temporary failure in name resolution

In [None]:
# üèÉ Train with callbacks
print("\nüöÄ Training (15 epochs)...")
callbacks = [
    keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True, verbose=1),
    keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=3, min_lr=1e-7),
    keras.callbacks.ModelCheckpoint('/kaggle/working/best_model.h5', save_best_only=True)
]

history = model.fit(
    train_ds,
    epochs=15,
    validation_data=val_ds,
    class_weight=class_weight_dict,
    callbacks=callbacks,
    verbose=1
)

In [None]:
# üìà Plot history
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(acc)+1)

plt.figure(figsize=(12, 4))
plt.subplot(1,2,1)
plt.plot(epochs, acc, 'b-', label='Train'); plt.plot(epochs, val_acc, 'r--', label='Val')
plt.title('Accuracy'); plt.legend()
plt.subplot(1,2,2)
plt.plot(epochs, loss, 'b-', label='Train'); plt.plot(epochs, val_loss, 'r--', label='Val')
plt.title('Loss'); plt.legend()
plt.show()

In [None]:
# üß™ Test evaluation
print("\nüß™ Test Evaluation:")
test_loss, test_acc = model.evaluate(test_ds, verbose=0)
print(f"‚úÖ Test Accuracy: {test_acc:.4f} ({test_acc*100:.2f}%)")

# Get predictions
y_true, y_pred = [], []
for images, labels in test_ds:
    preds = model.predict(images, verbose=0)
    y_true.append(tf.argmax(labels, axis=1))
    y_pred.append(tf.argmax(preds, axis=1))
y_true = np.concatenate(y_true); y_pred = np.concatenate(y_pred)

# Confusion matrix
cm = confusion_matrix(y_true, y_pred)
disp = ConfusionMatrixDisplay(cm, display_labels=CLASSES)
disp.plot(cmap='Blues')
plt.title("Confusion Matrix (Test Set)")
plt.show()

# Report
print("\nüìù Classification Report:")
print(classification_report(y_true, y_pred, target_names=CLASSES))

In [None]:
# üîç Grad-CAM for ResNet50 (last conv layer: 'conv5_block3_out')
def make_gradcam_heatmap(img_array, model, pred_index=None):
    grad_model = keras.models.Model(
        [model.inputs], 
        [model.get_layer('conv5_block3_out').output, model.output]
    )
    with tf.GradientTape() as tape:
        conv_outputs, predictions = grad_model(img_array)
        if pred_index is None:
            pred_index = tf.argmax(predictions[0])
        class_channel = predictions[:, pred_index]
    grads = tape.gradient(class_channel, conv_outputs)
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
    conv_outputs = conv_outputs[0]
    heatmap = conv_outputs @ pooled_grads[..., tf.newaxis]
    heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
    return heatmap.numpy()

# Run on 1 test image
for images, labels in test_ds.take(1):
    img = images[0:1]
    true_label = tf.argmax(labels[0]).numpy()
    break

preds = model.predict(img)
pred_idx = np.argmax(preds[0])
heatmap = make_gradcam_heatmap(img, model, pred_index=pred_idx)

# Create overlay
img_rgb = (img[0].numpy() * 255).astype(np.uint8)
heatmap_resized = cv2.resize(heatmap, (img_rgb.shape[1], img_rgb.shape[0]))
heatmap_jet = cv2.applyColorMap(np.uint8(255 * heatmap_resized), cv2.COLORMAP_JET)
overlay = cv2.addWeighted(img_rgb, 0.6, heatmap_jet, 0.4, 0)

# Display
plt.figure(figsize=(15, 4))
plt.subplot(1, 3, 1)
plt.imshow(img[0]); plt.title("OCT Scan"); plt.axis('off')
plt.subplot(1, 3, 2)
plt.imshow(heatmap_resized, cmap='jet'); plt.title("Grad-CAM Heatmap"); plt.axis('off')
plt.subplot(1, 3, 3)
plt.imshow(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)); plt.title("Overlay"); plt.axis('off')
plt.suptitle(f"Prediction: {CLASSES[pred_idx]} (True: {CLASSES[true_label]})", fontsize=14)
plt.show()

In [None]:
# üåü Patient Education Panel
from IPython.display import HTML

html_explainer = """
<div style="background:#e8f4fc; padding:20px; border-radius:10px; font-family:Arial, sans-serif; line-height:1.6; max-width:900px;">
<h2>üëÅÔ∏è What Does This Mean for You?</h2>
<p><strong>We show you exactly where the AI "looks"</strong> ‚Äî so you and your doctor can trust the result.</p>

<h3>üîç Understanding the 3 Images Above</h3>
<ul>
<li><b>Left:</b> Your actual OCT scan ‚Äî a cross-section of your retina.</li>
<li><b>Middle:</b> Heatmap ‚Äî <span style="color:red">red/yellow</span> = areas strongly linked to disease.</li>
<li><b>Right:</b> Combined view ‚Äî see <em>where</em> concerns are located.</li>
</ul>

<h3>üí° Clinical Guide</h3>
<table border="1" cellpadding="10" style="border-collapse:collapse; width:100%; text-align:left;">
<tr bgcolor="#d1e7ff"><th>AI Prediction</th><th>What Red Means</th><th>Your Next Step</th></tr>
<tr><td><b>CNV</b><br><small>(Wet AMD)</small></td><td>Leaky blood vessels under retina</td><td>Anti-VEGF injection often needed to prevent vision loss</td></tr>
<tr><td><b>DME</b><br><small>(Diabetic Swelling)</small></td><td>Fluid pockets in central retina</td><td>Control blood sugar + possible laser/injection</td></tr>
<tr><td><b>DRUSEN</b><br><small>(Early AMD)</small></td><td>Aging deposits under retina</td><td>Annual monitoring + healthy lifestyle (no smoking, leafy greens)</td></tr>
<tr><td><b>NORMAL</b></td><td>No concerning changes</td><td>Continue routine eye exams ‚úÖ</td></tr>
</table>

<blockquote style="background:#fff9db; padding:15px; border-left:4px solid #ffc107;">
<i>"This tool supports ‚Äî but does not replace ‚Äî your eye doctor‚Äôs expert judgment. Always discuss results with your ophthalmologist."</i>
</blockquote>
</div>
"""

display(HTML(html_explainer))

In [None]:
# üíæ Save final model
model.save("/kaggle/working/oct_resnet50_final.h5")
print("\n‚úÖ Model saved to: /kaggle/working/oct_resnet50_final.h5")
print("üéâ Project complete! Your explainable OCT detector is ready.")