In [None]:
# Step 1: Imports
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import EfficientNetV2B3
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.optimizers import Adam
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns

# Step 2: Paths
base_dir = 'aptos-augmented-images'
train_dir = os.path.join(base_dir, 'train/retina')
valid_dir = os.path.join(base_dir, 'valid/retina')
test_dir = os.path.join(base_dir, 'test')

# Step 3: Parameters
IMG_SIZE = (300, 300)
BATCH_SIZE = 32
NUM_CLASSES = 5
EPOCHS = 10

# Step 4: Optimized Data Pipeline using tf.data
AUTOTUNE = tf.data.AUTOTUNE

train_ds = tf.keras.utils.image_dataset_from_directory(
    train_dir,
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    label_mode='categorical'
)

valid_ds = tf.keras.utils.image_dataset_from_directory(
    valid_dir,
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    label_mode='categorical'
)

test_ds = tf.keras.utils.image_dataset_from_directory(
    test_dir,
    image_size=IMG_SIZE,
    batch_size=1,
    label_mode='categorical',
    shuffle=False
)

# Optional: Normalize pixel values to [0, 1]
normalization_layer = tf.keras.layers.Rescaling(1./255)

train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
valid_ds = valid_ds.map(lambda x, y: (normalization_layer(x), y))
test_ds = test_ds.map(lambda x, y: (normalization_layer(x), y))

# Add prefetching for performance
train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)
valid_ds = valid_ds.prefetch(buffer_size=AUTOTUNE)
test_ds = test_ds.prefetch(buffer_size=AUTOTUNE)

# Step 5: Build the Model
base_model = EfficientNetV2B3(include_top=False, weights='imagenet', input_shape=(300, 300, 3))
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dropout(0.3)(x)
output = Dense(NUM_CLASSES, activation='softmax')(x)

model = Model(inputs=base_model.input, outputs=output)

for layer in base_model.layers:
    layer.trainable = False  # Freeze base

model.compile(optimizer=Adam(1e-4), loss='categorical_crossentropy', metrics=['accuracy'])

model.summary()

# Step 6: Train the Model
history = model.fit(
    train_ds,
    validation_data=valid_ds,
    epochs=EPOCHS
)


# Step 7: Unfreeze and Fine-tune
for layer in base_model.layers:
    layer.trainable = True

model.compile(optimizer=Adam(1e-5), loss='categorical_crossentropy', metrics=['accuracy'])

fine_tune_epochs = 10
total_epochs = EPOCHS + fine_tune_epochs

history_fine = model.fit(
    train_ds,
    validation_data=valid_ds,
    epochs=total_epochs,
    initial_epoch=history.epoch[-1]
)


# Step 8: Plot Curves
def plot_curves(hist):
    acc = hist.history['accuracy']
    val_acc = hist.history['val_accuracy']
    loss = hist.history['loss']
    val_loss = hist.history['val_loss']
    
    plt.figure(figsize=(12, 5))
    plt.subplot(1,2,1)
    plt.plot(acc, label='Train Acc')
    plt.plot(val_acc, label='Val Acc')
    plt.legend()
    plt.title("Accuracy")
    
    plt.subplot(1,2,2)
    plt.plot(loss, label='Train Loss')
    plt.plot(val_loss, label='Val Loss')
    plt.legend()
    plt.title("Loss")
    
    plt.show()

plot_curves(history_fine)

# Step 9: Evaluate on Test Set
Y_pred = model.predict(test_ds)
y_pred = np.argmax(Y_pred, axis=1)

# Extract true labels from test_ds
true_labels = np.concatenate([y for x, y in test_ds], axis=0)
true_classes = np.argmax(true_labels, axis=1)

print("Classification Report:")
print(classification_report(true_classes, y_pred))

# Confusion Matrix
cm = confusion_matrix(true_classes, y_pred)
class_names = train_ds.class_names

plt.figure(figsize=(6,6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_names,
            yticklabels=class_names)
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.show()


# Step 10: Save Model
model.save('efficientnetv2b3_dr_grading.h5')


Found 10000 files belonging to 5 classes.
Found 10000 files belonging to 5 classes.
Found 9433 files belonging to 5 classes.


Epoch 1/10
[1m 13/313[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m30:13[0m 6s/step - accuracy: 0.1972 - loss: 1.6564