In [1]:
# !pip install tensorflow==2.12
# !pip install tensorflow_io==0.23.1

In [2]:
import os
from dotenv import load_dotenv
from pymongo import MongoClient
import tensorflow as tf
from PIL import Image
import numpy as np
from MongoDBDataset import create_dataset
import datetime

load_dotenv()

True

In [3]:
client = MongoClient(os.environ["URI"])
db = client.flowers
N_CLASSES = len(db.test.distinct("category"))

In [4]:
train_dataset = create_dataset(db.train, batch_size=3, mode="train")
eval_dataset = create_dataset(db.test, batch_size=1, mode="test")

In [5]:
import keras.layers as layers


def create_model():
    input = layers.Input(name="flower_image", shape=(256, 256, 3), dtype=tf.float32)
    model_design = [
        layers.Rescaling(1 / 127.5, -1),
        # layers.Lambda(tf.image.resize_with_crop_or_pad, arguments=(256, 256)),
        layers.Conv2D(
            filters=64,
            kernel_size=(5, 5),
            strides=(2, 2),
            activation="relu",
            name="layers_conv2d_1",
        ),
        layers.MaxPooling2D(),  # 256 -> 128
        layers.Conv2D(
            filters=64,
            kernel_size=(3, 3),
            strides=(1, 1),
            activation="relu",
            name="layers_conv2d_2",
        ),
        layers.Conv2D(
            filters=64,
            kernel_size=(3, 3),
            strides=(1, 1),
            activation="relu",
            name="layers_conv2d_3",
        ),
        layers.Conv2D(
            filters=64,
            kernel_size=(3, 3),
            strides=(1, 1),
            activation="relu",
            name="layers_conv2d_4",
        ),
        layers.MaxPooling2D(),  # 128 -> 64
        layers.Conv2D(
            filters=64,
            kernel_size=(3, 3),
            strides=(1, 1),
            activation="relu",
            name="layers_conv2d_5",
        ),
        layers.MaxPooling2D(),  # 64 -> 32
        layers.Flatten(name="layers_flatten"),  # 32*32*64 = 65536
        layers.Dropout(0.2, name="layers_dropout"),
        layers.Dense(N_CLASSES, activation="softmax", name="layers_classifier"),
    ]
    output = input
    for layer in model_design:
        output = layer(output)

    return tf.keras.models.Model(inputs=input, outputs=output, name="FlowerFinder")

In [6]:
model = create_model()
model.summary()
# tf.keras.utils.plot_model(model, "my_first_model.png")

Model: "FlowerFinder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 flower_image (InputLayer)   [(None, 256, 256, 3)]     0         
                                                                 
 rescaling (Rescaling)       (None, 256, 256, 3)       0         
                                                                 
 layers_conv2d_1 (Conv2D)    (None, 126, 126, 64)      4864      
                                                                 
 max_pooling2d (MaxPooling2  (None, 63, 63, 64)        0         
 D)                                                              
                                                                 
 layers_conv2d_2 (Conv2D)    (None, 61, 61, 64)        36928     
                                                                 
 layers_conv2d_3 (Conv2D)    (None, 59, 59, 64)        36928     
                                                      

In [7]:
model.compile(
    loss=tf.keras.losses.CategoricalCrossentropy(),
    optimizer=tf.keras.optimizers.Adam(),
    metrics=["accuracy", tf.keras.metrics.CategoricalAccuracy(), tf.keras.metrics.Precision(), tf.keras.metrics.Recall()],
)
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir=log_dir, histogram_freq=1, update_freq="batch"
)

model.fit(train_dataset, epochs=10, callbacks=[tensorboard_callback])

Epoch 1/10
    584/Unknown - 322s 502ms/step - loss: 1.6986 - accuracy: 0.3884 - categorical_accuracy: 0.3884 - precision: 0.5422 - recall: 0.1284

In [None]:
model.evaluate(eval_dataset, callbacks=[tensorboard_callback])



[0.12605267763137817, 0.6836734414100647]

In [None]:
client.close()