In [1]:
import numpy as np
import os 
import glob
import PIL
import re
from collections import Counter
import matplotlib.pyplot as plt
import cv2
from sklearn.model_selection import train_test_split

In [2]:
def load_preprocessed_data(path, label_key):
    """loads preprocessed data and returns an array of images and
    an array of labels"""
    X = []
    y = []
    for class_name in label_key.keys():
        class_path = os.path.join(path, class_name)
        files = glob.glob(os.path.join(class_path,'*.jpg'))
        for f in files:
            image = np.asarray(PIL.Image.open(f))
            image = np.expand_dims(image, axis=0)
            X.append(image)
            y.append(label_key[class_name])
    return np.concatenate(X), np.array(y)

In [343]:
def get_blank_idxs(images):
    """returns the indices of images post-preprocessing that 
    are blank"""
    blank_idxs = []
    for i in range(len(images)):
        b = len(images[i][images[i]==255])/2205
        if (b > 0.96):
            blank_idxs.append(i)
    return blank_idxs

In [344]:
def random_undersample(images, labels):
    """returns a rebalanced dataset where each of the classes have been 
    random undersampled to reflect the class with the fewest examples
    """
    n = Counter(labels).most_common()[-1][1]
    indices = []
    for i in range(5):
        possible_indices = np.where(labels==i)[0]
        if len(possible_indices) <= n:
            indices += list(possible_indices)
        else:
            selections = np.random.randint(0, len(possible_indices), size=n)
            selected_indices = possible_indices[selections]
            indices += list(selected_indices)
    return images[indices], labels[indices] 

In [345]:
def build_directory(path:str):
    """builds the directories to store the train and dev datasets
    """
    datasets = ['train', 'dev']
    for dataset in datasets:
        dataset_dir = os.path.join(path, dataset)
        if not os.path.exists(dataset_dir):
            os.mkdir(dataset_dir)
            labels = ['center','up','down','left','right']
            for l in labels:
                labelpath = os.path.join(dataset_dir,l)
                if not os.path.exists(labelpath):
                    os.mkdir(labelpath)
                    
def save_prebuilt_datasets(path:str, dataset_dict:dict):
    """writes the train and dev datasets to disk to be 
    used in modeling later on"""
    def get_next_idx(path):
        """finds the correct next file number for writing images 
        to a directory that already contains some images"""
        files_in_dir = glob.glob(os.path.join(path, '*.jpg'))
        if files_in_dir:  
            file_numbers = [int(re.search(r'\d+',filepath)[0]) for filepath in files_in_dir]
            return max(file_numbers)+1
        else:
            return 1
    label_key = {0:'center',
                 1:'down',
                 2:'left',
                 3:'right',
                 4:'up'}
    for key in dataset_dict.keys():
        images, labels = dataset_dict[key]
        for image, label in zip(images, labels):
            class_name = label_key[label]
            save_dir = os.path.join(os.path.join(path,key),class_name)
            idx = get_next_idx(save_dir)
            save_path = os.path.join(save_dir, f'{idx}.jpg')
            cv2.imwrite(save_path, image)

In [351]:
# load preprocessed data for train/dev sets
label_key = {'center':0,
             'down':1,
             'left':2,
             'right':3,
             'up':4}

preprocessed_images, preprocessed_labels = load_preprocessed_data('./data/preprocessed/', label_key)

In [352]:
# find indexes of images that are blank
blank_idxs = get_blank_idxs(preprocessed_images)
non_blank_idxs = list(set(range(len(preprocessed_images))).difference(set(blank_idxs)))

# remove those instances from the data
filtered_images = preprocessed_images[non_blank_idxs]
filtered_labels = preprocessed_labels[non_blank_idxs]

In [353]:
# rebalance classes in the dataset
rebalanced_images, rebalanced_labels = random_undersample(filtered_images, filtered_labels)
print(Counter(rebalanced_labels))

Counter({0: 5746, 1: 5746, 2: 5746, 3: 5746, 4: 5746})


In [368]:
# split rebalanced data into train/dev sets
X_train, X_dev, y_train, y_dev = train_test_split(rebalanced_images, 
                                              rebalanced_labels, 
                                              test_size=0.2, 
                                              random_state=42, 
                                              shuffle=True,
                                              stratify=rebalanced_labels)

In [369]:
destination_path = './data/'

dataset_dict = {'train':[X_train, y_train], 'dev':[X_dev, y_dev]}

# create directory for prebuilt datasets
build_directory(destination_path)

# write data to directory for use in modeling
save_prebuilt_datasets(destination_path, dataset_dict)