# Text Transformations

# Vision Transformations

## Resize and Scale Images and Annotations

In [None]:
IMG_SIZE = 256

@tf.function
def resize_and_scale(item):
    image, mask = item['image'], item['annotation']

    # Resize image and mask to IMG_SIZE
    image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
    mask = tf.image.resize(mask, [IMG_SIZE, IMG_SIZE])

    # Normalize the image
    image = tf.cast(image, tf.float32) / 255.0

    # Round the mask values
    mask = tf.cast(tf.math.round(mask), tf.uint16)

    return image, mask

## Batch Images and Masks to Patches

In [None]:
def extract_patches(images, masks, patch_size):
    patch_count = (images.shape[1]//patch_size)**2
    sizes=[1, patch_size, patch_size, 1]
    kwargs = dict(sizes=sizes, strides=sizes, rates=[1, 1, 1, 1], padding='VALID')

    # Extract patches from  images
    image_patches = tf.image.extract_patches(images, **kwargs)
    image_patches = tf.reshape(image_patches, [-1, patch_count, patch_size, patch_size, 3])

    # Extract patches from  masks
    box_patches = tf.image.extract_patches(masks, **kwargs)
    box_patches = tf.reshape(box_patches, [-1, patch_count, patch_size, patch_size, 2])


    return image_patches, box_patches

images, masks = tf.reshape(tf.range(16*256*256*3), (16, 256, 256, 3)), tf.reshape(tf.range(16*256*256*2), (16, 256, 256, 2))
# images, masks = tf.reshape(tf.range(4*4*3), (1, 4, 4, 3)), tf.reshape(tf.range(4*4*2), (1, 4, 4, 2))
image_patches, box_patches = extract_patches(images, masks, patch_size=2)

print('Image: {} --> Patches: {}'.format(images.shape, image_patches.shape))
print('Box: {} --> Patches: {}'.format(masks.shape, box_patches.shape))

## Batch Images to Patches

In [None]:
def batch_to_patch(images, patch_size):
    grid_size = images.shape[1]//patch_size
    patch_count = grid_size**2
    sizes=[1, patch_size, patch_size, 1]
    kwargs = dict(sizes=sizes, strides=sizes, rates=[1, 1, 1, 1], padding='VALID')

    # Extract patches from  images
    image_patches = tf.image.extract_patches(images, **kwargs)
    image_patches = tf.reshape(image_patches, [-1, patch_count, patch_size, patch_size, 3])

    return image_patches

batch_size, img_size, patch_size = 2, 16, 4
input_shape = (batch_size, img_size, img_size, 3)
images = tf.reshape(tf.range(reduce(operator.mul, input_shape)), input_shape)

image_patches = batch_to_patch(images, patch_size=patch_size)
print('Image: {} --> Patches: {}'.format(images.shape, image_patches.shape))