# Demo: database
Create and seperate a TF database with augmented images and files for testing, training and validation.

In [1]:
import os
import sys
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

# Automatically reload imported programmes
%load_ext autoreload
%autoreload 2


# Directories (ammend as necessary)
root_dir = '/content/gdrive/MyDrive/IDSAI/PROOF/filament-segmentation'
os.chdir(root_dir)  # Move to root_dir
sys.path.insert(0, root_dir)


# Database choice
batch_size = 10
num_patches = 1  # Subsample taining data
num_duplicates = 30  # Repeats of subsamples to augment
apply_augmentation = True
shuffle_on = True


# Locate data
data_file = 'tomograms2D/all'  # No leading/trailing `/`
database_name = 'all-2D-augmented'


# Add data to root directory and locate JSON file
data_dir = os.path.join(root_dir, 'data/' + data_file)
image_path = os.path.join(data_dir, 'png-original')
masks_path = os.path.join(data_dir, 'png-masks/semantic/*.png')



# New training and validation files
train_dir = os.path.join(root_dir, 'data/databases/' + database_name + '/train')
valid_dir = os.path.join(root_dir, 'data/databases/' + database_name + '/valid')

Mounted at /content/gdrive


## Create TF database for training

In [None]:
import tensorflow as tf
from loader import augment_data, get_data


print('\nProcessing data...')
train_imgs, train_msks, valid_imgs, valid_msks, _, _ = \
    get_data(path_train_imgs=image_path,
             path_train_msks=masks_path,
             path_valid_imgs='',
             path_valid_msks='',
             train_frac=0.8,
             valid_frac=0.1,
             image_size=[256, 256],
             num_patches_per_image=num_patches,
             num_duplicates_per_image=num_duplicates,
             )
    
train_set, valid_set = augment_data(train_imgs,
                                    train_msks,
                                    valid_imgs,
                                    valid_msks,
                                    batch_size,
                                    one_hot=False,
                                    augment_on=apply_augmentation,
                                    shuffle_on=True,
                                    )

tf.data.experimental.save(train_set, train_dir)
tf.data.experimental.save(valid_set, valid_dir)

print('Data processed and saved.\n')
print('Training set length: ', len(train_set))
print('Validation set length: ', len(valid_set))