In [1]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

In [2]:
def expand2square(img):
    width, height = img.size
    if width == height:
        return img
    
    if width > height:
        ret = Image.new('RGB', (width, width), (0, 0, 0))
        ret.paste(img, (0, (width - height // 2)))
    else:
        ret = Image.new('RGB', (height, height), (0, 0, 0))
        ret.paste(img, ((height - width // 2), 0))
        
    return ret

In [3]:
def preprocess(img, target_size):
    ret = img.convert('RGB')
    ret = expand2square(ret)
    ret = ret.resize(target_size)
    
    ret = np.array(ret)
    ret = ret.reshape((target_size[0]*target_size[1]*3, 1))
    
    return ret

In [4]:
def load_dataset(dataset_dir):
    dataset = []
    classes = {}
    target_size = (32,32)
    
    label = 0
    for class_name in os.listdir(dataset_dir):
        
        # if class_name is a directory (excluding .DS_Store)
        if not os.path.isdir(os.path.join(dataset_dir, class_name)):
            continue
        
        classes[label] = class_name
        class_img_path = os.path.join(dataset_dir, class_name)
        
        # preprocess all images in class_img_path and add them to dataset
        for img in os.listdir(class_img_path):
            img_path = os.path.join(class_img_path, img)
            print(os.path.join(class_name, img))
            # might fail to open image
            try:
                image = Image.open(img_path)
            except:
                continue
                
            preprocessed_image = preprocess(image, target_size)
            
            assert preprocessed_image.shape == (target_size[0]*target_size[1]*3, 1), "image shape = {}".format(preprocessed_image.shape)
            
            dataset.append((preprocessed_image, label))
        
        label += 1
    
    random.shuffle(dataset)
    
    return dataset, classes

In [5]:
def split_data(dataset, train_size, test_size):
    X_train = list(map(lambda x: x[0], dataset[:train_size]))
    Y_train = list(map(lambda x: x[1], dataset[:train_size]))
    X_test = list(map(lambda x: x[0], dataset[train_size:train_size+test_size]))
    Y_test = list(map(lambda x: x[1], dataset[train_size:train_size+test_size]))
    
    X_train = np.concatenate(X_train, axis=1)
    X_test = np.concatenate(X_test, axis=1)
    Y_train = np.array(Y_train)
    Y_test = np.array(Y_test)
    
    return X_train, Y_train, X_test, Y_test