Data augmentation is done in this notebook and not while running because keras doesn't support generation of multichannel in y (for channels that don't exist in x)

In [None]:
from __future__ import print_function
import keras
from keras.preprocessing.image import ImageDataGenerator
from keras.models import load_model

from keras.models import Model

from sklearn.utils import shuffle

import numpy as np
from time import time


In [None]:
import matplotlib
matplotlib.use('Agg');
import matplotlib.pyplot as plt
plt.set_cmap('Greys');

%matplotlib inline

In [2]:
input_dir = '/YOURPATH/'

output_dir = "/YOUR_OUTPUT_DIR/"

#### Data augmentation:

In [None]:
def load_input_set_shuffle_and_split(is_shuffle=True, shuffle_seed=0, n_train=20, is_norm_y_ims=False):
    # Load X and Y
    x = np.load(os.path.join(input_folder, 'npy_data', f'{params["x_file"]}.npy'))
    y = np.load(os.path.join(input_folder, 'npy_data', f'{params["y_file"]}.npy'))
    
    if y.ndim == 3:
        y = y.astype('float32')[:,:,:, None]

    print(f'x size: {x.shape}')
    print(f'y size: {y.shape}')
    
    # Check Y Range:
    print(f'max Y: {np.max(y)}, min Y: {np.min(y)}')
    
    
    if is_shuffle:
        x, y = shuffle(x, y, random_state=shuffle_seed)
        
    # Normalize Y - each image:
    # Already done for euclidean dist images.
    if is_norm_y_ims:
        for i in range(y.shape[0]):
            for j in range(y.shape[3]):
                im = y[i,:,:,j]
                y[i,:,:,j] = im/np.max(im)
    
    
    # Split data to train and test:
    (x_train, y_train) = x[:n_train], y[:n_train]
    (x_test, y_test) = x[n_train:], y[n_train:]

    print(f'y train size: {y_train.shape}')
    print(f'y test size: {y_test.shape}')

    x_train = x_train.astype('float32')[:,:,:, None]
    x_test = x_test.astype('float32')[:,:,:, None]

    print(f'x train size: {x_train.shape}')
    print(f'x test size: {x_test.shape}')
    
    # Print x normalization data:
    # Check that x is already normalized:
    print(f'mean x: {np.mean(x)}')
    print(f'std x: {np.std(x)}')
    print(f'min x: {np.min(x)}')
    print(f'max x: {np.max(x)}')
    
    print(f'mean x_train: {np.mean(x_train)}')
    print(f'std x_train: {np.std(x_train)}')
    print(f'min x_train: {np.min(x_train)}')
    print(f'max x_train: {np.max(x_train)}')
    
    return x_train, x_test, y_train, y_test

In [None]:
def show_images(x, y, n_ims_show=5):
    
    # examples of the x images 
    plt.figure()
    plt.rcParams['figure.figsize'] = (15, 5)
    plt.imshow(np.concatenate(x[:n_ims_show,:,:,0],axis=1), interpolation='none')
    plt.axis('off');
    # examples of the y images - first landmark
    y_32 = y.astype(np.float32)
    plt.figure()
    plt.rcParams['figure.figsize'] = (15, 5)
    plt.imshow(np.concatenate(y_32[:n_ims_show,:,:,0],axis=1), interpolation='none')
    plt.axis('off');
    
    # example of overlay all y on x:
    y_overlay = np.max(y_32, axis=3)
    xy_overlay = x[:,:,:,0] + y_overlay
    plt.figure()
    plt.rcParams['figure.figsize'] = (15, 5)
    plt.imshow(np.concatenate(xy_overlay[:n_ims_show,:,:],axis=1), interpolation='none')
    plt.axis('off');

In [None]:
import csv
import os

csv_file = '1'

csv_folder = os.path.join(input_folder, 'csv_dir')

with open(os.path.join(csv_folder,f'{csv_file}.csv'), newline='') as csvfile:
    reader = csv.reader(csvfile, delimiter=',')
    params = {row[0]:row[1] for row in reader}

In [None]:
x_train,x_test,y_train,y_test = load_input_set_shuffle_and_split(show_images)

training_set = [x_train, y_train]

In [None]:
data_gen_args = dict(rotation_range=15,
                               width_shift_range=0.1,
                               height_shift_range=0.1,
                               shear_range=0.01,
                               zoom_range=[0.95, 1.05],
                               horizontal_flip=True,
                               vertical_flip=True,
                               fill_mode='reflect',
                               data_format='channels_last',
                               brightness_range=[0.8, 1.2])

image_datagen = ImageDataGenerator(**data_gen_args)
n_y_channels = y_train.shape[3]
#n_y_channels = 3
mask_datagen_l = [ImageDataGenerator(**data_gen_args) for i in range(n_y_channels)]

seed = 1
image_datagen.fit(x_train, augment=True, seed=seed)

for i in range(n_y_channels):
    mask_datagen_l[i].fit(y_train[:,:,:,i,np.newaxis], augment=True, seed=seed)

# fits the model on batches with real-time data augmentation:
image_generator = image_datagen.flow(x_train,seed=seed, batch_size=32, save_to_dir=output_dir, save_prefix='', save_format='png')
mask_generator_l = [mask_datagen_l[i].flow(y_train[:,:,:,i,np.newaxis], seed=seed, batch_size=32, save_to_dir=output_dir, save_prefix='', save_format='png') for i in range(n_y_channels)]

train_generator = zip(image_generator, *mask_generator_l)

In [None]:
train_generator = zip(image_generator, *mask_generator_l)

In [None]:
new_augments = np.asarray(next(train_generator))

In [None]:
im_num = 10
train_example = new_augments[:,im_num,:,:,0]

In [None]:
x1 = train_example[0,:,:]
y1 = np.max(train_example[1:], axis=0)

In [None]:
xy1 = x1+y1
fig = plt.figure(3, figsize=(10,10))
ax = fig.add_subplot(1, 2, 1)
plt.imshow(xy1)
plt.axis('off');

xy2 = x1
ax = fig.add_subplot(1, 2, 2)
plt.imshow(xy2)
plt.axis('off');