In [3]:
import tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt
import numpy as np
import os


# Set random seed for reproducibility
tf.random.set_seed(42)

# Define constants
IMG_SIZE = (128, 128)  # As per requirement, images should be 128x128
BATCH_SIZE = 32  # Smaller batch size for better generalization
EPOCHS = 50  # Maximum number of epochs, early stopping will prevent overfitting
LEARNING_RATE = 0.0001  # Small learning rate for fine-tuning

# Define paths to your data
train_dir = '/content/train/'
test_dir = '/content/test/'

# Data augmentation and preprocessing for training data
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    validation_split=0.2  # 20% of data will be used for validation
)

# Only rescaling for test data
test_datagen = ImageDataGenerator(rescale=1./255)

# Load and prepare the training data
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='binary',  # Binary classification for DME vs AMD
    subset='training',
    color_mode='rgb'  # Converts grayscale to RGB
)

# Load and prepare the validation data
validation_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='binary',  # Binary classification for DME vs AMD
    subset='validation',
    color_mode='rgb'  # Converts grayscale to RGB
)

# Load and prepare the test data
test_generator = test_datagen.flow_from_directory(
    test_dir,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='binary',  # Binary classification for DME vs AMD
    color_mode='rgb'  # Converts grayscale to RGB
)

# Load pre-trained ResNet50 model
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(128, 128, 3))

# Freeze the base model layers
for layer in base_model.layers:
    layer.trainable = False

# Add custom layers on top of ResNet50
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
output = Dense(1, activation='sigmoid')(x)  # Single neuron with sigmoid for binary classification

model = Model(inputs=base_model.input, outputs=output)

# Compile the model
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])

# Set up callbacks for early stopping and model checkpointing
early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
model_checkpoint = ModelCheckpoint('best_model.keras', save_best_only=True, monitor='val_accuracy')

# Train the model
history = model.fit(
    train_generator,
    steps_per_epoch=train_generator.samples // BATCH_SIZE,
    validation_data=validation_generator,
    validation_steps=validation_generator.samples // BATCH_SIZE,
    epochs=EPOCHS,
    callbacks=[early_stopping, model_checkpoint]
)

# Plot training and validation loss
plt.figure(figsize=(10, 5))
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig('loss_plot.png')
plt.close()

# Evaluate the model on the test set
test_loss, test_accuracy = model.evaluate(test_generator)
print(f"Test accuracy: {test_accuracy:.4f}")

# Generate predictions for the test set
test_generator.reset()
y_pred = model.predict(test_generator, steps=len(test_generator))
y_pred = (y_pred > 0.5).astype(int)
y_true = test_generator.classes

# Plot confusion matrix
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8, 6))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.colorbar()
class_labels = ['DME', 'AMD']
plt.xticks([0, 1], class_labels)
plt.yticks([0, 1], class_labels)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')

# Add text annotations to confusion matrix
thresh = cm.max() / 2.
for i in range(2):
    for j in range(2):
        plt.text(j, i, format(cm[i, j], 'd'),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

plt.tight_layout()
plt.savefig('confusion_matrix.png')
plt.close()

print("Training completed. Check 'loss_plot.png' and 'confusion_matrix.png' for visualizations.")

# Print class distribution in the test set
print("\nClass distribution in the test set:")
for class_name, class_index in test_generator.class_indices.items():
    class_count = np.sum(test_generator.classes == class_index)
    print(f"{class_name}: {class_count}")

# Print classification report
print("\nClassification Report:")
print(classification_report(y_true, y_pred, target_names=class_labels))

# Explanation of hyperparameter choices
print("\nHyperparameter Choices:")
print(f"Image Size: {IMG_SIZE} - As per requirement, images are scaled to 128x128")
print(f"Batch Size: {BATCH_SIZE} - A moderate batch size for better generalization")
print(f"Max Epochs: {EPOCHS} - Set high, but early stopping prevents overfitting")
print(f"Learning Rate: {LEARNING_RATE} - Small learning rate for fine-tuning pre-trained model")
print("Optimizer: Adam - Adaptive learning rate optimization algorithm, well-suited for transfer learning")

Found 10 images belonging to 3 classes.
Found 2 images belonging to 3 classes.
Found 12 images belonging to 3 classes.
Epoch 1/50


  self._warn_if_super_not_called()


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 17s/step - accuracy: 0.0000e+00 - loss: 0.9004 - val_accuracy: 0.0000e+00 - val_loss: 0.7850
Epoch 2/50
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 8s/step - accuracy: 0.0000e+00 - loss: 0.7357 - val_accuracy: 0.5000 - val_loss: 0.6061
Epoch 3/50
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 837ms/step - accuracy: 0.5000 - loss: 0.5975 - val_accuracy: 0.5000 - val_loss: 0.4658
Epoch 4/50
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step - accuracy: 0.5000 - loss: 0.4719 - val_accuracy: 0.5000 - val_loss: 0.3609
Epoch 5/50
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.5000 - loss: 0.3241 - val_accuracy: 0.5000 - val_loss: 0.2225
Epoch 6/50
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.5000 - loss: 0.1977 - val_accuracy: 0.5000 - val_loss: 0.0733
Epoch 7/50
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [6]:
! pip freeze > requirements.txt