In [0]:
import os, utils
import mxnet as mx
from mxnet import gluon, image, nd
from mxnet.gluon import data
import cv2
%matplotlib inline

VERBOSE = True
TRAIN_FILE = True


def get_file_name(data_dir, verbose=VERBOSE):
    train_dir = os.path.join(data_dir, 'SBUTrain4KRecoveredSmall', 'ShadowImages')
    test_dir = os.path.join(data_dir, 'SBU-Test', 'ShadowImages')
    
    # get all file name without ext
    train_file_name = list(map(lambda x:os.path.splitext(x)[0], os.listdir(train_dir)))
    test_file_name = list(map(lambda x:os.path.splitext(x)[0], os.listdir(test_dir)))
    
    if verbose:
        print('Get {} train image, {} test image.'.format(len(train_file_name), len(test_file_name)))
        
    return train_file_name, test_file_name


def read_file(data_dir, train_file, verbose=VERBOSE):
    if train_file:
        file_name, _ = get_file_name(data_dir)
        data_dir = os.path.join(data_dir, 'SBUTrain4KRecoveredSmall')
    else:
        _, file_name = get_file_name(data_dir)
        data_dir = os.path.join(data_dir, 'SBU-Test')
        
    features, labels = [None] * len(file_name), [None] * len(file_name)
    
    for i, fname in enumerate(file_name):
        features[i] = image.imread('{}/ShadowImages/{}.jpg'.format(data_dir, fname))
        labels[i] = image.imread('{}/ShadowMasks/{}.png'.format(data_dir, fname))
        if verbose and i%500 == 0:
            print('{} images loaded...'.format(i))
    
    if verbose and train_file:
        print('Loaded all {} features and labels on train dataset.'.format(len(features)))
    elif verbose and not train_file:
        print('Loaded all {} features and labels on test dataset.'.format(len(features)))
        
    return features, labels

    
def label_transform(img):
    img = cv2.cvtColor(img.asnumpy(), cv2.COLOR_RGB2GRAY)
    img = img / 255.
    return nd.array(img)
    

# random crop
def random_crop(feature, label, width, height):
    # image.random_crop(data, height, width) different with the array's shape (width height)
    feature, rect = image.random_crop(feature, (height, width))
    label = image.fixed_crop(label, *rect)
    return feature, label


class ShadowDataset(data.Dataset):
    def __init__(self, train_file, crop_size, data_dir, verbose=True):
        self.rgb_mean = nd.array([0.485, 0.456, 0.406])
        self.rgb_std = nd.array([0.229, 0.224, 0.225])
        self.crop_size = crop_size
        features, labels = read_file(data_dir, train_file=train_file)
        self.features = [self.normalize_image(feature)
                         for feature in self.filter(features)]
        self.labels = self.filter(labels)
        if verbose:
          print('Remain {} images can be cropped and normalized.'.format(len(self.features)))
            
    # drop img with error size which cannot be cropped
    def filter(self, imgs):
        new_imgs = []
        for img in imgs:
          if img.shape[0] >= self.crop_size[0] and img.shape[1] >= self.crop_size[1]:
            new_imgs.append(img)
        
        return new_imgs
    
    def normalize_image(self, img):
        return (img.astype('float32') / 255 - self.rgb_mean) / self.rgb_std
    
    def __getitem__(self, index):
        feature, label = random_crop(self.features[index], self.labels[index], *self.crop_size)
        return (feature.transpose((2, 0, 1)), label_transform(label))

    def __len__(self):
      return len(self.features)
    

In [0]:
def test_read():
    features, labels = read_file(train_file=True)
    imgs = features[0:5] + labels[0:5]
    utils.show_images(imgs, 2, 5)
    
# test_read()

def test_crop():
    features, labels = read_file(train_file=True)
    imgs = []
    for _ in range(5):
        imgs += random_crop(features[0], labels[0], 300, 200)
    utils.show_images(imgs[::2] + imgs[1::2], 2, 5);

# test_crop()