## Resize and Scale Images and Annotations

In [None]:
IMG_SIZE = 256

@tf.function
def resize_and_scale(item):
    image, mask = item['image'], item['annotation']

    # Resize image and mask to IMG_SIZE
    image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
    mask = tf.image.resize(mask, [IMG_SIZE, IMG_SIZE], method='nearest')

    # Normalize the image
    image = tf.cast(image, tf.float32) / 255.0

    # Extract semantic mask channel
    mask = mask[..., :1]

    return image, mask

## Batch Images and Masks to Patches

In [None]:
def extract_patches(images, masks, patch_size):
    patch_count = (images.shape[1]//patch_size)**2
    sizes=[1, patch_size, patch_size, 1]
    kwargs = dict(sizes=sizes, strides=sizes, rates=[1, 1, 1, 1], padding='VALID')

    # Extract patches from  images
    image_patches = tf.image.extract_patches(images, **kwargs)
    image_patches = tf.reshape(image_patches, [-1, patch_count, patch_size, patch_size, 3])

    # Extract patches from  masks
    box_patches = tf.image.extract_patches(masks, **kwargs)
    box_patches = tf.reshape(box_patches, [-1, patch_count, patch_size, patch_size, 2])


    return image_patches, box_patches

images, masks = tf.reshape(tf.range(16*256*256*3), (16, 256, 256, 3)), tf.reshape(tf.range(16*256*256*2), (16, 256, 256, 2))
# images, masks = tf.reshape(tf.range(4*4*3), (1, 4, 4, 3)), tf.reshape(tf.range(4*4*2), (1, 4, 4, 2))
image_patches, box_patches = extract_patches(images, masks, patch_size=2)

print('Image: {} --> Patches: {}'.format(images.shape, image_patches.shape))
print('Box: {} --> Patches: {}'.format(masks.shape, box_patches.shape))

## Batch Images to Patches

In [None]:
def batch_to_patch(images, patch_size):
    grid_size = images.shape[1]//patch_size
    patch_count = grid_size**2
    sizes=[1, patch_size, patch_size, 1]
    kwargs = dict(sizes=sizes, strides=sizes, rates=[1, 1, 1, 1], padding='VALID')

    # Extract patches from  images
    image_patches = tf.image.extract_patches(images, **kwargs)
    image_patches = tf.reshape(image_patches, [-1, patch_count, patch_size, patch_size, 3])

    return image_patches

batch_size, img_size, patch_size = 2, 16, 4
input_shape = (batch_size, img_size, img_size, 3)
images = tf.reshape(tf.range(reduce(operator.mul, input_shape)), input_shape)

image_patches = batch_to_patch(images, patch_size=patch_size)
print('Image: {} --> Patches: {}'.format(images.shape, image_patches.shape))

# Image Standardization and Horizontal Flip Augmentations

In [None]:
augmentation_block = tf.keras.Sequential([
    layers.Lambda(tf.image.per_image_standardization),
    layers.RandomFlip("horizontal"),
], name='augmentation_block')

# Image and Mask Augmentations Using Albumenations

In [None]:
def get_augmenter(size):
    fn = A.Compose([
        A.RandomResizedCrop(width=size, height=size, scale=(0.5, 2.0)),
        A.HorizontalFlip(p=0.5),
    ])

    return fn

batch_size, img_size, classes = 2, 8, 3
image = tf.random.normal((img_size, img_size, 3))
mask = tf.random.uniform((img_size, img_size, 1), maxval=classes, dtype=tf.int32)

transform_fn = get_augmenter(4)
result = transform_fn(image=image.numpy(), mask=mask.numpy())
t_image, t_mask = result['image'], result['mask']

print(f'Input Image: {image.shape} --> Transformed Image: {t_image.shape}')
print(f'Input Mask: {mask.shape} --> Transformed Mask: {t_mask.shape}')

show_related_images(image, t_image, mask, t_mask)

# Image and Mask augmentation using Albumenations

In [None]:
IMG_SIZE = 256
PATCH_SIZE = 16

AUG_FN = get_augmenter(IMG_SIZE)

def transform_fn(image, mask):
    result = AUG_FN(image=image, mask=mask)
    return result['image'], result['mask']

@tf.function
def rescale_and_augment(item):
    image, mask = item['image'], item['annotation']

    # Normalize the image and extract mask
    image = tf.cast(image, tf.float32) / 255.0
    mask = mask[..., :1]

    result = tf.numpy_function(transform_fn, [image, mask], [tf.float32, tf.uint8])
    image = result[0]
    mask = result[1]

    return image, mask

train_prep_ds = train_ds.map(rescale_and_augment, num_parallel_calls=None)

def show_preprocessing_results(ds, p_ds, title='Preprocessing Results'):
    item = next(iter(ds))
    image, mask = item['image'], item['annotation']

    p_image, p_mask = next(iter(p_ds))

    print('Image: {} Mask: {} --> Image: {} Mask: {}'.format(image.shape, mask.shape, p_image.shape, p_mask.shape))

    show_related_images(image, mask, p_image, p_mask, title=title)

print('Training Set Preprocessing')
print('--------------------------')
show_preprocessing_results(train_ds, train_prep_ds, title='Training Set')