In [1]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
import matplotlib.pyplot as plt
import os
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.applications.efficientnet import preprocess_input

## Paths

In [2]:
train_dir = '../data/AI-CA-Data/train'
valid_dir = '../data/AI-CA-Data/valid'
test_dir = '../data/AI-CA-Data/test'
(train_dir, valid_dir, test_dir)

('../data/AI-CA-Data/train',
 '../data/AI-CA-Data/valid',
 '../data/AI-CA-Data/test')

## Data Generators (match MobileNet)

In [3]:
img_size = (224, 224)
batch_size = 32

train_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input,
    horizontal_flip=True,
    rotation_range=20,
    zoom_range=0.2
)

valid_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input
)

train_data = train_datagen.flow_from_directory(
    train_dir,
    target_size=img_size,
    batch_size=batch_size,
    class_mode='categorical'
)

valid_data = valid_datagen.flow_from_directory(
    valid_dir,
    target_size=img_size,
    batch_size=batch_size,
    class_mode='categorical'
)

print('class_indices (ordered):', train_data.class_indices)
class_names = list(train_data.class_indices.keys())
print('class_names (ordered):', class_names)


Found 7946 images belonging to 70 classes.
Found 700 images belonging to 70 classes.
class_indices (ordered): {'Afghan': 0, 'African Wild Dog': 1, 'Airedale': 2, 'American Hairless': 3, 'American Spaniel': 4, 'Basenji': 5, 'Basset': 6, 'Beagle': 7, 'Bearded Collie': 8, 'Bermaise': 9, 'Bichon Frise': 10, 'Blenheim': 11, 'Bloodhound': 12, 'Bluetick': 13, 'Border Collie': 14, 'Borzoi': 15, 'Boston Terrier': 16, 'Boxer': 17, 'Bull Mastiff': 18, 'Bull Terrier': 19, 'Bulldog': 20, 'Cairn': 21, 'Chihuahua': 22, 'Chinese Crested': 23, 'Chow': 24, 'Clumber': 25, 'Cockapoo': 26, 'Cocker': 27, 'Collie': 28, 'Corgi': 29, 'Coyote': 30, 'Dalmation': 31, 'Dhole': 32, 'Dingo': 33, 'Doberman': 34, 'Elk Hound': 35, 'French Bulldog': 36, 'German Sheperd': 37, 'Golden Retriever': 38, 'Great Dane': 39, 'Great Perenees': 40, 'Greyhound': 41, 'Groenendael': 42, 'Irish Spaniel': 43, 'Irish Wolfhound': 44, 'Japanese Spaniel': 45, 'Komondor': 46, 'Labradoodle': 47, 'Labrador': 48, 'Lhasa': 49, 'Malinois': 50, '

## Build Model

In [4]:
num_classes = train_data.num_classes

base_model = EfficientNetB0(
    weights="imagenet",
    include_top=False,
    input_shape=(224, 224, 3)
)

base_model.trainable = False

x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dropout(0.5)(x)
outputs = Dense(num_classes, activation="softmax")(x)

model = Model(inputs=base_model.input, outputs=outputs)

## Compile

In [5]:
model.compile(
    optimizer="adam",
    loss="categorical_crossentropy",
    metrics=["accuracy"]
)


## Callbacks

In [6]:
callbacks = [
    EarlyStopping(
        monitor="val_loss",
        patience=3,
        restore_best_weights=True
    ),
    ModelCheckpoint(
        "saved_models/efficientnet_best_model.h5",
        monitor="val_loss",
        save_best_only=True
    )
]

## Train

In [7]:
history = model.fit(
    train_data,
    validation_data=valid_data,
    epochs=10,
    callbacks=callbacks
)


Epoch 1/10
[1m249/249[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - accuracy: 0.5034 - loss: 2.4197



[1m249/249[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m306s[0m 1s/step - accuracy: 0.7175 - loss: 1.4329 - val_accuracy: 0.9314 - val_loss: 0.5072
Epoch 2/10
[1m249/249[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 949ms/step - accuracy: 0.8895 - loss: 0.4635



[1m249/249[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m259s[0m 1s/step - accuracy: 0.8913 - loss: 0.4342 - val_accuracy: 0.9343 - val_loss: 0.4210
Epoch 3/10
[1m249/249[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 884ms/step - accuracy: 0.9156 - loss: 0.3322



[1m249/249[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m236s[0m 947ms/step - accuracy: 0.9151 - loss: 0.3239 - val_accuracy: 0.9457 - val_loss: 0.3985
Epoch 4/10
[1m249/249[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m206s[0m 825ms/step - accuracy: 0.9206 - loss: 0.2819 - val_accuracy: 0.9429 - val_loss: 0.4127
Epoch 5/10
[1m249/249[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m210s[0m 841ms/step - accuracy: 0.9314 - loss: 0.2422 - val_accuracy: 0.9471 - val_loss: 0.4077
Epoch 6/10
[1m249/249[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m285s[0m 1s/step - accuracy: 0.9333 - loss: 0.2267 - val_accuracy: 0.9486 - val_loss: 0.4012
