# Demo: database
Create and seperate a TF database with augmented images and files for testing, training and validation.

In [4]:
import os
import sys
from google.colab import drive


# Mount Google Drive
drive.mount('/content/gdrive', force_remount=True)

# Automatically reload imported programmes
%load_ext autoreload
%autoreload 2


# Database name
dataset_name = 'all2D'  # Refers to data_file = 'tomograms2D/all'
augmentation_choice = 'none'  # Choose from 'zoom'/'full'/'none'
database_name = dataset_name + '-' + augmentation_choice


# Directories (ammend as necessary)
root_dir = '/content/gdrive/MyDrive/IDSAI/PROOF/filament-segmentation'
os.chdir(root_dir)  # Move to root_dir
sys.path.insert(0, root_dir)


# RAW data location
data_file = 'tomograms2D/all'  # No leading/trailing `/`
data_dir = os.path.join(root_dir, 'data/' + data_file)
image_path = os.path.join(data_dir, 'png-original')
masks_path = os.path.join(data_dir, 'png-masks/semantic/*.png')


# New training and validation files
train_dir = os.path.join(root_dir, 'data/databases/' + database_name + '/train')
valid_dir = os.path.join(root_dir, 'data/databases/' + database_name + '/valid')

Mounted at /content/gdrive
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Create TF database for training

Set parameters for augmentation.

In [5]:
batch_size = 10
shuffle_on = True

if augmentation_choice == 'zoom':
    num_patches = 20  # Subsample taining data before augmenting
    num_duplicates = 1
    apply_augmentation = True

elif augmentation_choice == 'full':
    num_patches = 1
    num_duplicates = 30  # Duplicate full image to augment
    apply_augmentation = True
    
elif augmentation_choice == 'none':
    num_patches = 1
    num_duplicates = 1
    apply_augmentation = False  # No augmentation (small dataset)

else:
    raise ValueError('Please select a pre-defined `augmentation_choice`.')

Load and save dataset.

In [6]:
import tensorflow as tf
from loader import augment_data, get_data

print('\nProcessing data...')
train_imgs, train_msks, valid_imgs, valid_msks, _, _ = \
    get_data(path_train_imgs=image_path,
                path_train_msks=masks_path,
                path_valid_imgs='',
                path_valid_msks='',
                train_frac=0.8,
                valid_frac=0.1,
                image_size=[256, 256],
                num_patches_per_image=num_patches,
                num_duplicates_per_image=num_duplicates,
                )
    
train_set, valid_set = augment_data(train_imgs,
                                    train_msks,
                                    valid_imgs,
                                    valid_msks,
                                    batch_size,
                                    one_hot=False,
                                    augment_on=apply_augmentation,
                                    shuffle_on=True,
                                    )

tf.data.experimental.save(train_set, train_dir)
tf.data.experimental.save(valid_set, valid_dir)

print('Data processed and saved.\n')
print('Training set length: ', len(train_set))
print('Validation set length: ', len(valid_set))


Processing data...


100%|██████████| 186/186 [12:14<00:00,  3.95s/it]


Data processed and saved.

Training set length:  15
Validation set length:  2
