<a href="https://colab.research.google.com/github/EdduardaLara/embedded-img-classification/blob/main/cnn_plant_diseases.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from tensorflow.keras import backend as K
K.clear_session()

# Install some packages

In [2]:
pip install tensorflow-model-optimization

Collecting tensorflow-model-optimization
  Downloading tensorflow_model_optimization-0.8.0-py2.py3-none-any.whl.metadata (904 bytes)
Downloading tensorflow_model_optimization-0.8.0-py2.py3-none-any.whl (242 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m242.5/242.5 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tensorflow-model-optimization
Successfully installed tensorflow-model-optimization-0.8.0


# Mount drive

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Import packages

In [4]:
import numpy as np
import os
import matplotlib.pyplot as plt
from pathlib import Path

import tensorflow as tf
from tensorflow_model_optimization.python.core.keras.compat import keras

import shutil

In [5]:
# shutil.rmtree('/content/data/plantvillage')

In [6]:
# shutil.rmtree('/content/results/plant-village')

Explore some runtime ressources

In [7]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
print("List of GPUs Available: ", tf.config.list_physical_devices('GPU'))

Num GPUs Available:  0
List of GPUs Available:  []


# Global variables setting

In [8]:
PLANT_CULTURE = "tomato"

In [9]:
DOWNLOAD_DATA = True

In [10]:
DELETE_DATASET_FOLDER = True

In [11]:
DELETE_SOME_SUBFOLDERS = True

Paths

In [12]:
dataset_name = 'plant-village'

In [13]:
project_dir = os.getcwd()
data_dir: str = project_dir + '/' + 'data/' + dataset_name
print(f"Project dir: {project_dir} | Data dir: {data_dir}")

Project dir: /content | Data dir: /content/data/plant-village


In [14]:
results_dir = '/content/drive/MyDrive/IC_VANT/PlantVillage/results'
if not os.path.exists(results_dir):
    os.makedirs(results_dir)
print(f"Results dir: {results_dir}")

Results dir: /content/drive/MyDrive/IC_VANT/PlantVillage/results


Model hyperparameters

In [15]:
BATCH_SIZE = 4
EPOCHS = 10

Image parameters

In [16]:
IMAGE_HEIGHT = 256

In [17]:
IMAGE_WIDTH = 256

In [18]:
IMAGE_SIZE = (IMAGE_HEIGHT, IMAGE_WIDTH)

# Functions

In [19]:
def check_data_dir(silent_console: bool = True):
    """
    check if the data_dir exists

    :param:
    :return:
    """
    if os.path.exists(data_dir):
        print("Data_dir found !") if not silent_console else None
        return True
    else:
        print("The data_dir not found !") if not silent_console else None
        return False

## Dataset manip

In [20]:
def get_dataset_info(directory: str) -> int:
    """
    get the number of images in the dataset

    :param directory: str
    :return: int
    """
    dir_path = Path(directory)
    image_count = len(list(dir_path.glob('*/*.jpg')))
    image_count += len(list(dir_path.glob('*/*.JPG')))
    return image_count


In [21]:
def check_nb_of_data_in_dataset(dataset: tf.data.Dataset):
    """
    check the number of data in the dataset

    :param dataset: tf.data.Dataset
    :return:
    """
    nb_of_batches = dataset.cardinality().numpy()
    nb_of_data = nb_of_batches * BATCH_SIZE
    print(f"Nb of data: {nb_of_data} | Nb of batches: {nb_of_batches}")
    return None

In [22]:
def check_nb_of_classes_in_dataset(dataset: tf.data.Dataset):
    """
    check the number of classes in the dataset

    :param dataset: tf.data.Dataset
    :return:
    """
    class_names = dataset.class_names
    print(f"Nb of classes: {len(class_names)} | Class names: {class_names}")
    return None


In [23]:
def load_split_dataset(val_split: float, test_split: float, silent_console: bool = True):
    """
    load and split the dataset

    :return: tf.data.Dataset, tf.data.Dataset, tf.data.Dataset
    """
    # get training dataset
    eval_split = val_split + test_split
    train_ds = tf.keras.utils.image_dataset_from_directory(
        data_dir,
        validation_split=eval_split,
        subset="training",
        seed=123,
        image_size=IMAGE_SIZE,
        batch_size=BATCH_SIZE
    )

    # get data to eval (validation and test)
    val_ds = tf.keras.utils.image_dataset_from_directory(
        data_dir,
        validation_split=eval_split,
        subset="validation",
        seed=123,
        image_size=IMAGE_SIZE,
        batch_size=BATCH_SIZE
    )

    return train_ds, val_ds

In [24]:
def get_dataset_classes(dataset: tf.data.Dataset, dataset_type: str, silent_console: bool = True):
    """
    check the dataset classes

    :return:
    """
    class_names = dataset.class_names
    if not silent_console:
        print(f"Dataset: {dataset_type} | Nb of classes: {len(class_names)} | Class names: {class_names}")
    return class_names

In [25]:
def check_batch_size(dataset: tf.data.Dataset, dataset_type: str):
    """
    check the batch size of the dataset

    :return:
    """
    print(f"\n------ Checking batch size of the {dataset_type} dataset...")
    for image_batch, labels_batch in dataset:
        print(f"Image batch shape: {image_batch.shape}")
        print(f"Label batch shape: {labels_batch.shape}\n")
        break

In [26]:
def display_img_sample_of_dataset(dataset: tf.data.Dataset, dataset_type: str):
    """
    display a sample of images from the dataset

    :return:
    """
    plt.figure(figsize=(10, 10))
    class_names = dataset.class_names

    # Take one batch of images and create a subplot
    for images, labels in dataset.take(1):
        for i in range(9):
            ax = plt.subplot(3, 3, i + 1)  # Create the subplot
            ax.imshow(images[i].numpy().astype("uint8"))  # Show the image

            # Set the title on the subplot (not the entire plot)
            ax.set_title(f"{class_names[labels[i]]}", fontsize=8)
            ax.axis("off")  # Remove the axis labels

    # Adjust layout to prevent overlapping titles
    plt.subplots_adjust(top=0.9, bottom=0.1, left=0.1, right=0.9, hspace=0.3, wspace=0.3)

    # Add a title to the entire plot
    plt.suptitle(f"Sample of images from the {dataset_type} dataset", fontsize=16)

    # Save the plot
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    plt.savefig(results_dir + f"sample_images_{dataset_type}.png")

In [27]:
def set_prefetch(dataset: tf.data.Dataset):
    """
    set the prefetch for the dataset

    :param dataset: tf.data.Dataset
    :return: tf.data.Dataset
    """
    AUTOTUNE = tf.data.AUTOTUNE

    return dataset.shuffle(1000).prefetch(buffer_size=AUTOTUNE)

In [28]:
def normalize_dataset(dataset: tf.data.Dataset, silent_console: bool = True):
    """
    Normalize image data.
    RGB values are in the [0, 255] range, so we need to scale them to the [0, 1] range.

    :return:
    """
    print("Normalizing dataset...")
    normalizer_layer = keras.layers.Rescaling(1. / 255)
    normalized_ds = dataset.map(lambda x, y: (normalizer_layer(x), y))
    image_batch, labels_batch = next(iter(normalized_ds))

    # get the first image to check the pixel values
    if not silent_console:
        first_image = image_batch[0]
        print(f"First image pixel values -> Min: {np.min(first_image)} | Max: {np.max(first_image)} | "
              f"Shape: {first_image.shape}")
    return normalized_ds

In [29]:
def show_first_data_in_dataset(dataset: tf.data.Dataset, class_names: list):
    """
    Show the first data in the dataset

    :return:
    """
    print("\n---- Showing the first data in the dataset...")
    # get the first batch of data
    for image, label in dataset.take(1):
        # Show the first image and label
        first_img = image[0]
        first_label = label[0]
        print(f"First image shape: {first_img.shape} | First label: {first_label}")
        print(f"First image pixel values -> Min: {np.min(first_img)} | Max: {np.max(first_img)}")

        # Display the first image
        plt.figure()
        plt.imshow(first_img)
        plt.title(f"First image example | Label: {first_label} | Class name: {class_names[first_label]}")
        plt.grid(False)
        plt.show()

## Model

In [30]:
def data_augmentation():
    """
    Create a data augmentation layer

    :return:
    """
    data_augmentation = tf.keras.Sequential([
        keras.layers.RandomFlip("horizontal_and_vertical"),
        keras.layers.RandomRotation(0.2),
    ])
    return data_augmentation

In [31]:
def create_model(class_names: list, img_height: int, img_width: int):
    """
    Create a CNN model to classify image

    :return:
    """
    num_classes = len(class_names)
    image_shape = (img_height, img_width, 3)
    model = keras.Sequential([
        keras.layers.Input(shape=image_shape),
        #data_augmentation(),
        keras.layers.Reshape((img_height, img_width, 3)),
        keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu'),
        keras.layers.MaxPooling2D((2, 2)),
        keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu'),
        keras.layers.MaxPooling2D((2, 2)),
        keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu'),
        keras.layers.MaxPooling2D((2, 2)),
        keras.layers.Conv2D(64, (3, 3), activation='relu'),
        keras.layers.MaxPooling2D((2, 2)),
        keras.layers.Conv2D(64, (3, 3), activation='relu'),
        keras.layers.MaxPooling2D((2, 2)),
        keras.layers.Conv2D(64, (3, 3), activation='relu'),
        keras.layers.MaxPooling2D((2, 2)),
        keras.layers.Flatten(),
        keras.layers.Dense(64, activation='relu'),
        keras.layers.Dense(num_classes, activation='softmax'),
    ])

    return model


In [32]:
def compile_mode(model: tf.keras.Model):
    """
    Compile the model

    :return:
    """
    model.compile(optimizer='adam',
                  loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
                  metrics=['accuracy'])

    return model


In [33]:
def show_model_fit(hist: tf.keras.callbacks.History, nb_epochs: int):
    """
    Show the model fit

    :return:
    """
    acc = hist.history['accuracy']
    val_acc = hist.history['val_accuracy']

    loss = hist.history['loss']
    val_loss = hist.history['val_loss']

    epochs_range = range(nb_epochs)

    plt.figure(figsize=(8, 8))
    plt.subplot(1, 2, 1)
    plt.plot(epochs_range, acc, label='Training Accuracy')
    plt.plot(epochs_range, val_acc, label='Validation Accuracy')
    plt.legend(loc='lower right')
    plt.title('Training and Validation Accuracy')

    plt.subplot(1, 2, 2)
    plt.plot(epochs_range, loss, label='Training Loss')
    plt.plot(epochs_range, val_loss, label='Validation Loss')
    plt.legend(loc='upper right')
    plt.title('Training and Validation Loss')

    # Save the plot
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    plt.savefig(results_dir + "model_fit.png")

In [34]:
def save_model(model: tf.keras.Model, model_name: str, file_format: str = "keras"):
    """
    Save the model

    :return:
    """
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    model_file = model_name + "." + file_format
    model_path = results_dir + model_file
    model.save(model_path)

In [35]:
def create_checkpoint_weights_callback():
    """
    Create a checkpoint callback to save trained weights per epoch done.

    :return:
    """
    checkpoint_dir = results_dir + "/training_checkpoints"
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}.weights.h5")
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_prefix,
        save_weights_only=True
    )
    return checkpoint_callback, checkpoint_dir

In [36]:
def weights_files_key(filename):
    """
    Get the epoch number of each trained weight saved.

    Parameters
    ----------
    filename: str

    Returns
    -------
    nb of the epoch
    """
    base_name = filename.split('.')[0]
    epoch_number = base_name.split('_')[1]
    print(f"Epoch number: {epoch_number} | Base name: {base_name}")
    return int(epoch_number)

# Main code

## Download data

In [37]:
if DOWNLOAD_DATA:
  import kagglehub

  # download
  !kaggle datasets download -d emmarex/plantdisease -p /content/data/ --unzip

  #rename data folder
  !mv /content/data/PlantVillage /content/data/plant-village
else:
  print("Data already downloaded !")

Dataset URL: https://www.kaggle.com/datasets/emmarex/plantdisease
License(s): unknown
Downloading plantdisease.zip to /content/data
 99% 651M/658M [00:08<00:00, 126MB/s]
100% 658M/658M [00:08<00:00, 78.1MB/s]


In [38]:
if DELETE_DATASET_FOLDER:

  # Delete the folder and all its contents
  shutil.rmtree('/content/data/plantvillage')


Check subfolders in data_dir, delete all subdirectories with a culture different of our set culture

In [39]:
if DELETE_SOME_SUBFOLDERS:

  def delete_subfolders(data_dir, target_culture):
    for subdir in os.listdir(data_dir):
        subdir_path = os.path.join(data_dir, subdir)
        print(f"Checking subdirectory: {subdir_path}")
        if os.path.isdir(subdir_path) and PLANT_CULTURE in subdir.lower():
            print(f"Keeping subdirectory: {subdir_path}")
        else:
            # move subdir to plant-village-others
            path = data_dir + '-others'
            if not os.path.exists(path):
                os.makedirs(path)
            print(f"Moving subdirectory: {subdir_path} to {path}")
            shutil.move(subdir_path, os.path.join(path, 'plant-village-others'))
        print("\n")

  delete_subfolders(data_dir, PLANT_CULTURE)

Checking subdirectory: /content/data/plant-village/Tomato_Early_blight
Keeping subdirectory: /content/data/plant-village/Tomato_Early_blight


Checking subdirectory: /content/data/plant-village/Pepper__bell___Bacterial_spot
Moving subdirectory: /content/data/plant-village/Pepper__bell___Bacterial_spot to /content/data/plant-village-others


Checking subdirectory: /content/data/plant-village/Pepper__bell___healthy
Moving subdirectory: /content/data/plant-village/Pepper__bell___healthy to /content/data/plant-village-others


Checking subdirectory: /content/data/plant-village/Tomato_Leaf_Mold
Keeping subdirectory: /content/data/plant-village/Tomato_Leaf_Mold


Checking subdirectory: /content/data/plant-village/Tomato_Spider_mites_Two_spotted_spider_mite
Keeping subdirectory: /content/data/plant-village/Tomato_Spider_mites_Two_spotted_spider_mite


Checking subdirectory: /content/data/plant-village/Tomato_Bacterial_spot
Keeping subdirectory: /content/data/plant-village/Tomato_Bacterial_spo

## Load data

In [40]:
if not check_data_dir():
    raise ValueError("Data dir not found !")
nb_img_data = get_dataset_info(data_dir)
print(f"# {dataset_name.upper()} dataset contains {nb_img_data} images\n")

# PLANT-VILLAGE dataset contains 16010 images



In [41]:
train_dataset, val_dataset = load_split_dataset(0.2, 0.2, silent_console=True)

Found 16011 files belonging to 10 classes.
Using 9607 files for training.
Found 16011 files belonging to 10 classes.
Using 6404 files for validation.


### Check dataset classes

In [42]:
train_classes = get_dataset_classes(train_dataset, "train")
val_classes = get_dataset_classes(val_dataset, "validation")
if train_classes != val_classes:
    raise ValueError("The classes in the train and validation datasets are different")
else:
    print(f"Nb of class: {len(train_classes)} | Classes: {train_classes}\n")

Nb of class: 10 | Classes: ['Tomato_Bacterial_spot', 'Tomato_Early_blight', 'Tomato_Late_blight', 'Tomato_Leaf_Mold', 'Tomato_Septoria_leaf_spot', 'Tomato_Spider_mites_Two_spotted_spider_mite', 'Tomato__Target_Spot', 'Tomato__Tomato_YellowLeaf__Curl_Virus', 'Tomato__Tomato_mosaic_virus', 'Tomato_healthy']



### Display sample of images

Train img sample

In [43]:
check_batch_size(train_dataset, "train")
# display_img_sample_of_dataset(train_dataset, "train")


------ Checking batch size of the train dataset...
Image batch shape: (4, 256, 256, 3)
Label batch shape: (4,)



Validation img sample

In [44]:
check_batch_size(val_dataset, "validation")
# display_img_sample_of_dataset(val_dataset, "validation")


------ Checking batch size of the validation dataset...
Image batch shape: (4, 256, 256, 3)
Label batch shape: (4,)



## Data pre-processing and cleaning

### Prefetch data

In [45]:
print("\n---- Setting prefetch for the dataset...")
train_dataset = set_prefetch(train_dataset)
val_dataset = set_prefetch(val_dataset)


---- Setting prefetch for the dataset...


### Normalize data

In [46]:
print("\n---- Normalizing dataset...")
train_dataset_normalized = normalize_dataset(train_dataset)
val_dataset_normalized = normalize_dataset(val_dataset)


---- Normalizing dataset...
Normalizing dataset...
Normalizing dataset...


In [47]:
# show_first_data_in_dataset(train_dataset_normalized, train_classes)

## Create and fit a CNN model

### Visualize some augmented images

### Build model

In [48]:
print("\n---- Building the model...")
model = create_model(train_classes, IMAGE_HEIGHT, IMAGE_WIDTH)
model = compile_mode(model)


---- Building the model...


In [49]:
model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 reshape (Reshape)           (None, 256, 256, 3)       0         
                                                                 
 conv2d (Conv2D)             (None, 254, 254, 32)      896       
                                                                 
 max_pooling2d (MaxPooling2  (None, 127, 127, 32)      0         
 D)                                                              
                                                                 
 conv2d_1 (Conv2D)           (None, 125, 125, 64)      18496     
                                                                 
 max_pooling2d_1 (MaxPoolin  (None, 62, 62, 64)        0         
 g2D)                                                            
                                                                 
 conv2d_2 (Conv2D)           (None, 60, 60, 64)        3

### Train model

In [50]:
start_ep = -1
fit_model = True

Create callbacks to save fit by epoch

In [51]:
checkpoint_callback, checkpoint_dir = create_checkpoint_weights_callback()

In [52]:
print(f"Checkpoint_dir: {checkpoint_dir}")

Checkpoint_dir: /content/drive/MyDrive/IC_VANT/PlantVillage/results/training_checkpoints


Check entire model fit

In [53]:
if start_ep == -1:
  # check the entire model
  if os.path.exists(results_dir + "cnn_first_model.keras"):
    print("\n---- An entire model is already saved !")
    print("\n---- Loading the model...")
    model = keras.models.load_model(results_dir + f"cnn_model_ep={EPOCHS}.keras")
    fit_model = False
  else:
    start_ep = 0
    print("\n---- No entire model saved !")


---- No entire model saved !


Check checkpoints of epochs fit

In [54]:
if start_ep == 0:
  print(f"\n---- Checking ckpt results")

  if os.path.exists(checkpoint_dir):
    files = os.listdir(checkpoint_dir)
    if len(files) > 0:
      print(f"\n We found some ckpt model files !")

      highest_epoch_file = max(files, key=weights_files_key)
      highest_epoch = weights_files_key(highest_epoch_file)

      # get the last epoch fitting
      if highest_epoch < EPOCHS:
        start_ep = highest_epoch
        path_to_load_model = os.path.join(checkpoint_dir, "ckpt_{}.weights.h5".format(start_ep))
        print("\n---- Loading weights of: {}".format(path_to_load_model))
        model.load_weights(path_to_load_model)
        fit_model = True
      else:
        fit_model = False
        # load checkpoint fitting
        path_to_load_model = os.path.join(checkpoint_dir, "ckpt_{}.weights.h5".format(start_ep))
        print("\n---- Loading weights of: {}".format(path_to_load_model))
        model.load_weights(path_to_load_model)
        # save complete model
        save_model(model, f"cnn_model_ep={EPOCHS}")
        print("\n---- Complete model saved !")

    else:
      print(f"\n No ckpt model files found !")
      fit_model = True



---- Checking ckpt results

 No ckpt model files found !


Fit model

In [55]:
print(start_ep)

0


In [None]:
if fit_model:
  print("\n---- Fitting the model...")
  print(f"Start epoch: {start_ep}")
  print(f"End epoch: {EPOCHS}\n")

  # fit using GPU
  with tf.device('/device:GPU:0'):
    hist = model.fit(train_dataset_normalized,
                    validation_data=val_dataset_normalized,
                    epochs=EPOCHS,
                    initial_epoch=start_ep,
                    callbacks=[checkpoint_callback]
                    )


---- Fitting the model...
Start epoch: 0
End epoch: 10

Epoch 1/10
 248/2402 [==>...........................] - ETA: 15:59 - loss: 2.1979 - accuracy: 0.1925

In [None]:
for layer in model.layers:
    print(layer.name, layer.output.device)

In [None]:
if fit_model:
  show_model_fit(hist, EPOCHS)
  save_model(model, f"cnn_model_ep={EPOCHS}")

### Evaluate model

In [None]:
print("\n---- Evaluating the model...")

In [None]:
print("### Val dataset evaluation:")
val_score = model.evaluate(val_dataset_normalized)
print(f"Val accuracy: {val_score[1]*100.00:.2f} %")

In [None]:
print("### Train dataset evaluation:")
train_score = model.evaluate(train_dataset_normalized)
print(f"Train accuracy: {train_score[1]*100.00:.2f} %")

## Create a Quantization Aware Model

In [None]:
import tensorflow_model_optimization as tfmot

### Quantize layers

In [None]:
base_model = keras.models.clone_model(model)
base_model.set_weights(model.get_weights())
base_model = compile_mode(base_model)

Test model copy

In [None]:
print("### Train dataset evaluation (COPY MODEL):")
train_score = base_model.evaluate(train_dataset_normalized)
print(f"Train accuracy: {train_score[1]*100.00:.2f} %")

In [None]:
print(type(base_model))

In [None]:
print(tf.__version__)  # Check TensorFlow version
print(keras.__version__)  # Check Keras version

Quantize only the Dense, MaxPool2D, Conv2D Layers

In [None]:
quantize_model = tfmot.quantization.keras.quantize_model

Apply quantization-aware training to specific layers

In [None]:
quant_aware_model = quantize_model(base_model)

### Compile and fit the quantization model

Compile the quantization model

In [None]:
quant_aware_model.compile(optimizer='adam',
                          loss='sparse_categorical_crossentropy',
                          metrics=['accuracy'])

Fit the quantization model

In [None]:
quant_aware_model.fit(train_dataset, validation_data=val_dataset, epochs=EPOCHS)

### Save q-model

In [None]:
save_model(quant_aware_model, f"cnn_quant_aware_model_ep={EPOCHS}")

### Evaluate q-model

In [None]:
print("[INFO] Calculating Quant Aware model accuracy")
scores = quant_aware_model.evaluate(val_dataset)
print(f"Val Accuracy: {round(scores[1],4)*100}%")
scores_train = quant_aware_model.evaluate(train_dataset)
print(f"Train Accuracy: {round(scores_train[1],4)*100}%")

## TF to TFlite

Convert model to tensorflowlite model

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(quant_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

quantized_tflite_model = converter.convert()

Save TFlite model

In [None]:
quantized_tflite_model.save(results_dir + "q_cnn_model.tflite")