In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout,Flatten
from tensorflow.keras.applications import MobileNet

import onnxruntime as rt
import onnx
import tf2onnx

import matplotlib.pyplot as plt
import numpy as np
import cv2

In [None]:
DATASET_NAME = 'cats_vs_dogs'
BUFFER_SIZE = 1000
BATCH_SIZE = 64
IMG_WIDTH = 224
IMG_HEIGHT = 224
LABEL = ["Cat", "Dog"]

In [None]:
def load_and_preprocess_data(dataset_name):
    (raw_train, raw_validation, raw_test), metadata = tfds.load(
        dataset_name,
        split=['train[:50%]', 'train[80%:90%]', 'train[90%:]'],
        with_info=True,
        as_supervised=True,
    )
    print(metadata)

    def preprocess_image(image, label):
        image = tf.image.convert_image_dtype(image, tf.float32)
        image = tf.image.resize(image, [IMG_WIDTH, IMG_HEIGHT])
        return image, label

    train = raw_train.map(preprocess_image).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
    validation = raw_validation.map(preprocess_image).batch(BATCH_SIZE)
    test = raw_test.map(preprocess_image).batch(BATCH_SIZE)

    return train, validation, test


train_dataset, validation_dataset, test_dataset = load_and_preprocess_data(DATASET_NAME)

In [None]:
mobilenat = MobileNet(input_shape=(IMG_WIDTH, IMG_HEIGHT, 3), include_top=False, weights='imagenet')

In [None]:
model = Sequential([
    mobilenat,
    Flatten(),

    Dense(100, activation='relu'),
    Dropout(0.2),

    Dense(2, activation="sigmoid"),

])

In [None]:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['accuracy'])

In [None]:
history = model.fit(train_dataset, epochs=5, validation_data=validation_dataset)

In [None]:
loss, accuracy = model.evaluate(test_dataset)
print("Test Loss:", loss)
print("Test Accuracy:", accuracy)

In [None]:
model.save('drive/MyDrive/cats_vs_dogs.h5')

In [None]:
onnx_model, _ = tf2onnx.convert.from_keras(model)

In [None]:
onnx.save(onnx_model, './drive/MyDrive/cats_vs_dogs_model.onnx')

In [None]:
session = rt.InferenceSession('./drive/MyDrive/cats_vs_dogs_model.onnx')

In [None]:
input_name = session.get_inputs()[0].name
print(input_name)
output_name = session.get_outputs()[0].name
print(output_name)

In [None]:
first_element = test_dataset.take(1)
for element in first_element:
    image, label = element
    print(image.shape, label.shape)
    break

In [None]:
test_img = image[0].numpy()
test_img

In [None]:
test_label = label[0].numpy()
test_label

In [None]:
plt.imshow(test_img)
plt.title(f"Label : {test_label}")
plt.show()

In [None]:
def onnxPred(input : np.ndarray):
    test = cv2.resize(input, (224, 224)).astype(np.float32)
    test = np.expand_dims(test, axis=0)
    res = session.run([output_name], {input_name : test})
    return np.argmax(res)

In [None]:
print(LABEL[onnxPred(test_img)])