In [None]:
#installation of tensorflow_datasets and Collections (un-comment if needed)
#%pip install tensorflow_datasets
#%pip install collections

In [None]:
from collections import Counter
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
import os
import pandas as pd
from tqdm.notebook import tqdm

In [None]:
# Load the PlantVillage dataset from TFDS instead of the new dataset (it performed data aug on the validation set, which is wrong)
(ds_train, ds_val, ds_test), ds_info = tfds.load(
    'plant_village',
    split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
    shuffle_files=True,
    as_supervised=True,  # returns (image, label) pairs
    with_info=True
)

In [None]:
#Number of elements per class:
class_counts = Counter(tf.argmax(tf.one_hot(label, depth=ds_info.features['label'].num_classes)).numpy() for _, label in ds_train)
class_names = ds_info.features['label'].names
class_counts_named = {class_names[i]: count for i, count in class_counts.items()}

for name, count in class_counts_named.items():
    print(f"{name}: {count}")

In [None]:
#setting the desired number of elements in each class after augmentation (target).
#In this run target = # of elements in the largest class 
target = max(class_counts.values())
class_elements = Counter(class_counts)
num_classes = len(class_names)

data_augmentation = tf.keras.Sequential([
    # tf.keras.layers.RandomFlip("horizontal"),
    # tf.keras.layers.RandomFlip("vertical"),
    tf.keras.layers.RandomRotation(0.1),
    tf.keras.layers.RandomZoom(0.1),
    tf.keras.layers.RandomContrast(0.1),
])

def preprocess(image, label):
    image = tf.image.resize(image, IMG_SIZE)
    image = tf.cast(image, tf.float32) / 255.0
    return image, tf.one_hot(label, ds_info.features['label'].num_classes)

def preprocess_with_aug(image, label):
    image = tf.image.resize(image, IMG_SIZE)
    image = data_augmentation(image)  # <-- augment here
    image = tf.cast(image, tf.float32) / 255.0
    return image, tf.one_hot(label, ds_info.features['label'].num_classes)


In [None]:
#this cell should apply a class-specific augmentation pipeline before saving the image

# Create output directory
os.makedirs('augmented_data', exist_ok=True)

# Create output folders
output_root = 'augmented_data'
os.makedirs(output_root, exist_ok=True)
for class_name in class_names:
    os.makedirs(os.path.join(output_root, class_name), exist_ok=True)

#Number of elements per class in the new dir:
#new_class_elements = Counter(tf.argmax(label).numpy() for _, label in ds_train)
new_class_elements = Counter({i: 0 for i in range(len(class_names))})

# Augmentation
for i, (image, label) in enumerate(ds_train.repeat()):  # .repeat() allows infinite looping
    
    label_index = label.numpy()
    label_name = class_names[label_index]

    # Saving the og images in the new empty folders
    save_path = os.path.join(output_root, label_name, f'img_{new_class_elements[label_index]}.png')
    tf.keras.preprocessing.image.save_img(save_path, image.numpy())
    new_class_elements[label_index] += 1


    # Stop augmenting if this NEW class is already balanced (= the desired number of images has been saved into this new folder)
    if new_class_elements[label_index] >= target:
        continue
    
    # Stop augmenting if this class is already balanced (= the original folder does not require data augmentation, but the new folder is still being filled)
    if class_elements[label_index] >= target:
        continue

    #the two previous lines are the ones that allow the duplication (and eventual augmentation) of the original classes:
    #if a class contains a number of images < target, then its images will get duplicated and augmented in the new respective class 
    #and until the number of elements in that class reaches target
    
    #if a class contains a number of images >= target, then its images will get duplicated in the new class without 
    #augmentation until the new class reaches a number of elements = target; in this way if target > max # of elements
    # the classes with # of elements > target will be trimmed down in the new dir and will contain less images 

    
    # Apply augmentation
    aug_image = data_augmentation(image)

    # Save image
    save_path = os.path.join(output_root, label_name, f'aug_{new_class_elements[label_index]}.png')
    tf.keras.preprocessing.image.save_img(save_path, aug_image.numpy())

    # Update count
    class_elements[label_index] += 1
    new_class_elements[label_index] += 1
    

    # Stop once all classes are balanced
    #if all(class_elements[c] >= target for c in range(num_classes)):
    if all(new_class_elements[c] >= target for c in range(num_classes)):
        print(" Classes balanced!! ")
        break

In [None]:
# useful constants
IMG_SIZE = (128, 128)
BATCH_SIZE = 64
APPLY_DATA_AUGMENTATION = False
N_EPOCHS = 30
NUM_CLASSES = ds_info.features['label'].num_classes
DROP_RATE = 0.3
L2_REGULARIZATION = 0.005

In [None]:
# loading the new training set in the training notebook 

from tensorflow.keras.utils import image_dataset_from_directory

# Define the path to the augmented data
augmented_data_path = os.path.abspath("augmented_data")

# Load the dataset from the new directory
ds_augmented_train = tf.keras.utils.image_dataset_from_directory(
    augmented_data_path,
    labels='inferred',
    label_mode='categorical',   # one-hot encoding
    batch_size=32,
    image_size=(224, 224),      # or whatever size your model expects
    shuffle=True,
    seed=123
)

# Get class names (sorted as image_dataset_from_directory does)
class_names = sorted(os.listdir("augmented_data"))
print("Class names:", class_names)

# Count images per class
class_counts = {cls: len(os.listdir(os.path.join("augmented_data", cls))) for cls in class_names}

print("Number of images per class:")
for cls in class_names:
    print(f"{cls}: {class_counts[cls]} images")

