In [None]:
import os

import tensorflow as tf

import matplotlib.pyplot as plt
%matplotlib inline

##### Load the data

In [None]:
train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale=1./255,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True,
    vertical_flip=True,
    zoom_range=0.2
)

train_generator = train_datagen.flow_from_directory(
    os.path.join(
        'dataset',
        'train'
    ),
    target_size=(256, 256),
    color_mode='rgb',
    class_mode='categorical'
)

In [None]:
test_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale=1./255
)

valid_generator = test_datagen.flow_from_directory(
    os.path.join(
        'dataset',
        'valid'
    ),
    target_size=(256, 256),
    color_mode='rgb',
    class_mode='categorical'
)

test_generator = test_datagen.flow_from_directory(
    os.path.join(
        'dataset',
        'test'
    ),
    target_size=(256, 256),
    color_mode='rgb',
    class_mode=None
)

In [None]:
plt.figure(figsize=(20, 9))

for i in range(6):
    plt.subplot(2, 3, i+1)
    plt.imshow(valid_generator[0][0][i])
    plt.title(label=f'Label - {np.argmax(valid_generator[0][1][i])}')
    plt.grid(visible=False)
plt.tight_layout()
plt.show()

##### Build a model

In [None]:
model = tf.keras.models.Sequential(
    [
        tf.keras.Input(shape=(256, 256, 3)),
        tf.keras.layers.Conv2D(filters=2, kernel_size=3),
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(6, activation='softmax')
    ]
)

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    loss=tf.keras.losses.CategoricalCrossentropy(),
    metrics=[
        tf.keras.metrics.CategoricalAccuracy()
    ]
)

model.summary()

##### Train the model

In [None]:
history = model.fit(
    train_generator,
    epochs=2,
    validation_data=valid_generator
)

##### See the training history

In [None]:
print(history.history)

In [None]:
plt.plot(history.history['categorical_accuracy'])
plt.plot(history.history['val_categorical_accuracy'])
plt.title('Model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

##### Predict using a image generator

In [None]:
predictions = model.predict(test_generator)
predictions

##### Save the model

In [None]:
model.save()

##### Load the model

In [None]:
new_model = tf.keras.models.load_model()

In [None]:
new_model.summary()

##### Predict on a single image

In [None]:
image_path = 'dataset\\test\\test\\3.jpg'

image = tf.keras.preprocessing.image.load_img(image_path)
image = tf.keras.preprocessing.image.img_to_array(image)
image.resize((256, 256, 3))
image = tf.expand_dims(image, axis=0)

prediction = new_model.predict(image)

prediction