In [3]:
import os
import cv2
import matplotlib.pyplot as plt
import pandas as pd

# Define raw data set and training data set directories
raw_dir = ".\\train_data-size-180-raw"
train_dir = ".\\train_data-size-180"

# Define image size
img_size = (64, 64)

# Define number of images per class for balancing
num_per_class = 60

# Define number of class for scaling
num_of_class = 180

In [10]:
# Perform pre-processing on dataset
# Include resizing and grayscaling

class_folders = os.listdir(raw_dir)
class_count = 0

for index, folder_name in enumerate(class_folders):

    # Stop when there is sufficient number of class
    if class_count == num_of_class:
        break

    folder_dir = os.path.join(raw_dir, folder_name)
    files = os.listdir(folder_dir)

    # Skip the current folder if it doesn't have sufficient number of images
    if len(files) < num_per_class:
        continue

    files = sorted(files)[0:num_per_class]
    
    # Define output folder directory
    out_folder_dir = os.path.join(train_dir, folder_name)
    os.mkdir(out_folder_dir)

    for file_name in files:
        file_dir = os.path.join(folder_dir, file_name)

        img = cv2.imread(file_dir)

        # Convert rgb image to grayscale. Gives weights to RGB channels to take human perception of brightness into account
        b, g, r = cv2.split(img)
        img_gray = cv2.convertScaleAbs(0.299 * r + 0.587 * g + 0.114 * b)

        # Resize to meet the model input size
        img_gray = cv2.resize(img_gray, img_size)

        # Save the pre-processed image
        out_file_dir = os.path.join(out_folder_dir, file_name)
        cv2.imwrite(out_file_dir, img_gray)
    
    class_count += 1


In [None]:
# Visualize the preprocessed data
# Plot training dataset images in a 6x6 grid

class_names = sorted(os.listdir(train_dir))[:6*6]

fig, axes = plt.subplots(6, 6, figsize=(6, 6))

for i, class_name in enumerate(class_names):
    class_dir = os.path.join(train_dir, class_name)
    img_name = os.listdir(class_dir)[0]
    img_path = os.path.join(class_dir, img_name)
    img = cv2.imread(img_path)
    axes[i % 6, i // 6].imshow(img)
    axes[i % 6, i // 6].axis('off')

plt.show()