# Image Subsetter
---

In [3]:
from sklearn.model_selection import train_test_split
from skimage.transform import AffineTransform, warp
from skimage.transform import resize
from skimage.morphology import disk
from skimage.filters import rank
import matplotlib.pyplot as plt
from scipy.misc import imread
from tqdm import tqdm
import numpy as np
import pickle
import glob
import uuid


% matplotlib inline

In [14]:
def get_image_id(image_filepath):
    image_name   = image_filepath.split('\\')[-1]
    image_id_str = image_name[:-4]
    return image_id_str


def get_mask_from_id(image_id, masked_image_filepaths):
    # Return annotation image associated with image_id
    ann_ix = [get_image_id(fp) for fp in masked_image_filepaths].index(image_id)
    return imread(masked_image_filepaths[ann_ix])


def resize_image_and_mask(image, mask, dst_size):
    image = resize(image, dst_size, order = 0)
    mask  = resize(mask, dst_size, order  = 0)
    return image, mask


def train_validation_test_split(X, y, frac_train=0.6, frac_validate=0.15, frac_test=0.25, rng = 1):
    try:
        assert sum([frac_train, frac_validate, frac_test]) == 1
    except AssertionError:
        raise ValueError('The training, test, and validation fractions do not sum to 1.')
        
    frac_train = frac_train/(frac_train + frac_validate)
    X_train, X_test, y_train, y_test  = train_test_split(X, y, train_size = 1 - frac_test,
                                                         test_size = frac_test, random_state = rng,
                                                         stratify = y)
    X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, train_size = frac_train,
                                                          test_size = 1 - frac_train, random_state = rng,
                                                          stratify = y_train)
    return X_train, X_valid, X_test, y_train, y_valid, y_test


def pickle_data(dataset, labels, name):
    pickle.dump(dataset, open('.\\data\\' + name + '_dataset.p', 'wb'))
    pickle.dump(labels, open('.\\data\\' + name + '_labels.p', 'wb'))
    print('Data pickled successully.')
    
    
def augment_image(img):
    flip_v     = np.random.randint(0, 2)
    flip_h     = np.random.randint(0, 2)
    scale_f    = np.random.uniform(1/1.02, 1) # Some of the cracks are at the extremes of the frame
    tform      = AffineTransform(scale = (scale_f, scale_f))
    aug_img    = warp(img, tform)
    if flip_v: aug_img = np.flip(aug_img, axis = 0)
    if flip_h: aug_img = np.flip(aug_img, axis = 1)
    aug_img    = rank.equalize(aug_img, selem=disk(7)).astype(float)
    aug_img    = (aug_img - 128)/128 # skimage bumps the images from [0, 1] to [0, 255]
    return aug_img


def balance_dataset(dataset, labels):
    n_labels    = len(np.unique(labels))
    min_count   = min([sum(labels == class_label) for class_label in np.unique(labels)])
    max_count   = min_count*4
    bal_dataset = np.zeros([max_count*n_labels, dataset.shape[1]])
    bal_labels  = np.zeros([max_count*n_labels])
    ix = 0
    for class_label in np.unique(labels):
        for _ in range(max_count):
            class_ix = np.nonzero(labels == class_label)[0]
            rnd_ix   = np.random.choice(class_ix)
            bal_dataset[ix, :] = augment_image(dataset[rnd_ix, :].reshape(kernel_size)).reshape(kernel_size[0]*kernel_size[1])
            bal_labels[ix]     = labels[rnd_ix]
            ix += 1
    print('Dataset successfully balanced.')
    return bal_dataset, bal_labels

def shuffle_dataset(dataset, labels):
    shuffle_ix = np.random.choice(range(labels.shape[0]), labels.shape[0], replace = False)
    dataset    = np.array([dataset[ix, :] for ix in shuffle_ix])
    labels     = np.array([labels[ix] for ix in shuffle_ix])
    return dataset, labels

In [5]:
image_filepaths = glob.glob('.\\data\\resized-images\\*.png')
mask_filepaths  = glob.glob('.\\data\\annotated-images-masks\\*.png')

image_size    = [608, 608]
kernel_size = [128, 128]
step_size   = [40, 40]

n_cols    = np.floor((image_size[1] - kernel_size[1])/step_size[1])
n_rows    = np.floor((image_size[0] - kernel_size[0])/step_size[0])
n_patches = len(image_filepaths)*(n_rows*n_cols).astype(int)

dataset   = np.zeros([n_patches, kernel_size[0], kernel_size[1]])
masks     = np.zeros([n_patches, kernel_size[0], kernel_size[1]])
labels    = np.zeros(n_patches)
ids       = np.array([None]*n_patches)
ix        = 0

for fp in tqdm(image_filepaths):
    image       = imread(fp)
    image_id    = get_image_id(fp)
    mask        = get_mask_from_id(image_id, mask_filepaths)
    image, mask = resize_image_and_mask(image, mask, image_size)
    for r in range(0, image_size[0] - kernel_size[0], step_size[0]):
        for c in range(0, image_size[1] - kernel_size[1], step_size[1]):
            image_patch       = image[r:r + kernel_size[0], c:c + kernel_size[1]]
            mask_patch        = mask[r:r + kernel_size[0], c:c + kernel_size[1]]
            dataset[ix, :, :] = image_patch
            masks[ix, :, :]   = mask_patch
            labels[ix]        = np.heaviside(np.sum(mask_patch) - 40, 0)
            ids[ix]           = uuid.uuid4().hex
            ix += 1

dataset = dataset.reshape([dataset.shape[0], kernel_size[0]*kernel_size[0]])
masks = masks.reshape([masks.shape[0], kernel_size[0]*kernel_size[0]])

100%|██████████████████████████████████████████| 52/52 [00:03<00:00, 13.30it/s]


In [6]:
(train_dataset, valid_dataset,
 test_dataset, train_labels, 
 valid_labels, test_labels) = train_validation_test_split(dataset, labels)
print('Images subset and split successfully.')

Images subset and split successfully.


In [7]:
pickle_data(2*(valid_dataset - 1), valid_labels, 'valid')
pickle_data(2*(test_dataset  - 1), test_labels, 'test')
print('Testing and validation sets successfully pickled.')

Data pickled successully.
Data pickled successully.
Testing and validation sets successfully pickled.


In [15]:
btrain_dataset, btrain_labels = balance_dataset(train_dataset, train_labels)

  "%s to %s" % (dtypeobj_in, dtypeobj))


Dataset successfully balanced.


In [17]:
btrain_dataset, btrain_labels = shuffle_dataset(btrain_dataset, btrain_labels)
pickle_data(btrain_dataset, btrain_labels, 'train')
print('Training set successfully pickled.')

Data pickled successully.
Training set successfully pickled.
