## Load required packages

In [None]:
from typing import Tuple

from tflite_model_maker.image_classifier import DataLoader
from tflite_model_maker import image_classifier

import tensorflow as tf
import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline

## Create data loader

In [None]:
data = DataLoader.from_folder('Fruits_Dataset/train')
train_data, rest_data = data.split(0.8)
validation_data, test_data = rest_data.split(0.5)

## View some samples

In [None]:
plt.figure(figsize=(15, 10))
for i, (image, label) in enumerate(data.gen_dataset().unbatch().take(5)):
  plt.subplot(5, 5, i+1)
  plt.xticks([])
  plt.yticks([])
  plt.grid(False)
  plt.imshow(image.numpy(), cmap=plt.cm.gray)
  plt.xlabel(data.index_to_label[label.numpy()], color='white')
plt.show()

## Train a model

In [None]:
model = image_classifier.create(
    train_data,
    model_spec='efficientnet_lite0',
    epochs=1,
    validation_data=validation_data
)

## Test the model

In [None]:
loss, accuracy = model.evaluate(test_data)
predicts = model.predict_top_k(test_data, k=2)

In [None]:
def get_label_color(val1, val2):
  if val1 == val2:
    return 'white'
  else:
    return 'red'

plt.figure(figsize=(10, 10))

for i, (image, label) in enumerate(test_data.gen_dataset().unbatch().take(10)):
  ax = plt.subplot(2, 5, i+1)
  plt.xticks([])
  plt.yticks([])
  plt.grid(False)
  plt.imshow(image.numpy(), cmap=plt.cm.gray)
  predict_label = predicts[i][0][0]
  color = get_label_color(predict_label,
                          test_data.index_to_label[label.numpy()])
  ax.xaxis.label.set_color(color)
  plt.xlabel(predict_label)
plt.show()

## Export the model

In [None]:
model.export(
    export_dir='./models',
    tflite_filename='fruit_classifier.tflite'
)

## Predict on a single image

In [None]:
MODEL_PATH = './models/fruit_classifier.tflite'

def get_interpreter(model_path: str) -> Tuple:
    interpreter = tf.lite.Interpreter(model_path=model_path)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    
    return interpreter, input_details, output_details

def predict(image_path: str) -> int:
    interpreter, input_details, output_details = get_interpreter(MODEL_PATH)
    input_shape = input_details[0]['shape']
    img = tf.io.read_file(image_path)
    img = tf.io.decode_image(img, channels=3)
    img = tf.image.resize(img, (input_shape[2], input_shape[2]))
    img = tf.expand_dims(img, axis=0)
    resized_img = tf.cast(img, dtype=tf.uint8)
    
    interpreter.set_tensor(input_details[0]['index'], resized_img)
    interpreter.invoke()

    output_data = interpreter.get_tensor(output_details[0]['index'])
    results = np.squeeze(output_data)
    return np.argmax(results, axis=0)

In [None]:
predict('Fruits_Dataset/train/Tomato_Yellow/18_100_jpg.rf.b526089a2adc3ca43b2b8b6ae9ea7301.jpg')