# Transfer Learning with ResNet50 for Plant Disease Classification


## 1. Imports
We import the main libraries used in this notebook: TensorFlow/Keras for the CNN, NumPy for arrays, and Matplotlib for plots.


In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import models
import matplotlib.pyplot as plt

from tensorflow.keras.applications.resnet50 import preprocess_input


## 2. Basic configuration
Set the key hyperparameters: image size, batch size, number of channels (RGB), and training epochs.


In [None]:
image_size = (256) # 256x256 pixels
batch_size = 16 # images per batch
channels = 3 # RGB
epochs = 50 # number of training epochs


## 3. Load the dataset and exploratory data analysis 
Loads images from your `PlantVillage` directory. Each subfolder name becomes a class label automatically.


In [None]:
# Load images from a folder structure into a tf.data.Dataset.
# Expected folder layout:
#   ../data/raw/PlantVillage/
#       class_1/  (images...)
#       class_2/  (images...)
#       ...
# Each subfolder name becomes the class label automatically.
dataset = tf.keras.utils.image_dataset_from_directory(
    directory = "../data/raw/PlantVillage",     # root folder containing one subfolder per class
    image_size = (image_size, image_size),      # resize every image to this fixed size (H, W)
    batch_size = batch_size                     # how many images per batch returned by the dataset
)

# Result:
# - dataset yields batches of (images, labels)
# - images shape: (batch_size, image_size, image_size, 3)
# - labels are integer class IDs (e.g., 0..num_classes-1)


### Class names
Keras assigns an integer label to each class based on folder order. `class_names` stores the mapping.


In [None]:
# Get the list of class (label) names inferred from the subfolder names
# inside the PlantVillage directory.
# Example: ["Apple___Black_rot", "Apple___healthy", ...]
class_names = dataset.class_names

# Display / print the class names (in notebooks, the last line shows the value)
class_names 



### Inspect one batch
Checks the tensor shapes coming from the dataset and prints the label ids in that batch.


In [None]:
for image_batch, label_batch in dataset.take(1):
    print(image_batch.shape)   # image batch shape
    print(label_batch.numpy()) # labels as numpy


### Inspect one image shape
Shows the shape of a single image tensor (height, width, channels).


In [None]:
for image_batch, label_batch in dataset.take(1):
    print(image_batch[0].shape)


### Inspect raw pixel values
Prints raw pixel values for one image (before rescaling).


In [None]:
for image_batch, label_batch in dataset.take(1):
    print(image_batch[0].numpy())  # pixel values (array) of the first image in the batch



### Visualize a sample image
Displays one example image from the dataset.


In [None]:
for image_batch, label_batch in dataset.take(1):
    plt.imshow(image_batch[0].numpy().astype("uint8"))  # show first image (convert to uint8 for display)
    plt.axis("off")                                      # hide axes


### Label id → class name mapping
Prints the numeric label id for each class. **Note:** this cell previously contained extra code; it has been cleaned to only show the mapping.


In [None]:
for label_id, class_name in enumerate(dataset.class_names):
    print(f"{label_id} : {class_name}")  # print: class_index : class_name



### Collect one sample per class
Unbatches the dataset and stores the first image seen for each class. This is useful for later visualization.


In [None]:
import math

class_names = dataset.class_names                 # class names (from folder names)
num_classes = len(class_names)                    # number of classes

samples = {}                                      # store 1 sample image per class

for img, label in dataset.unbatch():              # iterate image-by-image (not in batches)
    label_id = int(label.numpy())                 # tensor -> int
    if label_id not in samples:
        samples[label_id] = img                   # keep the first image for this class
    if len(samples) == num_classes:
        break                                     # stop after collecting all classes

cols = 3
rows = math.ceil(num_classes / cols)              # rows needed for the grid

plt.figure(figsize=(4*cols, 4*rows))
for i, label_id in enumerate(sorted(samples.keys())):
    ax = plt.subplot(rows, cols, i + 1)           # position in the grid

    img = samples[label_id].numpy().astype("uint8")  # convert to display format
    plt.imshow(img)                               # show image
    plt.title(f"{label_id} → {class_names[label_id]}", fontsize=10)  # label + name
    plt.axis("off")                               # hide axes

plt.tight_layout()                                # nicer spacing
plt.show()                                        # render the figure


## 4. Train/validation/test split
Splits a `tf.data.Dataset` into train/val/test using `take()` and `skip()`. This assumes the dataset has a known length.


In [None]:
def spliting_the_data(ds, train_split=0.80, val_split=0.10, test_split=0.10, shuffle=True, shuffle_size=1000):
    ds_size = len(ds)                          # total number of batches/elements in ds

    if shuffle:
        ds = ds.shuffle(shuffle_size, seed=12) # shuffle before splitting

    train_size = int(train_split * ds_size)    # number of items for train
    val_size   = int(val_split * ds_size)      # number of items for val

    train_ds = ds.take(train_size)             # first part -> train
    val_ds   = ds.skip(train_size).take(val_size)  # next part -> val
    test_ds  = ds.skip(train_size).skip(val_size)  # remaining -> test

    return train_ds, val_ds, test_ds           # return the 3 datasets


### Create the splits
Runs the split function and produces `train_ds`, `val_ds`, and `test_ds`.


In [None]:
train_ds, val_ds, test_ds = spliting_the_data(dataset)  # split dataset into train/val/test



### Check split sizes
Shows the number of batches in each split (may be `unknown` in some pipelines).


In [None]:
len(train_ds), len(val_ds), len(test_ds)  # number of batches/items in each split

## 5. Preprocessing and data augmentation
- **Resizing** ensures all images have the same shape.
- **Rescaling** normalizes pixels to `[0, 1]`.
- **Augmentation** creates random flips/rotations to improve generalization.


In [None]:
# For ResNet50 transfer learning, we keep the pixel range 0..255 (no 1/255 scaling),
# because tf.keras.applications.resnet50.preprocess_input expects that range.

resize_only = tf.keras.Sequential([
    layers.Resizing(image_size, image_size),   # resize to fixed size
], name="resize_only")

@tf.keras.utils.register_keras_serializable(package="Custom")
class ResNet50Preprocess(layers.Layer):
    # Serializable preprocessing layer for ResNet50
    def call(self, inputs):
        # inputs expected in RGB with values in [0, 255]
        return preprocess_input(inputs)

    def get_config(self):
        return super().get_config()

resnet50_preprocess = ResNet50Preprocess(name="resnet50_preprocess")

data_augmentation = tf.keras.Sequential([
    layers.RandomFlip('horizontal_and_vertical'),  # random flips
    layers.RandomRotation(0.2)                     # random rotation
], name="data_augmentation")


## 6. Build the ResNet50 model (transfer learning)


### Model architecture (ResNet50 backbone + custom head)


In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models

n_classes = len(class_names)
input_shape = (image_size, image_size, 3)

# ✅ ResNet50 backbone (no name= argument)
base_model = tf.keras.applications.ResNet50(
    weights="imagenet",
    include_top=False,
    input_shape=input_shape
)
base_model.trainable = False

# ✅ IMPORTANT: ResNet50 expects its own preprocessing (not /255)
preprocess = tf.keras.applications.resnet50.preprocess_input

model = models.Sequential([
    layers.Input(shape=input_shape),
    resize_only,                                  # your resize layer
    layers.Lambda(preprocess, name="resnet50_preprocess"),
    base_model,
    layers.GlobalAveragePooling2D(),
    layers.Dropout(0.25),
    layers.Dense(n_classes, activation="softmax"),
], name="ResNet50_transfer_learning")

model.summary()


### Compile the model
Sets the optimizer, loss function, and metrics. `SparseCategoricalCrossentropy` is correct when labels are integer-encoded.


In [None]:
model.compile(
    optimizer='adam',   # training algorithm (Adam optimizer)
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),  # loss for integer labels + softmax output
    metrics=['accuracy']  # track accuracy during training/eval
)

model.summary()  # show model architecture + parameters



### Early stopping 

In [None]:
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau  # training callbacks

callbacks = [
    tf.keras.callbacks.EarlyStopping(
        monitor="val_loss",              # watch validation loss
        mode="min",                      # lower is better
        patience=9,                      # stop if no improvement for 9 epochs
        restore_best_weights=True        # keep best weights found
    ),
    tf.keras.callbacks.ModelCheckpoint(
        "best_model.keras",              # file to save best model
        monitor="val_loss",
        mode="min",
        save_best_only=True              # save only when val_loss improves
    ),
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss",
        mode="min",
        factor=0.5,                      # reduce LR by half
        patience=2,                      # wait 2 epochs without improvement
        min_lr=1e-6,                     # don't go below this LR
        verbose=1                        # print when LR changes
    )
]


### GPU setup (optional)
Enables memory growth so TensorFlow doesn't reserve all GPU memory at once.


In [None]:
import tensorflow as tf  # TensorFlow library

gpus = tf.config.list_physical_devices('GPU')          # list available GPUs
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True) # use GPU memory as needed (avoid full pre-allocation)

print("GPUs:", gpus)                                   # show detected GPUs



### Train or load a saved model
Use `TRAIN = True` to train and save, or `TRAIN = False` to load the saved model from disk. This cell also prints TensorFlow/Keras versions and some file info.


In [22]:
import os                                   # file/path utilities
from datetime import datetime                # (not used here, can remove)
import tensorflow as tf                      # TensorFlow

MODEL_PATH = "../models/ResNet50_image_classification_model.keras"  # where to save/load the model
TRAIN = True                               # True = train + save, False = load existing model

# Optional: fine-tune the top layers of ResNet50 after the first training stage
FINE_TUNE = False                            # set True only after you see stable val accuracy
FINE_TUNE_EPOCHS = 10                        # extra epochs for fine-tuning
UNFREEZE_LAST_N = 30                         # unfreeze the last N layers of the ResNet50 backbone

if TRAIN:
    print("Training model...")
    model_history = model.fit(
        train_ds,                            # training dataset
        validation_data=val_ds,              # validation dataset
        epochs=epochs,                       # number of epochs
        callbacks=callbacks,                 # callbacks (early stop, checkpoint, etc.)
        verbose=1                            # show training progress
    )

    # Fine-tune (optional)
    if FINE_TUNE:
        print("Fine-tuning the top layers of ResNet50...")
        backbone = model.get_layer("resnet50")
        backbone.trainable = True

        # Freeze early layers, unfreeze only the last UNFREEZE_LAST_N layers
        if UNFREEZE_LAST_N is not None and UNFREEZE_LAST_N > 0:
            for layer in backbone.layers[:-UNFREEZE_LAST_N]:
                layer.trainable = False

        # Keep BatchNorm layers frozen for more stable fine-tuning
        for layer in backbone.layers:
            if isinstance(layer, layers.BatchNormalization):
                layer.trainable = False

        model.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
            metrics=['accuracy']
        )

        fine_tune_history = model.fit(
            train_ds,
            validation_data=val_ds,
            epochs=FINE_TUNE_EPOCHS,
            callbacks=callbacks,
            verbose=1
        )

    model.save(MODEL_PATH)                   # save trained model
    print("Saved model to:", MODEL_PATH)

else:
    if os.path.exists(MODEL_PATH):           # check model file exists
        print("Loading saved model...")
        try:
            model = tf.keras.models.load_model(
                MODEL_PATH,
                compile=False,
                safe_mode=False,
            )
        except TypeError:
            # Older TF/Keras versions don't have safe_mode
            model = tf.keras.models.load_model(
                MODEL_PATH,
                compile=False,
            )

        # Compile after loading (needed for model.evaluate)
        model.compile(
            optimizer='adam',
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
            metrics=['accuracy']
        )

        print("Loaded model from:", MODEL_PATH)
    else:
        raise FileNotFoundError(             # error if model file is missing
            f"Model not found at: {MODEL_PATH}\n"
            "Either fix the path or set TRAIN=True to train and save the model."
        )


 232/2348 [=>............................] - ETA: 1:12:53 - loss: 0.2034 - accuracy: 0.9351

KeyboardInterrupt: 

### Testing and model evaluation
Returns test loss and test accuracy.


In [None]:
if "scores" not in globals():
    model.evaluate(test_ds, verbose=1)
else:
    print("✅ Using cached scores (not re-evaluating).")

print("metrics:", model.metrics_names)


### Predict labels on the test set
Collects true labels and predicted labels so we can build a classification report and confusion matrix.


In [None]:
y_true_list = []                      # store true labels from all batches
y_pred_list = []                      # store predicted labels from all batches

for x_batch, y_batch in test_ds:      # loop over test batches (images, labels)
    probs = model.predict(x_batch, verbose=0)   # predicted class probabilities
    y_pred_batch = np.argmax(probs, axis=1)     # pick class with highest probability

    y_true_batch = y_batch.numpy()    # true labels as numpy

    y_true_list.append(y_true_batch)  # collect true labels
    y_pred_list.append(y_pred_batch)  # collect predicted labels

y_true = np.concatenate(y_true_list)  # merge all true labels into one array
y_pred = np.concatenate(y_pred_list)  # merge all predictions into one array

print("Manual test accuracy:", (y_true == y_pred).mean())  # compare with model.evaluate accuracy


### Confusion matrix plot
Shows where the model confuses one class for another.


In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay  # confusion matrix tools

cm = confusion_matrix(y_true, y_pred)      # build confusion matrix from true vs predicted labels

plt.figure(figsize=(10, 8))               # set figure size
disp = ConfusionMatrixDisplay(confusion_matrix=cm)  # wrap matrix for plotting
disp.plot(values_format="d")              # plot counts as integers
plt.title("Confusion Matrix (Test)")      # title
plt.show()                                # display plot


### Classification report
Precision/recall/F1-score per class, plus macro and weighted averages.


In [None]:
from sklearn.metrics import classification_report
cm = confusion_matrix(y_true, y_pred)  # compute confusion matrix

print(classification_report(
    y_true, y_pred,                    # true labels vs predicted labels
    target_names=class_names,          # show class names instead of numbers
    digits=4                           # print metrics with 4 decimals
))



### Inspect wrong predictions
Collects a few misclassified examples so you can visually check what went wrong.


In [None]:
import matplotlib.pyplot as plt   # plotting
import numpy as np                # arrays

wrong = []  # store (image, true_label, pred_label) for mistakes

for x_batch, y_batch in test_ds:                      # loop over test batches
    probs = model.predict(x_batch, verbose=0)         # predict probabilities
    pred = np.argmax(probs, axis=1)                   # predicted class ids

    if len(y_batch.shape) > 1 and y_batch.shape[-1] > 1:  # if labels are one-hot
        true = np.argmax(y_batch.numpy(), axis=1)          # convert to class ids
    else:
        true = y_batch.numpy().astype(int)                 # sparse labels -> int

    for i in range(len(true)):                             # check each item in batch
        if true[i] != pred[i]:
            wrong.append((x_batch[i].numpy().astype("uint8"), true[i], pred[i]))  # save mistake
    if len(wrong) >= 25:
        break                                              # stop after collecting 25 mistakes

plt.figure(figsize=(12, 10))                               # create figure
for i, (img, t, p) in enumerate(wrong[:25]):               # plot up to 25 wrong images
    plt.subplot(5, 5, i+1)
    plt.imshow(img)                                       # show image
    t_name = class_names[t] if "class_names" in globals() else str(t)  # true name
    p_name = class_names[p] if "class_names" in globals() else str(p)  # pred name
    plt.title(f"T:{t_name}\nP:{p_name}", fontsize=8)       # title (T=true, P=pred)
    plt.axis("off")                                       # hide axes
plt.tight_layout()                                        # spacing
plt.show()                                                # display


### Error analysis 

In [None]:
import numpy as np  # arrays + numerical ops

pairs = []                          # store (count, true_class_name, predicted_class_name)
n = len(class_names)                # number of classes

for i in range(n):                  # loop over true class index
    for j in range(n):              # loop over predicted class index
        if i != j and cm[i, j] > 0: # keep only mistakes (off-diagonal) with count > 0
            pairs.append((cm[i, j], class_names[i], class_names[j]))  # save this confusion pair

for c, true_name, pred_name in sorted(pairs, reverse=True)[:15]:  # top 15 biggest confusions
    print(f"{c:>3}  True: {true_name:30s}  → Pred: {pred_name}")   # print nicely formatted


The main remaining errors are concentrated in a few visually similar categories—especially Spider_mites → Target_Spot (9) and Spider_mites → healthy (8), plus smaller mix-ups like Septoria_leaf_spot → Early_blight (5). This is reasonable because these classes can share similar “spotty” textures, and symptoms may be subtle or appear in early stages, making the visual cues harder even for humans.