In [1]:
import os, random
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [2]:
dataaug = ImageDataGenerator(
    rotation_range=5, 
    zoom_range = 0.05,
    width_shift_range=0.05, 
    height_shift_range=0.05,
)

In [3]:
def balance_class(source_directory, target_directory, imagegen_obj, target_size=None, cmap='viridis'):
    
    class_examples = []
    os.mkdir(target_directory)
    for directory in os.listdir(source_directory):
        os.mkdir(os.path.join(target_directory, directory))
        class_examples.append(len(os.listdir(os.path.join(source_directory, directory))))
    
    print('Class Distribution:', class_examples)
    
    if target_size == None:
        target_size = max(class_examples)
        
    for i, directory in enumerate(os.listdir(source_directory)):
        dest_dir = os.path.join(target_directory, directory)
        if target_size > class_examples[i]:        
            q, r = divmod(target_size, class_examples[i])
            q, r = q-1, r/class_examples[i]
            for imgpath in tqdm(os.listdir(os.path.join(source_directory, directory))):
                filename, file_extension = os.path.splitext(imgpath)
                img = plt.imread(os.path.join(source_directory, directory, imgpath))
                plt.imsave(os.path.join(dest_dir, imgpath), img, cmap=cmap)
                for i in range(q):
                    if cmap == 'gray':
                        aug_img = np.squeeze(dataaug.random_transform(np.expand_dims(img, axis=2)))
                    else:
                        aug_img = imagegen_obj.random_transform(img)
                    plt.imsave(os.path.join(dest_dir, f"{filename}-{i+1}{file_extension}"), aug_img, cmap=cmap)
                prob = random.random()
                if prob <= r:
                    if cmap == 'gray':
                        aug_img = np.squeeze(dataaug.random_transform(np.expand_dims(img, axis=2)))
                    else:
                        aug_img = imagegen_obj.random_transform(img)
                    plt.imsave(os.path.join(dest_dir, f"{filename}-{0}{file_extension}"), aug_img, cmap=cmap)
        else:
            r = target_size/class_examples[i]
            for imgpath in tqdm(os.listdir(os.path.join(source_directory, directory))):
                filename, file_extension = os.path.splitext(imgpath)
                img = plt.imread(os.path.join(source_directory, directory, imgpath))
                prob = random.random()
                if prob <= r:
                    plt.imsave(os.path.join(dest_dir, imgpath), img, cmap=cmap)
    class_examples_updated = []
    for directory in os.listdir(target_directory):
        class_examples_updated.append(len(os.listdir(os.path.join(target_directory, directory))))   
    print('Updated Class Distribution:', class_examples_updated)

In [4]:
balance_class(
    source_directory='../Data/MathSymbols/', target_directory='../Data/MathSymbols-balanced/',
    imagegen_obj=dataaug, target_size=7000, cmap='gray'
) 

Class Distribution: [14294, 14355, 25112, 1067, 13104, 3251, 33997]


HBox(children=(FloatProgress(value=0.0, max=14294.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=14355.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=25112.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1067.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13104.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=3251.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=33997.0), HTML(value='')))


Updated Class Distribution: [7008, 7116, 6973, 7030, 6958, 6986, 6951]
