# Data Augmentation

The [PlantVillage-Dataset](https://github.com/spMohanty/PlantVillage-Dataset/blob/master/README.md), is a dataset of plant leaf images  categorized per species and per health conditions. The dataset containes 38 classes, but the number of elements per class can vary significantly between them, making the dataset imbalanced.

This notebook aims to balence the classes by applying a class-specific data augmentation (or down-sampling), and then to proceed with the training of the model with the balanced dataset.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Install required libraries
%pip install tensorflow_datasets
#%pip install collections

In [None]:
# Import all necessary libraries and functions
from collections import Counter
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
import os
import pandas as pd
import random
from tqdm.notebook import tqdm

In [None]:
# Seed setting for reproducibility
random.seed(42)
np.random.seed(42)
tf.random.set_seed(42)
tf.keras.utils.set_random_seed(42)

In [None]:
# Load the PlantVillage dataset from the tensorflow_datasets library
(ds_train, ds_val, ds_test), ds_info = tfds.load(
    'plant_village',
    split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
    shuffle_files=True,
    as_supervised=True,  # returns (image, label) pairs
    with_info=True    #returns metadata regarding the dataset
)

# Count number of samples in each subset
train_count = sum(1 for _ in ds_train)
val_count = sum(1 for _ in ds_val)
test_count = sum(1 for _ in ds_test)
dataset_count = train_count + val_count + test_count

print(f"The Plantvillage dataset contains: {dataset_count} samples")
print("Dataset Splitting: 80/10/10 Train-Validation-Test Ratio")
print(f"Training set: {train_count} samples")
print(f"Validation set: {val_count} samples")
print(f"Test set: {test_count} samples")

## Training dataset data augmentation

Data augmentation is a process that can impact the response of a model during the training phase. This is due to the fact that the resulting augmented images vary depending on the specific transformations applied during the augmentation.
This notebooks aim is to evaluate the impact of different transformations on model performance by applying three different data augmentation pipelines to the training dataset:
- Geometric data augmentation: uses transformations that translate, rotate and zooming in or out of the original image. These mimic the various spatial positions in which the same leaf might appear.

- Color data augmentation: uses transformations that alters the value of the pixel by randomly adjusting brightness, saturation and contrast of the original image. This mimics the variability in lighting conditions under which the same leaf may be observed.

- Combined data augmentation: uses the combination of all the previous transformation. By bringing together the two previous groups of transformations, this pipeline is expected to best simulate the randomness and variability encountered in real-world scenarios.

Original training dataset information:

In [None]:
# Extract only labels (not images) for counting
labels_list = []
for _, label in ds_train:      # Iterate over the dataset, ignoring the image (_) and extracting the label
    labels_list.append(label.numpy())

# Convert to pandas Series for fast operations
labels_series = pd.Series(labels_list)

number_of_classes = ds_info.features['label'].num_classes
print(f"Number of classes: {number_of_classes}")
class_names = ds_info.features['label'].names

# Count using pandas (fast) but memory-efficient
class_counts = labels_series.value_counts().sort_index()

# Count original class samples (before augmentation)
original_counts = labels_series.value_counts().sort_index()

# Show example from the dataset
tfds.show_examples(ds_train, ds_info)

plt.figure(figsize=(12,6))
plt.bar(class_names, original_counts)
plt.title("Class distribution before augmentation")
plt.xlabel("Class")
plt.ylabel("Number of images")
plt.xticks(rotation=90)
plt.show()

# Print the counts with class names
print("Number of images per class:")
for idx, count in class_counts.items():
    print(f"{class_names[idx]}: {count} images")

In [None]:
# Useful constants
IMG_SIZE = (128, 128)
BATCH_SIZE = 64
N_EPOCHS = 30
NUM_CLASSES = ds_info.features['label'].num_classes
DROP_RATE = 0.3
L2_REGULARIZATION = 0.005

Target sample size and data augmentation pipelines definition:

In [None]:
# Preprocess for non-augmented datasets
def preprocess(image, label, dataset_info, image_size=(128, 128)):
    image = tf.image.resize(image, image_size)
    image = tf.cast(image, tf.float32) / 255.0
    return image, tf.one_hot(label, dataset_info.features['label'].num_classes)

# Geometric augmentation
geo_augmentation = tf.keras.Sequential([
    tf.keras.layers.RandomZoom(0.1),
    tf.keras.layers.RandomTranslation(0.1, 0.1),
    tf.keras.layers.RandomRotation(0.1),
])

# Preprocess for geometrically augmented datasets
def preprocess_geo_aug(image, label):
    image = tf.image.resize(image, IMG_SIZE)
    image = tf.cast(image, tf.float32) / 255.0
    return image, tf.one_hot(label, ds_info.features['label'].num_classes)

# Color augmentation
color_augmentation = tf.keras.Sequential([
    tf.keras.layers.RandomBrightness(0.3),  # Brightness change
    tf.keras.layers.RandomContrast(0.2),    # Contrast variation
    tf.keras.layers.RandomSaturation(0.2),  # (optional) Color intensity variation
])

# Preprocess for color augmented datasets
def preprocess_color_aug(image, label):
    image = tf.image.resize(image, IMG_SIZE)
    image = tf.clip_by_value(aug_image, 0.0, 255.0)
    image = tf.cast(aug_image, tf.uint8)
    image = tf.cast(image, tf.float32) / 255.0
    return image, tf.one_hot(label, ds_info.features['label'].num_classes)

# Combining the two types of augmentation
combined_data_augmentation = tf.keras.Sequential([
    # Geometric augmentations
    tf.keras.layers.RandomRotation(0.1),
    tf.keras.layers.RandomZoom(0.1),
    tf.keras.layers.RandomTranslation(0.1, 0.1),

    # Color augmentations
    tf.keras.layers.RandomBrightness(0.2),
    tf.keras.layers.RandomContrast(0.2),
    tf.keras.layers.Lambda(lambda x: tf.image.random_saturation(x, 0.8, 1.2)),

])

# Preprocess for geometrically and color augmented datasets
def preprocess_combined_aug(image, label):
    image = tf.image.resize(image, IMG_SIZE)
    image = tf.cast(image, tf.float32) / 255.0
    return image, label, ds_info.features['label'].num_classes

## Augmentation

In [None]:
# Choose number of samples per class:
#'max': # of samples in largest class
#'mean': mean # of samples
#'t': 1000 samples
#'min': # of samples in the smallest class

target_type = input("Choose the number of samples per class('max', 'mean', 't', 'min'): ").strip().lower()
if target_type not in ['max', 'mean', 't', 'min']:
    raise ValueError("Invalid choice! Please enter 'max', 'mean', 't', or 'min'.")

print(f"You chose: {target_type}")

if target_type == 'max':
  target = class_counts.max()

elif target_type == 'mean':
  target = round(class_counts.mean())

elif target_type == 't':
  target = 1000

elif target_type == 'min':
  target = class_counts.min()

samplesize = str(target)

# Choose augmentation type:
#'comb': geometric and color augmentation
#'geo': geometric augmentation
#'color': color augmentation

augmentation_type = input("Choose augmentation type ('comb', 'geo', 'color'): ").strip().lower()
if augmentation_type not in ['comb', 'geo', 'color']:
    raise ValueError("Invalid choice! Please enter 'none', 'geo', or 'color'.")

print(f"You chose: {augmentation_type}")

if augmentation_type == 'comb':
  def augment(image):
    image = combined_data_augmentation(image)
    return image
  def preprocess_aug(image, label):
    image = tf.image.resize(image, IMG_SIZE)
    image = tf.clip_by_value(image, 0.0, 255.0)
    image = tf.cast(image, tf.float32) / 255.0
    return image, label

elif augmentation_type == 'geo':
  def augment(image):
    image = geo_augmentation(image)
    return image

  def preprocess_aug(image, label):
    image = tf.image.resize(image, IMG_SIZE)
    image = tf.cast(image, tf.float32) / 255.0
    return image, label

elif augmentation_type == 'color':
  def augment(image):
    image = geo_augmentation(image)
    return image

  def preprocess_aug(image, label):
    image = tf.image.resize(image, IMG_SIZE)
    image = tf.clip_by_value(image, 0.0, 255.0)
    image = tf.cast(image, tf.float32) / 255.0
    return image, label

# Setting output directory to Google Drive folder and creating the new classes
# output_root = 'G:/My Drive/MACHINE_LEARNING/progetto-daml'+augmentation_type+'_aug_'+samplesize
output_root = '/content/drive/MyDrive/progetto-daml/Augmented_datasets'+augmentation_type+'_aug_'+samplesize
#output_root = 'augmented_data'
print(f'New classes path: {output_root}')

In [None]:
# Creating new empty classes
os.makedirs(output_root, exist_ok=True)
for class_name in class_names:
    os.makedirs(os.path.join(output_root, class_name), exist_ok=True)

# Convert class_counts to dict for faster lookup
class_counts_dict = dict(class_counts)
target_per_class = {i: target for i in range(len(class_names))}

# Initialize counters more efficiently
new_class_elements = {i: 0 for i in range(len(class_names))}   # Dictionary for the new classes and their # of samples
class_elements = class_counts_dict.copy()  # Start with original counts

# Calculating which classes need augmentation
# If the target is smaller than the number of samples in some classes,
# those classes won't need any augmentation and will be subject to downsampling

if class_counts.max() >= target:
  if class_counts.min() >= target:
    classes_to_modify = {i for i in range(len(class_names))}
    print("All classes will undergo downsampling")
    classes_needing_aug = {i for i in range(len(class_names)) if class_counts_dict.get(i, 0) < target}
    new_class_aug_elements = {i: 0 for i in range(len(class_names))}
      
  else:
    classes_needing_aug = {i for i in range(len(class_names)) if class_counts_dict.get(i, 0) < target}
    classes_to_modify = {i for i in range(len(class_names)) if class_counts_dict.get(i, 0) < target or class_counts_dict.get(i,0) > target}
    new_class_aug_elements = {i: (target - class_counts.values[i]) for i in range(len(class_names)) if target - class_counts.values[i] > 0} # Calculating the number of augmented images needed for the underrappresented classes to reach the target
    aug_elements_dict = {k: int(v) for k, v in new_class_aug_elements.items()}   # Making the values int variables

# If the target is grater than the greatest number of samples recorded,
# all classes will be subject to data augmentation
if class_counts.max() < target:
  classes_needing_aug = {i for i in range(len(class_names))}
  new_class_aug_elements = {i: target - class_counts.values[i]for i in range(len(class_names))}
  aug_elements_dict = {k: int(v) for k, v in new_class_aug_elements.items()}
  print("All classes will undergo data augmentation")

print("Classes needing augmentation:")
for idx in classes_needing_aug:
    print(f"{class_names[idx]}")

print(f"Target images per class: {target}")
print(f"Classes undergoing data augmentation: {len(classes_needing_aug)}")
print(f"Classes undergoing downsampling: {len(classes_to_modify) - len(classes_needing_aug)}")

# Add progress tracking
from tqdm.notebook import tqdm

# Augmentation loop
total_needed = sum(max(0, target - class_counts_dict.get(i, 0)) for i in range(len(class_names)))
pbar = tqdm(total=total_needed, desc="Augmenting images")
print(f"Total needed: {total_needed}")


# Saving images loop
total_saved = target*len(class_names)
pbar1 = tqdm(total = total_saved, desc="Saving images")
print(f"Total saved: {total_saved}")


# Use enumerate with early stopping
for i, (image, label) in enumerate(ds_train.repeat()):
    label_index = label.numpy()
    label_name = class_names[label_index]
    print(f"class #{i+1}: {class_names[label_index]}")

    # Early termination check - if no more classes need more samples
    if all(count >= target for count in new_class_elements.values()):
      print(f" All classes have {target} images.")
      print("Current class counts:", new_class_elements)
      break

    # Checking if this specific class still needs samples
    if new_class_elements[label_index] >= target:
      if label_index in classes_needing_aug:
        classes_needing_aug.discard(label_index)
        continue
      continue

    # Always save the original image first
    save_path = os.path.join(output_root, label_name, f'img_{new_class_elements[label_index]}.png')
    tf.keras.preprocessing.image.save_img(save_path, image.numpy())

    # Update counters following the addition of one image
    new_class_elements[label_index] += 1
    pbar1.update(1)

    print(f"Original Image saved in class: {class_names[label_index]}")
    #print(f"Nella lista delle classi in need: {classes_needing_aug[label_index]}")

    # Checking if a specific class still needs augmented samples
    if label_index in classes_needing_aug:
      if aug_elements_dict[label_index] == 0:
        continue

    # Termination check - if this class is done, skip augmentation
    if new_class_elements[label_index] >= target:
      print(f"Class does not need anymore samples: {class_names[label_index]}")
      if label_index in classes_needing_aug:
        # Remove from classes needing augmentation
        classes_needing_aug.discard(label_index)
        continue
      continue

    # Only augment if this class still needs more samples
    if label_index in classes_needing_aug and new_class_elements[label_index] < target:
      # Apply augmentation with TensorFlow operations (faster than sequential)
      # Convert to float32 for augmentation operations
      aug_image = augment(image)

      # Convert back to uint8
      #aug_image = tf.clip_by_value(aug_image, 0.0, 255.0)
      aug_image = tf.cast(aug_image, tf.uint8)

      # Save augmented image
      save_path_aug = os.path.join(output_root, label_name, f'aug_{new_class_elements[label_index]}.png')
      tf.keras.preprocessing.image.save_img(save_path_aug, aug_image.numpy())
      print(f"Aug image saved in class: {class_names[label_index]}")

      # Update counters following the augmatation of one image
      class_elements[label_index] += 1
      new_class_aug_elements[label_index] -= 1
      new_class_elements[label_index] += 1
      aug_elements_dict[label_index] -= 1
      pbar.update(1)
      pbar1.update(1)

pbar.close()
pbar1.close()

# Print final statistics
print("\nFinal class distribution:")
for idx, count in new_class_elements.items():
    print(f"{class_names[idx]}: {count} images")

total_obtained = sum(new_class_elements.values())
print(f"Total samples in dataset: {total_obtained}")

In [None]:
# Class distribution after augmentation
plt.figure(figsize=(12,6))
plt.bar(class_names, new_class_elements.values())
plt.title("Class distribution after augmentation")
plt.xlabel("Class")
plt.ylabel("Number of images")
plt.xticks(rotation=90)
plt.show()

# MODEL TRAINING

In [None]:
# Loading the new training set from chosen dir

from tensorflow.keras.utils import image_dataset_from_directory

# Define the path to the augmented data
augmented_data_path = output_root

# Load the dataset from the new directory
ds_augmented_train = tf.keras.utils.image_dataset_from_directory(
    augmented_data_path,
    labels='inferred',
    label_mode='categorical',   # one-hot encoding
    batch_size=32,
    image_size=(224, 224),      # before preprocessing
    shuffle=True,
    seed=123
)

# Get class names (sorted as image_dataset_from_directory does)
class_names = sorted(os.listdir(augmented_data_path))
print("Class names:", class_names)

# Count images per class
class_counts = {cls: len(os.listdir(os.path.join(augmented_data_path, cls))) for cls in class_names}

print("Number of images per class:")
for cls in class_names:
    print(f"{cls}: {class_counts[cls]} images")

In [None]:
# Preprocessing
train_ds = ds_augmented_train.map(lambda image, label: preprocess_aug(image, label), num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)
val_ds   = ds_val.map  (lambda image, label: preprocess(image, label, ds_info, IMG_SIZE), num_parallel_calls=tf.data.AUTOTUNE).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
#test_ds  = ds_test.map (lambda image, label: preprocess(image, label, ds_info, IMG_SIZE), num_parallel_calls=tf.data.AUTOTUNE).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

In [None]:
from keras.models import Sequential, Model
from keras.layers import Activation, BatchNormalization, Dense, Conv2D, MaxPooling2D, Dropout, Flatten, GlobalAveragePooling2D, ReLU, Rescaling
from keras.optimizers.legacy import Adam, SGD
from keras.losses import CategoricalCrossentropy

from keras.metrics import CategoricalAccuracy, Precision, Recall
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau


def simple_cnn(input_shape=(IMG_SIZE[0], IMG_SIZE[1], 3), num_classes=38):
    model = Sequential([
        Conv2D(16, (3, 3), activation='relu', padding='same'),
        Dropout(0.4),
        MaxPooling2D((2, 2)),
        Conv2D(32, (3, 3), activation='relu', padding='same'),
        MaxPooling2D((2, 2)),
        Flatten(),
        Dense(num_classes, activation='softmax')
    ])
    return model

In [None]:
model = simple_cnn()
model.build(input_shape=(None, IMG_SIZE[0], IMG_SIZE[1], 3))  # Build the model with the input shape
model.summary()

In [None]:
from tensorflow.keras.optimizers import Adam

optimizer = Adam(learning_rate=0.0002)
# optimizer = SGD(learning_rate=0.05, momentum=0.9)
model.compile(
    optimizer=optimizer,
    loss=CategoricalCrossentropy(),
    metrics=['accuracy']
)

n_epochs = 30

history = model.fit(
    train_ds,
    validation_data = val_ds,
    epochs=n_epochs,
    callbacks=[
        EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True),
        ModelCheckpoint('/content/weights/'+ augmentation_type+'_aug_'+ samplesize+'.h5', monitor='val_loss', save_best_only=True),
        ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=2)
    ]
)

In [None]:
model.save('/content/weights/'+ augmentation_type+'_aug_'+ samplesize+'.h5')