In [None]:
import tensorflow as tf
from tensorflow.keras.applications import VGG16
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import Dense, Input, GlobalAveragePooling2D
from tensorflow.keras.models import Model

In [None]:
from tensorflow.keras.callbacks import EarlyStopping


def create_model(img_width, img_height, num_classes):
    base_model = VGG16(include_top=False, weights="imagenet")
    input_layer = Input(shape=(img_width, img_height, 3))

    # Freeze the pre-trained layers
    for layer in base_model.layers:
        layer.trainable = False

    # Combine the input layer and the pre-trained model
    x = base_model(input_layer)
    x = GlobalAveragePooling2D()(x)

    # Add the output layer
    output_layer = Dense(units=num_classes, activation="softmax")(x)

    # Create and return the model
    model = Model(inputs=input_layer, outputs=output_layer)
    return model


img_width, img_height = 224, 224
# Define the model
model = create_model(img_width, img_height, 14)

model.compile(
    loss="categorical_crossentropy",
    optimizer="adam",
    metrics=["accuracy"],
)

train_data_path = "train"
validation_data_path = "valid"

# Resize images in the data generator
train_datagen = ImageDataGenerator(
    rescale=1.0 / 255,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True,
    shear_range=0.2,
    zoom_range=0.2,
)

train_generator = train_datagen.flow_from_directory(
    train_data_path,
    target_size=(img_width, img_height),
    batch_size=32,
    class_mode="categorical",
)
validation_datagen = ImageDataGenerator(rescale=1./255)

validation_generator = validation_datagen.flow_from_directory(
    validation_data_path,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical')

# Update validation data generator with the same size


early_stopping = EarlyStopping(monitor="val_loss", patience=5)

# Train the model with early stopping
model.fit(
    train_generator,
    steps_per_epoch=len(train_generator) // train_generator.batch_size,
    epochs=20,
    validation_data=validation_generator,
    validation_steps=len(
        validation_generator) // validation_generator.batch_size,
    callbacks=[early_stopping],
)