# Image Segmentation

In [1]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [18]:
import os
import json

import tensorflow as tf
import numpy as np

from tensorflow.keras.preprocessing.image import ImageDataGenerator

SEED = 1234
tf.random.set_seed(SEED)

cwd = os.getcwd()

## Dataset loading

In [19]:
bs = 8
img_w = 256
img_h = 256
validation_split = 0.1
dataset_dir = os.path.join(cwd, "Segmentation_Dataset")

# Image and masks data generators
image_data_gen = ImageDataGenerator(rotation_range = 90,
                                    width_shift_range = 10,
                                    height_shift_range = 10,
                                    zoom_range = 0.3,
                                    horizontal_flip = True,
                                    vertical_flip = True,
                                    fill_mode = 'constant',
                                    cval = 0,
                                    rescale = 1./255,
                                    validation_split = validation_split)
mask_data_gen = ImageDataGenerator(rotation_range = 90,
                                    width_shift_range = 10,
                                    height_shift_range = 10,
                                    zoom_range = 0.3,
                                    horizontal_flip = True,
                                    vertical_flip = True,
                                    fill_mode = 'constant',
                                    cval = 0,
                                    validation_split = validation_split)


train_dir = os.path.join(dataset_dir, "training")
train_img_dir = os.path.join(train_dir, "images")
train_mask_dir = os.path.join(train_dir, "masks")

# train generators
print("Training")
train_img_gen = image_data_gen.flow_from_directory(train_img_dir,
                                                   target_size = (img_h, img_w),
                                                   batch_size = bs,
                                                   class_mode = None,
                                                   shuffle = True,
                                                   interpolation = 'bilinear',
                                                   seed = SEED,
                                                   subset = 'training')
train_mask_gen = image_data_gen.flow_from_directory(train_mask_dir,
                                                   target_size = (img_h, img_w),
                                                   batch_size = bs,
                                                   class_mode = None,
                                                   shuffle = True,
                                                   interpolation = 'bilinear',
                                                   color_mode = 'grayscale',
                                                   seed = SEED,
                                                   subset = 'training')
train_gen = zip(train_img_gen, train_mask_gen)

# validation generators
print("\nValidation")
validation_img_gen = image_data_gen.flow_from_directory(train_img_dir,
                                                   target_size = (img_h, img_w),
                                                   batch_size = bs,
                                                   class_mode = None,
                                                   shuffle = True,
                                                   interpolation = 'bilinear',
                                                   seed = SEED,
                                                   subset = 'validation')
validation_mask_gen = image_data_gen.flow_from_directory(train_mask_dir,
                                                   target_size = (img_h, img_w),
                                                   batch_size = bs,
                                                   class_mode = None,
                                                   shuffle = True,
                                                   interpolation = 'bilinear',
                                                   color_mode = 'grayscale',
                                                   seed = SEED,
                                                   subset = 'validation')
validation_gen = zip(validation_img_gen, validation_mask_gen)

# datasets

def prepare_target(x_, y_):
    y_ = tf.cast(y_, tf.int32)
    return x_, y_

train_dataset = tf.data.Dataset.from_generator(lambda: train_gen,
                                              output_types = (tf.float32, tf.float32),
                                              output_shapes = ([None, img_h, img_w, 3], [None, img_h, img_w, 1]))
train_dataset = train_dataset.map(prepare_target).repeat()

validation_dataset = tf.data.Dataset.from_generator(lambda: validation_gen,
                                              output_types = (tf.float32, tf.float32),
                                              output_shapes = ([None, img_h, img_w, 3], [None, img_h, img_w, 1]))
validation_dataset = validation_dataset.map(prepare_target).repeat()

# write filenames to JSON file

filenames = {
    "training" : {},
    "validation" : {}
}

filenames["training"] = [fn.replace("img/","") for fn in train_img_gen.filenames]
filenames["validation"] = [fn.replace("img/","") for fn in validation_img_gen.filenames]



with open('dataset_split.json', 'w') as file:
    json.dump(filenames, file, indent=4)

    
# TODO: load test

Training
Found 6883 images belonging to 1 classes.
Found 6883 images belonging to 1 classes.

Validation
Found 764 images belonging to 1 classes.
Found 764 images belonging to 1 classes.
