In [1]:
import os
import random

In [2]:
BASE_DATASET_PATH = './dataset'
TRAIN_DATASET_PATH = BASE_DATASET_PATH + '/Training'
TEST_DATASET_PATH = BASE_DATASET_PATH + '/Test'

In [3]:
def load_images(path: str) -> dict:
    images = {}
    
    for folder in os.listdir(path):
        category = folder.split()[0]
        folder = path + '/' + folder
        if category not in images.keys(): images[category] = []
        images[category].append([folder + '/' + file for file in os.listdir(folder)])
        
    return images

In [4]:
train_images = load_images(TRAIN_DATASET_PATH)
test_images = load_images(TEST_DATASET_PATH)

# Debugging information
print('Train Dataset')
print({ category: sum([len(ims) for ims in images]) for category, images in train_images.items() })

print()

print('Test Dataset')
print({ category: sum([len(ims) for ims in images]) for category, images in test_images.items() })

Train Dataset
{'Apple': 6404, 'Tomato': 5103, 'Orange': 479, 'Cocos': 490, 'Kiwi': 466, 'Lemon': 982}

Test Dataset
{'Apple': 2134, 'Tomato': 1707, 'Orange': 160, 'Cocos': 166, 'Kiwi': 156, 'Lemon': 330}


In [5]:
# Balance training dataset
max_number_collections = max([ len(collections) for collections in train_images.values() ])
min_number_samples = min([ sum([len(collection) for collection in collections]) for collections in train_images.values() ])
train_images_balanced = { category: [ image for collection in collections for image in random.sample(collection, int(min_number_samples/len(collections))) ] for category, collections in train_images.items() }
train_images_balanced = { category: images + random.sample([image for collection in train_images[category] for image in collection if image not in train_images_balanced[category]], min_number_samples - len(train_images_balanced[category])) for category, images in train_images_balanced.items() }

# Debugging information
print('Train Dataset')
print({ category: len(set(images)) for category, images in train_images_balanced.items() })

Train Dataset
{'Apple': 466, 'Tomato': 466, 'Orange': 466, 'Cocos': 466, 'Kiwi': 466, 'Lemon': 466}


In [None]:
train_dataset = train_images_balanced
test_dataset = { category: [ image for collection in collections for image in collection ] for category, collections in test_images.items() }