# DATASET CREATION

## 1-IMPORTS

In [None]:
from AxonDeepSeg.data import dataset_building
import os, shutil
from scipy.misc import imread, imsave

## 2-PARAMETERS

First we set the path parameters

In [None]:
path_data = '../../data/TEM_3classes_raw/TEM_segmented_train/' # where to find the raw data
trainingset_path = '../data/TEM_3classes_512_train/' # where to put the generated dataset

Important! We now define the random seed. This will enable us to reproduce the exact same images each time we use the same random seed.

This will be used to enable the generation the same validation set and the same testing set.

In [None]:
rng = 2017

Note: rng used for the different datasets:
- SEM: 2017

We then call the function build_dataset. It will automatically create the dataset in the previously specified folder.

## 3-CREATING TRAIN AND VALIDATION SET

Change thresh indices if you want to generate a mask with different classes.

In [None]:
dataset_building.build_dataset(path_data, trainingset_path, thresh_indices=[0, 0.2, 0.8], random_seed=rng, trainRatio=1.0, data_type='TEM')

## 4-HAND CHOOSING TEST SET

Here, you should pick >> **in the generated train set** << the images you want to include in the test set.

To do that, move them from the training set path to the test set path.

With the usual structure of the project, this looks like this:

    -- data/
    ---- dataset_1/
    ------ training/
    -------- train/ __ move images and masks to use for the test set from here __
    -------- validation/
    ------ testing/ __ put them here __
    -------- raw/



## 5-RENUMBERING THE TRAINING SET

Now that we have extracted the images from the training set to the test set, we need to renumber them for our algorithm to be able to work.

No need to renumber the test set for the moment as it does not intervene in the training phase.

In [None]:
i = 0
j = 0

subpath_data = trainingset_path + '/training/Train'
temp_path = trainingset_path + '/temp/'
os.mkdir(temp_path)

# Renumbering data

for data in os.listdir(subpath_data):
    if 'image' in data:
        img = imread(os.path.join(subpath_data, data), flatten=False, mode='L')
        imsave(temp_path + '/image_%s.png'%i, img, 'png')
        i=i+1
    elif 'mask' in data:
        mask = imread(os.path.join(subpath_data, data), flatten=False, mode='L')
        imsave(temp_path + '/mask_%s.png'%j, mask, 'png')
        j=j+1
        
# Replacing old images and masks by new images and mask

filelist = [ f for f in os.listdir(subpath_data) if f.endswith(".png") ]
for f in filelist:
    os.remove(os.path.join(subpath_data,f))
    
filelist = [ f for f in os.listdir(temp_path) if f.endswith(".png") ]
for f in filelist:
    shutil.move(os.path.join(temp_path,f),subpath_data)
    
shutil.rmtree(temp_path)

In [None]:
import numpy as np

In [None]:
L_img, L_mask = [], []
i = 0
j = 0

subpath_data = trainingset_path + '/training/Validation'
temp_path = trainingset_path + '/temp/'
os.mkdir(temp_path)

# Renumbering data
for data in os.listdir(subpath_data):
    data_name = data[:-4].split('_')
    if 'image' in data:
        img = imread(os.path.join(subpath_data, data), flatten=False, mode='L')
        L_img.append((img, int(data_name[-1])))

    elif 'mask' in data:
        mask = imread(os.path.join(subpath_data, data), flatten=False, mode='L')
        L_mask.append((mask, int(data_name[-1])))
        

    # We sort the transformations to make by the number preceding the transformation in the dict in the config file        
    L_img_sorted = sorted(L_img, key=lambda x: int(x[1])) 
    L_mask_sorted = sorted(L_mask, key=lambda x: int(x[1]))
    
for img,k in L_img_sorted:
    imsave(temp_path + '/image_%s.png'%i, img, 'png')
    i = i+1
    
for mask,k in L_mask_sorted:
    imsave(temp_path + '/mask_%s.png'%j, mask, 'png')
    j =j+1

# Replacing old images and masks by new images and mask

filelist = [ f for f in os.listdir(subpath_data) if f.endswith(".png") ]
for f in filelist:
    os.remove(os.path.join(subpath_data,f))
    
filelist = [ f for f in os.listdir(temp_path) if f.endswith(".png") ]
for f in filelist:
    shutil.move(os.path.join(temp_path,f),subpath_data)
    
shutil.rmtree(temp_path)

You're all done!