In [2]:
from glob import glob
import os
import IPython.display as display
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import datetime
from tensorflow.keras.layers import *
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.optimizers import Adam
from IPython.display import clear_output
from keras.preprocessing.image import ImageDataGenerator
import tensorflow_io as tfio

In [3]:
# important for reproducibility
seed = 42

# https://www.tensorflow.org/guide/data_performance#prefetching
AUTOTUNE = tf.data.experimental.AUTOTUNE

base_dir = os.path.abspath('..')
train_dir = os.path.join(base_dir,"data","train")
val_dir = os.path.join(base_dir,"data","test")

# # Image size that we are going to use
# IMG_SIZE = 128
# # Our images are RGB (3 channels)
# N_CHANNELS = 3
# # Number of classes + 1 for background
# N_CLASSES = 4

In [5]:
def parse_image(img_path: str) -> dict:
    
    """Load an image and its annotation (mask) and returning
    a dictionary.

    Parameters
    ----------
    img_path : str
        Image (not the mask) location.

    Returns
    -------
    dict
        Dictionary mapping an image and its ground truth
    """
    image = tf.io.read_file(img_path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.convert_image_dtype(image, tf.uint8)


    mask_path = tf.strings.regex_replace(img_path, "original_retinal_images", "masks_new")
    mask_path = tf.strings.regex_replace(mask_path, "jpg", "tif")
    
    mask = tf.io.read_file(mask_path)
    
    # The masks contain a class index for each pixels
    mask = tfio.experimental.image.decode_tiff(mask)
   

    return {'image': image, 'segmentation_mask': mask}



train_dataset = tf.data.Dataset.list_files(os.path.join(train_dir , 'original_retinal_images' , "*.jpg"), seed=seed)
train_dataset = train_dataset.map(parse_image)

val_dataset = tf.data.Dataset.list_files(os.path.join(val_dir , 'original_retinal_images' , "*.jpg"), seed=seed)
val_dataset = val_dataset.map(parse_image)

In [6]:
images = np.array([img['image'] for img in val_dataset.as_numpy_iterator()])


In [7]:
masks = np.array([img['segmentation_mask'] for img in val_dataset.as_numpy_iterator()],dtype='uint8')

In [8]:
images.shape

(27, 2848, 4288, 3)

In [9]:
masks.shape

(27, 2848, 4288, 4)

In [10]:
batch_size = 1
epochs = 50

# Training path
X_path= os.path.join(train_dir, 'original_retinal_images') # input image
Y_path = os.path.join(train_dir, 'masks_new') # ground-truth label

# Validation path
val_X_path = os.path.join(val_dir, 'original_retinal_images')
val_Y_path = os.path.join(val_dir, 'masks_new')

# Train data generator
x_gen_args = dict(
                        rescale=1./255,
                        #featurewise_center=True,
                        #featurewise_std_normalization=True,
                        #shear_range=0.2,
                        #zoom_range=0.5,
                        #channel_shift_range=?,
                        #width_shift_range=0.5,
                        #height_shift_range=0.5,
                        rotation_range = 10,
                        horizontal_flip=True
                    )
y_gen_args = dict(
                        #featurewise_center=True,
                        #featurewise_std_normalization=True,
                        #shear_range=0.2,
                        #zoom_range=0.5,
                        #channel_shift_range=?,
                        #width_shift_range=0.5,
                        #height_shift_range=0.5,
                        rotation_range = 10,
                        horizontal_flip=True
                    )

image_datagen = ImageDataGenerator(**x_gen_args)
mask_datagen = ImageDataGenerator(**y_gen_args)

image_datagen.fit(images, augment=True, seed=seed)
mask_datagen.fit(masks, augment=True, seed=seed)

image_generator = image_datagen.flow_from_directory(
    X_path,
    batch_size=batch_size,
    shuffle = True, # shuffle the training data
    class_mode=None, # set to None, in this case
    interpolation='nearest',
    seed=seed)

mask_generator = mask_datagen.flow_from_directory(
    Y_path,
    batch_size=batch_size,
    shuffle = True,
    class_mode=None,
    interpolation='nearest',
    seed=seed)





Found 0 images belonging to 0 classes.
Found 0 images belonging to 0 classes.


In [None]:
# combine image_ and mask_generator into one
train_generator = zip(image_generator, mask_generator)
num_train = len(image_generator)

# val data generator
image_datagen = ImageDataGenerator()
mask_datagen = ImageDataGenerator()

image_generator = image_datagen.flow_from_directory(
    val_X_path,
    target_size=(h, w),
    batch_size=batch_size,
    shuffle = False, # we dont need to shuffle validation set
    class_mode=None,
    seed=seed)

mask_generator = mask_datagen.flow_from_directory(
    val_Y_path,
    target_size=(h, w),
    batch_size=batch_size,
    shuffle = False,
    seed=seed)

val_generator = zip(image_generator, mask_generator)
num_val = len(image_generator)


In [None]:
# fit the generators
model.fit_generator(
                    train_generator,
                    steps_per_epoch = num_train/batch_size, 
                    validation_data=val_generator,
                    validation_steps =num_val/batch_size,
                    epochs=epochs,
                    verbose=1
                    )