In [3]:
import os
import random
import tensorflow as tf
from tensorflow.keras.preprocessing import image_dataset_from_directory
from sklearn.model_selection import train_test_split

In [4]:
# input and output directory
input_dir = "augmented/"
output_train_dir = "train/"
output_test_dir = "test/"

In [5]:
# Split percenteage
train_split = 0.8

# Image size 
target_size = (224,224) # (height,width)

# Random seed for reproducibility
random_seed = 32

In [6]:
# Load the dataset
dataset = image_dataset_from_directory(directory=input_dir,
                                      image_size = target_size,
                                      batch_size = 1,
                                      shuffle = True,
                                      seed = random_seed)

Found 6456 files belonging to 3 classes.


In [7]:
# Class names
classes = dataset.class_names
classes

['Potato___Early_blight_aug',
 'Potato___Late_blight_aug',
 'Potato___healthy_aug']

In [8]:
# dataset list
def dataset_to_lists(dataset):
    image_paths = []
    labels = []
    for image, label in dataset.unbatch():
        image_paths.append(image.numpy())
        labels.append(int(label.numpy())) 
    return image_paths, labels

# Convert dataset to lists
image_paths, labels = dataset_to_lists(dataset)

In [9]:
len(labels), len(image_paths)

(6456, 6456)

In [10]:
# Convert labels to string
labels = [str(label) for label in labels]


In [11]:
# Function to count and print images in each class
def count_images_in_classes(image_paths,labels):
    class_counts = {}
    for img, label in zip(image_paths,labels):
        cls = str(label)
        if cls not in class_counts:
            class_counts[cls] = 0
        class_counts[cls] += 1
    return class_counts

In [12]:
# Count images in each class and total images before split
print("Number of images before splitting : \n")
class_counts_before = count_images_in_classes(image_paths,labels)
total_images_before = sum(class_counts_before.values())

for cls, count in class_counts_before.items():
    print(f"{classes[int(cls)]} : {count} images")
print(f"\nTotal images : {total_images_before}")

Number of images before splitting : 

Potato___Late_blight_aug : 3000 images
Potato___Early_blight_aug : 3000 images
Potato___healthy_aug : 456 images

Total images : 6456


In [13]:
# Split the dataset using stratifid sampling
train_image_paths, test_image_paths, train_labels, test_labels = train_test_split(image_paths,
                                                                                 labels,
                                                                                 train_size=train_split,
                                                                                 stratify=labels,
                                                                                 random_state=random_seed)

In [23]:
cnt = 0;

def save_images(image_paths, labels, output_dir):
    global cnt
    class_counts = {}
    
    # Create class directories
    for cls in classes:
        class_counts[cls] = 0
        os.makedirs(os.path.join(output_dir, cls), exist_ok=True)

    # Iterate through the dataset
    for img, label in zip(image_paths, labels):
        class_name = classes[int(label)]
        class_counts[class_name] += 1
        img_path = os.path.join(output_dir, class_name, f"img_{cnt:05d}.jpg")
        tf.io.write_file(img_path, tf.io.encode_jpeg(tf.cast(img, tf.uint8)))
        cnt += 1

    return class_counts

In [24]:
# Save training images
print("\nSaving training images...")
class_counts_train = save_images(train_image_paths, train_labels, output_train_dir)

# Save test images
print("\nSaving test images...")
class_counts_test = save_images(test_image_paths, test_labels, output_test_dir)

print("\nDone")


Saving training images...

Saving test images...

Done


In [25]:
# Count images in each class and total images after split
print("Number of images after splitting : \n")

print("Training set : ")
for cls, count in class_counts_train.items():
    print(f"{cls}: {count} images")
print(f"\nTotal training images : {sum(class_counts_train.values())}")

print("\nTest set:")
for cls, count in class_counts_test.items():
    print(f"{cls}: {count} images")
print(f"\nTotal test images: {sum(class_counts_test.values())}")

print(f"\nTotal Images : {sum(class_counts_test.values()) + sum(class_counts_train.values())}")


Number of images after splitting : 

Training set : 
Potato___Early_blight_aug: 2400 images
Potato___Late_blight_aug: 2399 images
Potato___healthy_aug: 365 images

Total training images : 5164

Test set:
Potato___Early_blight_aug: 600 images
Potato___Late_blight_aug: 601 images
Potato___healthy_aug: 91 images

Total test images: 1292

Total Images : 6456
