<a href="https://colab.research.google.com/github/LeoFades/COS30049_SmartPlant_Sarawak/blob/main/plantclassifier/PlantRecog_Train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

1.  Required Libraries

Import all required libraries. We're using EfficientNetB0 for lightweight training pipeline as well as classifier for less latency while maintaining higher accuracy, however backbone can be upscaled to B1, B2 and so forth if needed.


In [1]:
import tensorflow as tf
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras import layers, models
import os
IMG_SIZE = (224, 224)

# Only allow these extensions
VALID_EXTS = (".jpg", ".jpeg", ".png")

def list_valid_files(data_dir):
    file_paths = []
    class_names = sorted(next(os.walk(data_dir))[1])  # subfolders = class names
    class_to_index = {name: idx for idx, name in enumerate(class_names)}

    for class_name in class_names:
        class_dir = os.path.join(data_dir, class_name)
        for fname in os.listdir(class_dir):
            if fname.lower().endswith(VALID_EXTS):
                file_paths.append((os.path.join(class_dir, fname), class_to_index[class_name]))
            else:
                print("Skipping non-image:", fname)
    return file_paths, class_names

def decode_img(path, label):
    try:
        img = tf.io.read_file(path)
        img = tf.image.decode_image(img, channels=3, expand_animations=False)
        img = tf.image.resize(img, IMG_SIZE) / 255.0  # normalize to [0,1]
        return img, tf.one_hot(label, depth=num_classes)
    except:
        tf.print("Corrupted:", path)
        return tf.zeros((224,224,3)), tf.one_hot(label, depth=num_classes)

2. Drive Connection

Connect to Google Drive for dataset.







In [2]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


3. Data Preparation

Next, data preparation and cleaning. Load datasets and normalise datasets along with any other techniques deemed appropriate.


In [3]:
# Load datasets
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "/content/drive/MyDrive/COS30049 - PlantRecog/dataset/Training", image_size=(224, 224), batch_size=32, label_mode="categorical"
)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "/content/drive/MyDrive/COS30049 - PlantRecog/dataset/Validation", image_size=(224, 224), batch_size=32, label_mode="categorical"
)

# get number of classes correctly
num_classes = len(train_ds.class_names)  # <-- correct way

print(num_classes)

# Normalization layer
normalization_layer = layers.Rescaling(1./255)

# Augmentation pipeline (can be tweaked if model underfits due to low data volume)
data_augmentation = tf.keras.Sequential([
    layers.RandomFlip("horizontal"),       # random horizontal flip
    layers.RandomRotation(0.2),            # rotate up to ±20%
    layers.RandomZoom(0.2),                # zoom in/out
    layers.RandomContrast(0.2),            # change contrast
])

# Apply augmentation only on training data
train_ds = train_ds.map(lambda x, y: (data_augmentation(normalization_layer(x)), y))
val_ds   = val_ds.map(lambda x, y: (normalization_layer(x), y))

# Performance optimizations
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds   = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

Found 838 files belonging to 4 classes.
Found 80 files belonging to 4 classes.


2.1 Troubleshooting (?)

For when the code blows up.

In [9]:
print("Train batches:", train_ds.cardinality().numpy())
print("Val batches:", val_ds.cardinality().numpy())

print("GPUs:", tf.config.list_physical_devices("GPU"))

from PIL import Image
import os

data_dir = "/content/drive/MyDrive/COS30049 - PlantRecog/dataset/Validation"
bad_files = []
for root, dirs, files in os.walk(data_dir):
    for fname in files:
        fpath = os.path.join(root, fname)
        try:
            img = Image.open(fpath)    # try open
            img.verify()               # verify integrity
        except Exception as e:
            print(f"Corrupted: {fpath} ({e})")
            bad_files.append(fpath)

print("Bad files:", bad_files)

import os
for root, dirs, files in os.walk(data_dir):
    for f in files:
        if not f.lower().endswith((".jpg", ".jpeg", ".png")):
            print("Non-image file:", os.path.join(root, f))


Train batches: 27
Val batches: 3
GPUs: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
Bad files: []


4. Model Training

This step will then commence model training with the base as EfficientNetB0 and weights from ImageNet while not including the head classifier, which is defined ourself.


In [None]:
base = EfficientNetB0(weights="imagenet", include_top=False, input_shape=(224,224,3))
base.trainable = False

x = layers.GlobalAveragePooling2D()(base.output)
x = layers.Dropout(0.3)(x)
out = layers.Dense(num_classes, activation="softmax")(x)

model = models.Model(base.input, out)
model.compile(optimizer="adam",
              loss="categorical_crossentropy",
              metrics=["accuracy"])

# Callbacks (recommended)
callbacks = [
    tf.keras.callbacks.ModelCheckpoint("best_model.h5", save_best_only=True),
    tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True)
]

model.fit(train_ds, validation_data=val_ds, epochs=10, callbacks=callbacks)

# Fine-tune: unfreeze last few layers or whole base
# Option A: Unfreeze whole base (careful with memory / lr)
base.trainable = True
model.compile(optimizer=tf.keras.optimizers.Adam(1e-5),
              loss="categorical_crossentropy",
              metrics=["accuracy"])
model.fit(train_ds, validation_data=val_ds, epochs=10, callbacks=callbacks)


Epoch 1/10


5. Progress Saving

Once satisfied with the results, hit this to update the whole model. This includes saving the model trained, architecture, weights and optimiser.

In [None]:
# Save entire model (recommended for checkpointing)
model.save("/content/drive/MyDrive/plant_identifier/efficientnet_model")


6.1 Model Testing

Use this block to tweak paths to images for testing or loading the whole model to be used in an application. Ideally should be used to be deployed on a Gradio site for quick testing before integrating into the mobile app.

In [None]:
import numpy as np
from tensorflow.keras.preprocessing import image

# Load model
model = keras.models.load_model("/content/drive/MyDrive/plant_identifier/efficientnet_model")

# Load a test image
img_path = "/content/drive/MyDrive/test_images/mango1.jpg"
img = image.load_img(img_path, target_size=(224, 224))
img_array = image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)  # add batch dimension

# Normalize to match training
img_array = img_array / 255.0

# Predict
preds = model.predict(img_array)
class_names = train_ds.class_names  # assumes you kept this list saved
print("Prediction:", class_names[np.argmax(preds)])


6.2 Model Checkpoint Loading

Use this block to load model and continue retraining (in progress).

In [None]:
import numpy as np
from tensorflow.keras.preprocessing import image

# Load model
model = keras.models.load_model("/content/drive/MyDrive/plant_identifier/efficientnet_model")

# Load a test image
img_path = "/content/drive/MyDrive/test_images/mango1.jpg"
img = image.load_img(img_path, target_size=(224, 224))
img_array = image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)  # add batch dimension

# Normalize to match training
img_array = img_array / 255.0

# Predict
preds = model.predict(img_array)
class_names = train_ds.class_names  # assumes you kept this list saved
print("Prediction:", class_names[np.argmax(preds)])
