In [2]:
# there is a data folder with several folders inside, each folder has images of a specific type of mushroom
# the goal is to create a dataset with the following structure: 
# data/
#   train/
#       class1/
#           img1.jpg
#           img2.jpg
#           ...
#       class2/
#           img1.jpg
#           img2.jpg
#           ...
#       ...
#   val/
#       class1/
#           img1.jpg
#           img2.jpg
#           ...
#       class2/
#           img1.jpg
#           img2.jpg
#           ...
#       ...
#   test/
#       class1/
#           img1.jpg
#           img2.jpg
#           ...
#       class2/
#           img1.jpg
#           img2.jpg
#           ...
#       ...

In [None]:

import os
import shutil
import random

# define the path to the data folder
data_path = r"C:\Users\ADMIN\Desktop\mushroom_img_classification\data"

# define the path to the new dataset folder
new_dataset_path = r"C:\Users\ADMIN\Desktop\mushroom_img_classification\data_new"

# define the percentage of images to use for the validation and test sets
val_split = 0.15
test_split = 0.15
train_split = 1 - (val_split + test_split)

# create the new dataset folder
if not os.path.exists(new_dataset_path):
    os.makedirs(new_dataset_path)

# create the train, val, and test folders and in each create class folders for the images then proceed to copy the images to the new dataset folder
for split in ['train', 'val', 'test']:
    split_path = os.path.join(new_dataset_path, split)
    if not os.path.exists(split_path):
        os.makedirs(split_path)
    for class_name in os.listdir(data_path):
        class_path = os.path.join(split_path, class_name)
        if not os.path.exists(class_path):
            os.makedirs(class_path)
        class_images = os.listdir(os.path.join(data_path, class_name))
        random.shuffle(class_images)
        # split the images into 70% train, 15% val, and 15% test and round up to the nearest integer and make sure that the images are not repeated in the different sets
        split_images = class_images[:int(train_split * len(class_images))] if split == 'train' else class_images[int(train_split * len(class_images)):int(train_split * len(class_images)) + int(val_split * len(class_images))] if split == 'val' else class_images[int(train_split * len(class_images)) + int(val_split * len(class_images)):] if split == 'test' else None
        
        
        for img_name in split_images:
            img_path = os.path.join(data_path, class_name, img_name)
            new_img_path = os.path.join(class_path, img_name)
            shutil.copy(img_path, new_img_path)
            print(f"Copying {img_path} to {new_img_path}")
print("Done creating dataset")



Copying C:\Users\ADMIN\Desktop\mushroom_img_classification\data\almond_mushroom\2.png to C:\Users\ADMIN\Desktop\mushroom_img_classification\data_new\train\almond_mushroom\2.png
Copying C:\Users\ADMIN\Desktop\mushroom_img_classification\data\almond_mushroom\10.png to C:\Users\ADMIN\Desktop\mushroom_img_classification\data_new\train\almond_mushroom\10.png
Copying C:\Users\ADMIN\Desktop\mushroom_img_classification\data\almond_mushroom\8.png to C:\Users\ADMIN\Desktop\mushroom_img_classification\data_new\train\almond_mushroom\8.png
Copying C:\Users\ADMIN\Desktop\mushroom_img_classification\data\almond_mushroom\1.png to C:\Users\ADMIN\Desktop\mushroom_img_classification\data_new\train\almond_mushroom\1.png
Copying C:\Users\ADMIN\Desktop\mushroom_img_classification\data\almond_mushroom\3.png to C:\Users\ADMIN\Desktop\mushroom_img_classification\data_new\train\almond_mushroom\3.png
Copying C:\Users\ADMIN\Desktop\mushroom_img_classification\data\almond_mushroom\5.png to C:\Users\ADMIN\Desktop\m

In [14]:
# check the number of images in each class folder in the 3 splits in a tabular format with the three split being the columns and the class folders being the rows
print('Class'.ljust(20), 'Train'.ljust(10), 'Val'.ljust(10), 'Test'.ljust(10))
print('-' * 40)
for class_name in os.listdir(data_path):
    train_path = os.path.join(new_dataset_path, 'train', class_name)
    val_path = os.path.join(new_dataset_path, 'val', class_name)
    test_path = os.path.join(new_dataset_path, 'test', class_name)
    train_images = len(os.listdir(train_path))
    val_images = len(os.listdir(val_path))
    test_images = len(os.listdir(test_path))
    print(class_name.ljust(20), str(train_images).ljust(10), str(val_images).ljust(10), str(test_images).ljust(10))

# check the total number of images in each split
print(f'Train: {sum([len(os.listdir(os.path.join(new_dataset_path, "train", class_name))) for class_name in os.listdir(data_path)])}')
print(f'Val: {sum([len(os.listdir(os.path.join(new_dataset_path, "val", class_name))) for class_name in os.listdir(data_path)])}')
print(f'Test: {sum([len(os.listdir(os.path.join(new_dataset_path, "test", class_name))) for class_name in os.listdir(data_path)])}')

        

# check the total number of images
total_images = 0
for folder in ['train', 'val', 'test']:
    folder_path = os.path.join(new_dataset_path, folder)
    class_folders = os.listdir(folder_path)
    for class_folder in class_folders:
        class_folder_path = os.path.join(folder_path, class_folder)
        num_images = len(os.listdir(class_folder_path))
        total_images += num_images
print(f'Total images: {total_images}')


Class                Train      Val        Test      
----------------------------------------
almond_mushroom      8          1          3         
amanita_gemmata      10         2          3         
amethyst_chanterelle 10         2          3         
amethyst_deceiver    10         2          3         
aniseed_funnel       10         2          3         
ascot_hat            10         2          3         
bay_bolete           10         2          3         
bearded_milkcap      10         2          3         
beechwood_sickener   10         2          3         
beefsteak_fungus     10         2          3         
birch_polypore       10         2          3         
birch_woodwart       8          1          3         
bitter_beech_bolete  10         2          3         
bitter_bolete        10         2          3         
blackening_brittlegill 10         2          3         
blackening_polypore  10         2          3         
blackening_waxcap    10         2      

In [15]:
# now i wish to check if any of the images are repeated in the different sets
train_images = []
val_images = []
test_images = []
for class_name in os.listdir(data_path):
    train_images += os.listdir(os.path.join(new_dataset_path, 'train', class_name))
    val_images += os.listdir(os.path.join(new_dataset_path, 'val', class_name))
    test_images += os.listdir(os.path.join(new_dataset_path, 'test', class_name))
print(f"Images in train and val: {set(train_images) & set(val_images)}")
print(f"Images in train and test: {set(train_images) & set(test_images)}")
print(f"Images in val and test: {set(val_images) & set(test_images)}")
print("Done checking for repeated images")

Images in train and val: {'1.png', '3.png', '10.png', '9.png', '4.png', '14.png', '12.png', '2.png', '0.png', '11.png', '6.png', '7.png', '8.png', '5.png', '13.png'}
Images in train and test: {'1.png', '3.png', '10.png', '9.png', '4.png', '14.png', '6.png', '2.png', '0.png', '11.png', '12.png', '7.png', '8.png', '5.png', '13.png'}
Images in val and test: {'1.png', '3.png', '10.png', '9.png', '4.png', '14.png', '6.png', '2.png', '0.png', '11.png', '12.png', '7.png', '8.png', '5.png', '13.png'}
Done checking for repeated images
