### Train the model

In [None]:
import tensorflow as tf

model = tf.keras.Sequential(
    [
     tf.keras.layers.Rescaling(1./255),
     tf.keras.layers.Conv2D(32, 3, activation="relu"),
     tf.keras.layers.MaxPooling2D(),
     tf.keras.layers.Conv2D(32, 3, activation="relu"),
     tf.keras.layers.MaxPooling2D(),
     tf.keras.layers.Conv2D(32, 3, activation="relu"),
     tf.keras.layers.MaxPooling2D(),
     tf.keras.layers.Flatten(),
     tf.keras.layers.Dense(128, activation="relu"),
     tf.keras.layers.Dense(3)
    ]
)
model.compile(
    optimizer="adam",
    loss=tf.losses.SparseCategoricalCrossentropy(from_logits = True),
    metrics=['accuracy']
)

In [None]:
import matplotlib.pyplot as plt

In [None]:
img_height, img_width = 32, 32
batch_size = 20

train_ds = tf.keras.utils.image_dataset_from_directory(
    "fruits/train",
    image_size = (img_height, img_width),
    batch_size = batch_size
)
val_ds = tf.keras.utils.image_dataset_from_directory(
    "fruits/validation",
    image_size = (img_height, img_width),
    batch_size = batch_size
)
test_ds = tf.keras.utils.image_dataset_from_directory(
    "fruits/test",
    image_size = (img_height, img_width),
    batch_size = batch_size
)

In [None]:
class_names = ["apple", "banana", "orange"]
plt.figure(figsize=(10,10))
for images, labels in train_ds.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")

In [None]:
model.fit(
    train_ds,
    validation_data = val_ds,
    epochs = 100
)

In [None]:
model.evaluate(test_ds)

In [None]:
model.save('fruit')

In [None]:
import numpy

plt.figure(figsize=(10,10))
for images, labels in test_ds.take(1):
  classifications = model(images)
  print(classifications)
  
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    index = numpy.argmax(classifications[i])
    plt.title("Pred: " + class_names[index] + " | Real: " + class_names[labels[i]])

In [None]:
import tensorflow as tf
from tensorflow.keras.models import load_model

# Load the saved model
model = load_model('fruit')

# Step 1: Read the image file into a tensor
image_path = 'apple.jpg'  # Replace this with the path to your image file
image_data = tf.io.read_file(image_path)

# Step 2: Decode the image data into a tensor
image_tensor = tf.image.decode_jpeg(image_data, channels=3)

# Step 3: Resize the image to 32x32x3
image_tensor_resized = tf.image.resize(image_tensor, [32, 32])

# Step 4: Add a batch dimension
image_tensor_batch = tf.expand_dims(image_tensor_resized, axis=0)

output = model(image_tensor_batch)
output

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

with open("fruits.tflite", 'wb') as f:
  f.write(tflite_model)

### tflite inference

In [None]:
import numpy as np
import tensorflow as tf

# Load the TFLite model
interpreter = tf.lite.Interpreter(model_path='fruits.tflite')
interpreter.allocate_tensors()

# Get input and output tensors
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Step 1: Read the image file into a tensor
image_path = 'orange.jpg'  # Replace this with the path to your image file
image_data = tf.io.read_file(image_path)

# Step 2: Decode the image data into a tensor
image_tensor = tf.image.decode_jpeg(image_data, channels=3)

# Step 3: Resize the image to 32x32x3
image_tensor_resized = tf.image.resize(image_tensor, [32, 32])

# Step 4: Normalize the image data (if necessary)
# Depending on how the model was trained, you may need to normalize the image data

# Preprocess the image tensor according to the input tensor details
input_shape = input_details[0]['shape']
input_data = tf.cast(image_tensor_resized, dtype=np.float32)
input_data = np.expand_dims(input_data, axis=0)  # Add batch dimension if necessary

# Set the input tensor values
interpreter.set_tensor(input_details[0]['index'], input_data)

# Perform inference
interpreter.invoke()

# Get the output tensor
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
