In [1]:
# %pip install tensorflow
# %pip install --upgrade pip

In [2]:
import tensorflow as tf
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from tensorflow.keras.utils import image_dataset_from_directory
import pickle

## Data Preprocessing

In [3]:
### Training set
training_set = image_dataset_from_directory(
    "dataset/Dataset - train+val+test/train",
    labels="inferred", # label names inferred from dataset structure
    label_mode="categorical", # catergories are CNV, DME, DRUSSEN, NORMAL << chnage with new categories
    class_names=None,
    color_mode="rgb",
    batch_size=32,
    image_size=(224, 224), # changed to fit MobileNet model
    shuffle=True,
    seed=None,
    validation_split=None,
    subset=None,
    interpolation="bilinear",
    follow_links=False,
    crop_to_aspect_ratio=False,
    pad_to_aspect_ratio=False,
    data_format=None,
    verbose=True,
)

Found 76515 files belonging to 4 classes.


In [4]:
### validation set
validation_set = image_dataset_from_directory(
    "dataset/Dataset - train+val+test/val",
    labels="inferred", # label names inferred from dataset structure
    label_mode="categorical", # catergories are CNV, DME, DRUSSEN, NORMAL << chnage with new categories
    class_names=None,
    color_mode="rgb",
    batch_size=32,
    image_size=(224, 224), # changed to fit MobileNet model
    shuffle=True,
    seed=None,
    validation_split=None,
    subset=None,
    interpolation="bilinear",
    follow_links=False,
    crop_to_aspect_ratio=False,
    pad_to_aspect_ratio=False,
    data_format=None,
    verbose=True,
)

Found 21861 files belonging to 4 classes.


In [5]:
training_set

<_PrefetchDataset element_spec=(TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 4), dtype=tf.float32, name=None))>

In [6]:
INPUT_SHAPE = (224,224,3)

## Model Training

In [7]:
mobnet = tf.keras.applications.MobileNetV3Large(
    input_shape=INPUT_SHAPE,
    alpha=1.0,
    minimalistic=False,
    include_top=True,
    weights="imagenet",
    input_tensor=None,
    classes=1000,
    pooling=None,
    dropout_rate=0.2,
    classifier_activation="softmax",
    include_preprocessing=True
    # name="MobileNetV3Large",
)

In [8]:
### Transfer learning
# mobnet.trainable = False # freeze pretrained layer, best if low training data

In [9]:
model = tf.keras.models.Sequential()

In [10]:
model.add(tf.keras.Input(shape=INPUT_SHAPE))

In [11]:
model.add(mobnet)

In [12]:
model.add(tf.keras.layers.Dense(units = 4, activation="softmax"))

In [13]:
metrics_list = ["accuracy", tf.keras.metrics.F1Score()] # monitor model performance

In [14]:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), loss="categorical_crossentropy", metrics=metrics_list)

In [15]:
model.summary()

In [16]:
training_history = model.fit(x=training_set, validation_data=validation_set, epochs=15)

Epoch 1/15
[1m2392/2392[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7112s[0m 3s/step - accuracy: 0.8796 - f1_score: 0.6702 - loss: 1.1314 - val_accuracy: 0.8933 - val_f1_score: 0.6881 - val_loss: 0.9641
Epoch 2/15
[1m2392/2392[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5779s[0m 2s/step - accuracy: 0.8959 - f1_score: 0.6929 - loss: 0.8256 - val_accuracy: 0.8957 - val_f1_score: 0.6924 - val_loss: 0.7147
Epoch 3/15
[1m2392/2392[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6213s[0m 3s/step - accuracy: 0.9443 - f1_score: 0.8980 - loss: 0.6237 - val_accuracy: 0.9517 - val_f1_score: 0.9133 - val_loss: 0.5528
Epoch 4/15
[1m2392/2392[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4554s[0m 2s/step - accuracy: 0.9583 - f1_score: 0.9284 - loss: 0.4816 - val_accuracy: 0.9516 - val_f1_score: 0.9188 - val_loss: 0.4354
Epoch 5/15
[1m2392/2392[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4283s[0m 2s/step - accuracy: 0.9631 - f1_score: 0.9368 - loss: 0.3748 - val_accuracy: 0.9611 - 

KeyboardInterrupt: 

In [None]:
#Save model
model.save("./Trained_Eye_disease_model.h5")
model.save("./Trained_Eye_disease_model.keras")



In [None]:
training_history.history

In [None]:
# Save History
with open("Training_history.pkl", "wb") as f:
    pickle.dump(training_history.history, f)