In [2]:
# %pip install tensorflow
# %pip install --upgrade pip

In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from tensorflow.keras.utils import image_dataset_from_directory
import pickle

## Data Preprocessing

In [4]:
### Training set
training_set = image_dataset_from_directory(
    "dataset/Dataset - train+val+test/train",
    labels="inferred", # label names inferred from dataset structure
    label_mode="categorical", # catergories are CNV, DME, DRUSSEN, NORMAL << chnage with new categories
    class_names=None,
    color_mode="rgb",
    batch_size=32,
    image_size=(224, 224), # changed to fit MobileNet model
    shuffle=True,
    seed=None,
    validation_split=None,
    subset=None,
    interpolation="bilinear",
    follow_links=False,
    crop_to_aspect_ratio=False,
    pad_to_aspect_ratio=False,
    data_format=None,
    verbose=True,
)

Found 76515 files belonging to 4 classes.


In [5]:
### validation set
validation_set = image_dataset_from_directory(
    "dataset/Dataset - train+val+test/val",
    labels="inferred", # label names inferred from dataset structure
    label_mode="categorical", # catergories are CNV, DME, DRUSSEN, NORMAL << chnage with new categories
    class_names=None,
    color_mode="rgb",
    batch_size=32,
    image_size=(224, 224), # changed to fit MobileNet model
    shuffle=True,
    seed=None,
    validation_split=None,
    subset=None,
    interpolation="bilinear",
    follow_links=False,
    crop_to_aspect_ratio=False,
    pad_to_aspect_ratio=False,
    data_format=None,
    verbose=True,
)

Found 21861 files belonging to 4 classes.


In [6]:
training_set

<_PrefetchDataset element_spec=(TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 4), dtype=tf.float32, name=None))>

In [7]:
INPUT_SHAPE = (224,224,3)

## Model Training

In [8]:
mobnet = tf.keras.applications.MobileNetV3Large(
    input_shape=INPUT_SHAPE,
    alpha=1.0,
    minimalistic=False,
    include_top=True,
    weights="imagenet",
    input_tensor=None,
    classes=1000,
    pooling=None,
    dropout_rate=0.2,
    classifier_activation="softmax",
    include_preprocessing=True
    # name="MobileNetV3Large",
)

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v3/weights_mobilenet_v3_large_224_1.0_float.h5
[1m22661472/22661472[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 0us/step


In [None]:
### Transfer learning
# mobnet.trainable = False # freeze pretrained layer, best if low training data

In [9]:
model = tf.keras.models.Sequential()

In [None]:
model.add(tf.keras.Input(shape=INPUT_SHAPE))

In [10]:
model.add(mobnet)

In [12]:
model.add(tf.keras.layers.Dense(units = 4, activation="softmax"))

In [13]:
metrics_list = ["accuracy", tf.keras.metrics.F1Score()] # monitor model performance

In [14]:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), loss="categorical_crossentropy", metrics=metrics_list)

In [15]:
model.summary()

In [None]:
training_history = model.fit(x=training_set, validation_data=validation_set, epochs=15)

In [16]:
#Save model
model.save("./Trained_Eye_disease_model.h5")
model.save("./Trained_Eye_disease_model.keras")



In [None]:
training_history.history

In [None]:
# Save History
with open("Training_history.pkl", "wb") as f:
    pickle.dump(training_history.history, f)