# Global Greening

## Installing & Importing Libraries

In [None]:
import pickle
import numpy as np
import pandas as pd
from PIL import Image
from patchify import patchify
import albumentations as A
from IPython.display import SVG
import graphviz
import matplotlib.pyplot as plt
%matplotlib inline
import os, re, sys, random, shutil, cv2

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam, Nadam
from tensorflow.keras import applications, optimizers
from tensorflow.keras.applications import InceptionResNetV2
from tensorflow.keras.applications.resnet50 import preprocess_input

from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from tensorflow.keras.utils import model_to_dot, plot_model
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping, CSVLogger, LearningRateScheduler
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, ZeroPadding2D, Dropout

from sklearn.preprocessing import MinMaxScaler

## Prepara Data Augmentation

**Augmentation using Albumentations Library**

[Albumentations](https://albumentations.ai/) is a Python library for fast and flexible image augmentations. Albumentations efficiently implements a rich variety of image transform operations that are optimized for performance, and does so while providing a concise, yet powerful image augmentation interface for different computer vision tasks, including object classification, segmentation, and detection.

Data augmentation is done by the following techniques:

1. Random Cropping - left out since we will have same size pictures
2. Horizontal Flipping
3. Vertical Flipping
4. Rotation
5. Random Brightness & Contrast
6. Contrast Limited Adaptive Histogram Equalization (CLAHE)
7. Grid Distortion
8. Optical Distortion

In [None]:
# function to augment
def augment(): #width, height
    transform = A.Compose([
#        A.RandomCrop(width=width, height=height, p=1.0),
        A.HorizontalFlip(p=1.0),
        A.VerticalFlip(p=1.0),
        A.Rotate(limit=[60, 300], p=1.0, interpolation=cv2.INTER_NEAREST),
        A.RandomBrightnessContrast(brightness_limit=[-0.2, 0.3], contrast_limit=0.2, p=1.0),
        A.OneOf([
            A.CLAHE (clip_limit=1.5, tile_grid_size=(8, 8), p=0.5),
            A.GridDistortion(p=0.5),
            A.OpticalDistortion(distort_limit=1, shift_limit=0.5, interpolation=cv2.INTER_NEAREST, p=0.5),
        ], p=1.0),
    ], p=1.0)
    
    return transform

In [None]:
# visualize the augmentations

def visualize(image, mask, original_image=None, original_mask=None):
    fontsize = 16

    if original_image is None and original_mask is None:
        f, ax = plt.subplots(2, 1, figsize=(10, 10)) 

        ax[0].imshow(image)
        ax[1].imshow(mask)
    else:
        f, ax = plt.subplots(2, 2, figsize=(16, 12))  

        ax[0, 0].imshow(original_image)
        ax[0, 0].set_title('Original Image', fontsize=fontsize)

        ax[1, 0].imshow(original_mask)
        ax[1, 0].set_title('Original Mask', fontsize=fontsize)

        ax[0, 1].imshow(image)
        ax[0, 1].set_title('Transformed Image', fontsize=fontsize)

        ax[1, 1].imshow(mask)
        ax[1, 1].set_title('Transformed Mask', fontsize=fontsize)
        
    plt.savefig('sample_augmented_image.png', facecolor= 'w', transparent= False, bbox_inches= 'tight', dpi= 100)

## Loading the Data

In [None]:
# check where we are
!pwd

In [None]:
# load the data
dataset_root_folder = '/Users/Alenka/code/Alastair908/Downloads'
dataset_name = 'Dubai_data'

In [None]:
# loading images and masks into the list - new version wihtout cv2.

images_dataset = []
masks_dataset = []

for image_type in ['images' , 'masks']:
    for tile_id in range(1,9):
        for image_id in range(1,10):                    
            if image_type == 'images':
                image_extension = 'jpg'
                path_image = f'{dataset_root_folder}/{dataset_name}/Tile {tile_id}/{image_type}/image_part_00{image_id}.{image_extension}'
                print(path_image)
                image = Image.open(path_image)
                image = image.resize((512,512))
                images_dataset.append(np.array(image))
                print('appended image')
            elif image_type == 'masks':
                image_extension = 'png'
                path_mask = f'{dataset_root_folder}/{dataset_name}/Tile {tile_id}/{image_type}/image_part_00{image_id}.{image_extension}'
                print(path_mask)
                mask = Image.open(path_mask)
                mask = mask.resize((512,512)).convert('RGB')
                masks_dataset.append(np.array(mask))
                print('appended mask')

In [None]:
len(images_dataset), len(masks_dataset)

In [None]:
image = images_dataset[0] 
mask = masks_dataset[0] 

print(image.shape, mask.shape, type(image))

In [None]:
f, ax = plt.subplots(1, 2, figsize=(6, 6)) 
ax[0].imshow(images_dataset[0])
ax[1].imshow(masks_dataset[0])

In [None]:
# this piece of code causing trouble
# visualize(image, mask)

## Image masks

The images are densely labeled and contain the following 6 classes:

| Name       | R   | G   | B   | Color                                                                                              |
| ---------- | --- | --- | --- | -------------------------------------------------------------------------------------------------- |
| Building   | 60  | 16  | 152 | <p align="center"><div style="background-color: rgb(60, 16, 152); padding: 10px; "/></p>   |
| Land       | 132 | 41  | 246 | <p align="center"><div style="background-color: rgb(132, 41, 246); padding: 10px; "/></p>   |
| Road       | 110 | 193 | 228 | <p align="center"><div style="background-color: rgb(110, 193, 228); padding: 10px; "/></p>   |
| Vegetation | 254 | 221 | 58  | <p align="center"><div style="background-color: rgb(254, 221, 58); padding: 10px; "/></p>   |
| Water      | 226 | 169 | 41  | <p align="center"><div style="background-color: rgb(226, 169, 41); padding: 10px; "/></p>   |
| Unlabeled  | 155 | 155 | 155 | <p align="center"><div style="background-color: rgb(155, 155, 155); padding: 10px; "/></p>   |


## Perform Augmentation

In [None]:
transform = augment()
transformed = transform(image=image, mask=mask)
transformed_image = transformed['image']
transformed_mask = transformed['mask']

visualize(transformed_image, transformed_mask, image, mask)

In [None]:
def augment_dataset(count):
    '''Function for data augmentation
        Input:
            count - total no. of images after augmentation = initial no. of images * count
        Output:
            writes augmented images (input images & segmentation masks) to the working directory
    '''
    transform = augment() 
    aug_images_dataset = []
    aug_masks_dataset = []
    
    i = 0
    for i in range(count):
        for j in range(len(images_dataset)):
            img = images_dataset[j]
            msk = masks_dataset[j] 
            
            transformed = transform(image=img, mask=msk)
            transformed_image = transformed['image']
            transformed_mask = transformed['mask'] 
       
            aug_images_dataset.append(transformed_image)
            aug_masks_dataset.append(transformed_mask)
    return aug_images_dataset, aug_masks_dataset 

In [None]:
aug_images_dataset, aug_masks_dataset  = augment_dataset(8)

In [None]:
len(images_dataset), len(masks_dataset), len(aug_images_dataset), len(aug_masks_dataset)

In [None]:
image_sizes = []

for i in range(len(images_dataset)):
    image_sizes.append(images_dataset[i].shape)

image_sizes
min(image_sizes), max(image_sizes)

In [None]:
mask_sizes = []

for i in range(len(masks_dataset)):
    mask_sizes.append(masks_dataset[i].shape)

mask_sizes
min(mask_sizes), max(mask_sizes)

In [None]:
# show first 3 augmented images + masks for first image 

f, ax = plt.subplots(2, 4, figsize=(12, 6)) 

ax[0,0].imshow(images_dataset[0])
ax[0,1].imshow(masks_dataset[0])

ax[0,2].imshow(aug_images_dataset[0])
ax[0,3].imshow(aug_masks_dataset[0])

ax[1,0].imshow(aug_images_dataset[72])
ax[1,1].imshow(aug_masks_dataset[72])

ax[1,2].imshow(aug_images_dataset[144])
ax[1,3].imshow(aug_masks_dataset[144])


## Preparing labels

In [None]:
#labels_dict = {"classes": [{"title": "Water", "shape": "polygon", "color": "#50E3C2", "geometry_config": {}}, {"title": "Land (unpaved area)", "shape": "polygon", "color": "#F5A623", "geometry_config": {}}, {"title": "Road", "shape": "polygon", "color": "#DE597F", "geometry_config": {}}, {"title": "Building", "shape": "polygon", "color": "#D0021B", "geometry_config": {}}, {"title": "Vegetation", "shape": "polygon", "color": "#417505", "geometry_config": {}}, {"title": "Unlabeled", "shape": "polygon", "color": "#9B9B9B", "geometry_config": {}}]}
labels_dict = {"classes": [{"title": "Water", "r": 226, "g": 169, "b": 41 }, 
                           {"title": "Land", "r": 132, "g": 41, "b": 246 }, 
                           {"title": "Road", "r": 110, "g": 193, "b": 228 }, 
                           {"title": "Building", "r": 60, "g": 16, "b": 152 }, 
                           {"title": "Vegetation", "r": 254, "g": 221, "b": 58 }, 
                           {"title": "Unlabeled", "r": 155, "g": 155, "b": 155 }]}

labels_dict_df = pd.DataFrame(labels_dict['classes'])
labels_dict_df

In [None]:
label_names= list(labels_dict_df.title)
label_codes = []
r= np.asarray(labels_dict_df.r)
g= np.asarray(labels_dict_df.g)
b= np.asarray(labels_dict_df.b)

for i in range(len(labels_dict_df)):
    label_codes.append(tuple([r[i], g[i], b[i]]))
    
label_codes, label_names

In [None]:
code2id = {v:k for k,v in enumerate(label_codes)}
id2code = {k:v for k,v in enumerate(label_codes)}

name2id = {v:k for k,v in enumerate(label_names)}
id2name = {k:v for k,v in enumerate(label_names)}

In [None]:
id2code

In [None]:
id2name

## Function to One-hot Encode RGB Labels/Masks and Decoding Encoded Predictions

In [None]:
def rgb_to_onehot(rgb_mask_image, colormap = id2code):
    '''Function to one hot encode RGB mask labels
        Inputs: 
            rgb_image - image matrix (eg. 256 x 256 x 3 dimension numpy ndarray)
            colormap - dictionary of color to label id
        Output: One hot encoded image of dimensions (height x width x num_classes) where num_classes = len(colormap)
    '''
    num_classes = len(colormap)
    # shape prepared for image size and channels = num of classes (instead of 3 RGB colors)
    shape = rgb_mask_image.shape[:2]+(num_classes,)
    # encoded_image prepare array with right shaoe 
    encoded_mask = np.zeros( shape, dtype=np.int8 )
    for i, cls in enumerate(colormap):
        # image.reshape flattens and keeps 3 channels, 
        # then checks which pixels same as color in colormap
        # then change back to image size for each of 6 channels (based on colormap)
        encoded_mask[:,:,i] = np.all(rgb_mask_image.reshape( (-1,3) ) == colormap[i], axis=1).reshape(shape[:2])

    return encoded_mask

In [None]:
def onehot_to_rgb(onehot, colormap = id2code):
    '''Function to decode encoded mask labels
        Inputs: 
            onehot - one hot encoded image matrix (height x width x num_classes)
            colormap - dictionary of color to label id
        Output: Decoded RGB image (height x width x 3) 
    '''
    single_layer = np.argmax(onehot, axis=-1)
    output = np.zeros( onehot.shape[:2]+(3,) )
    for k in colormap.keys():
        output[single_layer==k] = colormap[k]
    return np.uint8(output)

In [None]:
# checking that it works
print(f'mask shape is RGB image {mask.shape}')
encoded_mask = rgb_to_onehot(mask, colormap = id2code)
decoded_mask = onehot_to_rgb(encoded_mask, colormap = id2code)
plt.imshow(decoded_mask);
print(f'encoded mask is 6 channel array {encoded_mask.shape}')

**Input on loading and preprocessing Images**

Deprecated: tf.keras.preprocessing.image.ImageDataGenerator is not recommended for new code. Prefer loading images with tf.keras.utils.image_dataset_from_directory and transforming the output tf.data.Dataset with preprocessing layers. For more information, see the tutorials for loading images and augmenting images, as well as the preprocessing layer guide.


we will use function [`tf.keras.utils.image_dataset_from_directory`](https://www.tensorflow.org/api_docs/python/tf/keras/utils/image_dataset_from_directory)

and resize with [keras.layers.resizing](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Resizing)

did resizing on import with Pillow.Image

## Model

In [None]:
# files to use
len(images_dataset), len(masks_dataset), len(aug_images_dataset), len(aug_masks_dataset)

In [None]:
# files to use
images_dataset[0].shape, masks_dataset[0].shape, aug_images_dataset[0].shape, aug_masks_dataset[0].shape

In [None]:
#generate list of images and masks
image_full_dataset = images_dataset + aug_images_dataset
image_full_dataset_np = np.array(image_full_dataset)

masks_full_dataset = masks_dataset + aug_masks_dataset
#masks_full_dataset_np = np.array(masks_full_dataset)

#### Problem solving - delete later

In [None]:
image_full_dataset = images_dataset + aug_images_dataset
len(image_full_dataset), type(image_full_dataset), type(image_full_dataset[0]), image_full_dataset[0].dtype

In [None]:
masks_full_dataset = masks_dataset + aug_masks_dataset
len(masks_full_dataset ), type(masks_full_dataset ), type(masks_full_dataset[0]), masks_full_dataset[0].dtype

In [None]:
image_full_dataset_np = np.array(image_full_dataset)
len(image_full_dataset_np), type(image_full_dataset_np), type(image_full_dataset_np[0]), image_full_dataset_np[0].dtype

In [None]:
masks_full_dataset_np = np.array(masks_full_dataset)
#len(masks_full_dataset_), type(masks_full_dataset ), type(masks_full_dataset[0]), masks_full_dataset[0].dtype

In [None]:
masks_full_dataset

In [None]:
len(image_full_dataset), image_full_dataset[0].shape, image_full_dataset_np.shape

In [None]:
type(masks_full_dataset)
len(masks_full_dataset), masks_full_dataset[0].shape

In [None]:
image_full_dataset[0]

In [None]:
plt.imshow(masks_full_dataset[0])

In [None]:
mask1 = masks_dataset[0]
mask2 = masks_dataset[10]
encoded_mask1 = rgb_to_onehot(mask1)
encoded_mask2 = rgb_to_onehot(mask2)
mask1.shape, encoded_mask2.shape
mask2.shape, encoded_mask2.shape

#### Back to modelling

In [None]:
encoded_masks = []

for i in range(len(masks_full_dataset)):
    mask = masks_full_dataset[i]
    encoded_mask = rgb_to_onehot(mask)
    encoded_masks.append(encoded_mask)

In [None]:
y = np.array(encoded_masks) 
X = np.array(image_full_dataset)/255.
len(y), len(X)

In [None]:
X[0].shape

In [None]:
plt.imshow(X[1])

In [None]:
# Preparing X(images) and y(labels) - to be added to load images later

# Finally we shuffle:
p = np.random.permutation(len(X))
X, y = X[p], y[p]

# first split is for train/val data, second split for test data
first_split = int(len(X) /6.) 
X_test, X_train_val = X[:first_split], X[first_split:]
y_test, y_train_val = y[:first_split], y[first_split:] 

### InceptionResNetV2 UNet

In [None]:
def conv_block(input, num_filters):
    x = Conv2D(num_filters, 3, padding="same")(input)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = Conv2D(num_filters, 3, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    return x

def decoder_block(input, skip_features, num_filters):
    x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(input)
    x = Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x

def build_inception_resnetv2_unet(input_shape):
    """ Input """
    inputs = Input(input_shape)

    """ Pre-trained InceptionResNetV2 Model """
    encoder = InceptionResNetV2(include_top=False, weights="imagenet", input_tensor=inputs)

    """ Encoder """
    s1 = encoder.get_layer("input_1").output           ## (512 x 512)

    s2 = encoder.get_layer("activation").output        ## (255 x 255)
    s2 = ZeroPadding2D(( (1, 0), (1, 0) ))(s2)         ## (256 x 256)

    s3 = encoder.get_layer("activation_3").output      ## (126 x 126)
    s3 = ZeroPadding2D((1, 1))(s3)                     ## (128 x 128)

    s4 = encoder.get_layer("activation_74").output      ## (61 x 61)
    s4 = ZeroPadding2D(( (2, 1),(2, 1) ))(s4)           ## (64 x 64)

    """ Bridge """
    b1 = encoder.get_layer("activation_161").output     ## (30 x 30)
    b1 = ZeroPadding2D((1, 1))(b1)                      ## (32 x 32)

    """ Decoder """
    d1 = decoder_block(b1, s4, 512)                     ## (64 x 64)
    d2 = decoder_block(d1, s3, 256)                     ## (128 x 128)
    d3 = decoder_block(d2, s2, 128)                     ## (256 x 256)
    d4 = decoder_block(d3, s1, 64)                      ## (512 x 512)
    
    """ Output """
    dropout = Dropout(0.3)(d4)
    outputs = Conv2D(6, 1, padding="same", activation="softmax")(dropout)

    model = Model(inputs, outputs, name="InceptionResNetV2-UNet")
    return model

In [None]:
K.clear_session()
# 1 is smooth coefficient - the dice_coef as IoU is 
# 1. the area of overlap between the predicted segmentation and the ground truth 
# 2. divided by the area of union between the predicted segmentation and the ground truth

def dice_coef(y_true, y_pred):
    return (2. * K.sum(y_true * y_pred) + 1.) / (K.sum(y_true) + K.sum(y_pred) + 1.)

model = build_inception_resnetv2_unet(input_shape = (512, 512, 3))
model.compile(optimizer=Adam(lr = 0.0001), loss='categorical_crossentropy', metrics=[dice_coef, "accuracy"])
model.summary()

In [None]:
# graph the model
SVG(model_to_dot(model).create(prog='dot', format='svg'))
plot_model(model, to_file='model.png', show_shapes=True, show_layer_names=True, expand_nested=True)

### Modelling

In [None]:
def exponential_decay(lr0, s):
    def exponential_decay_fn(epoch):
        return lr0 * 0.1 **(epoch / s)
    return exponential_decay_fn

exponential_decay_fn = exponential_decay(0.0001, 60)

lr_scheduler = LearningRateScheduler(
    exponential_decay_fn,
    verbose=1
)

# save the model
checkpoint = ModelCheckpoint(
    filepath = 'InceptionResNetV2-UNet.h5',
    save_best_only = True, 
#     save_weights_only = False,
    monitor = 'val_loss', 
    mode = 'auto', 
    verbose = 1
)

earlystop = EarlyStopping(
    monitor = 'val_loss', 
    min_delta = 0.001, 
    patience = 12, 
    mode = 'auto', 
    verbose = 1,
    restore_best_weights = True
)

csvlogger = CSVLogger(
    filename= "model_training.csv",
    separator = ",",
    append = False
)

callbacks = [checkpoint, earlystop, csvlogger, lr_scheduler]

In [None]:
batch_size = 16

steps_per_epoch = np.ceil(float(len(X_train_val)*0.8) / float(batch_size))
print('steps_per_epoch: ', steps_per_epoch)

validation_steps = np.ceil(float(len(X_train_val)*0.2) / float(batch_size))
print('validation_steps: ', validation_steps)

In [None]:
history = model.fit(
    X_train_val, 
    y_train_val,
    batch_size=batch_size,
    validation_split = 0.2, 
    epochs = 50,
    callbacks=callbacks, 
    verbose=1
)

In [None]:
df_result = pd.DataFrame(history.history)
df_result

In [None]:
# load history from csv file
history_saved = pd.read_csv("model_training.csv")
history_saved

In [None]:
# adjusted to show based on history_saved

fig, ax = plt.subplots(1, 4, figsize=(40, 10))
ax = ax.ravel()
metrics = ['Dice Coefficient', 'Accuracy', 'Loss', 'Learning Rate']

for i, met in enumerate(['dice_coef', 'accuracy', 'loss', 'lr']): 
    if met != 'lr':
        ax[i].plot(history_saved[met])
        ax[i].plot(history_saved['val_' + met])
        ax[i].set_title('{} vs Epochs'.format(metrics[i]), fontsize=16)
        ax[i].set_xlabel('Epochs')
        ax[i].set_ylabel(metrics[i])
        ax[i].set_xticks(np.arange(0,45,4))
        ax[i].legend(['Train', 'Validation'])
        ax[i].xaxis.grid(True, color = "lightgray", linewidth = "0.8", linestyle = "-")
        ax[i].yaxis.grid(True, color = "lightgray", linewidth = "0.8", linestyle = "-")
#     else:
#         ax[i].plot(history_saved[met])
#         ax[i].set_title('{} vs Epochs'.format(metrics[i]), fontsize=16)
#         ax[i].set_xlabel('Epochs')
#         ax[i].set_ylabel(metrics[i])
#         ax[i].set_xticks(np.arange(0,45,4))
#         ax[i].xaxis.grid(True, color = "lightgray", linewidth = "0.8", linestyle = "-")
#         ax[i].yaxis.grid(True, color = "lightgray", linewidth = "0.8", linestyle = "-")
        
plt.savefig('model_metrics_plot.png', facecolor= 'w',transparent= False, bbox_inches= 'tight', dpi= 150)

In [None]:
# # this portion not working as we dont have history
# fig, ax = plt.subplots(1, 4, figsize=(40, 5))
# ax = ax.ravel()
# metrics = ['Dice Coefficient', 'Accuracy', 'Loss', 'Learning Rate']

# for i, met in enumerate(['dice_coef', 'accuracy', 'loss', 'lr']): 
#     if met != 'lr':
#         ax[i].plot(history.history[met])
#         ax[i].plot(history.history['val_' + met])
#         ax[i].set_title('{} vs Epochs'.format(metrics[i]), fontsize=16)
#         ax[i].set_xlabel('Epochs')
#         ax[i].set_ylabel(metrics[i])
#         ax[i].set_xticks(np.arange(0,45,4))
#         ax[i].legend(['Train', 'Validation'])
#         ax[i].xaxis.grid(True, color = "lightgray", linewidth = "0.8", linestyle = "-")
#         ax[i].yaxis.grid(True, color = "lightgray", linewidth = "0.8", linestyle = "-")
#     else:
#         ax[i].plot(history.history[met])
#         ax[i].set_title('{} vs Epochs'.format(metrics[i]), fontsize=16)
#         ax[i].set_xlabel('Epochs')
#         ax[i].set_ylabel(metrics[i])
#         ax[i].set_xticks(np.arange(0,45,4))
#         ax[i].xaxis.grid(True, color = "lightgray", linewidth = "0.8", linestyle = "-")
#         ax[i].yaxis.grid(True, color = "lightgray", linewidth = "0.8", linestyle = "-")
        
# plt.savefig('model_metrics_plot.png', facecolor= 'w',transparent= False, bbox_inches= 'tight', dpi= 150)

In [None]:
model.load_weights("./InceptionResNetV2-UNet.h5")

In [None]:
!mkdir predictions

In [None]:
pred_all= model.predict(X_test)
np.shape(pred_all)
count = 0

for j in range(0,np.shape(pred_all)[0]):
    count += 1
    fig = plt.figure(figsize=(20,8))

    ax1 = fig.add_subplot(1,3,1)
    ax1.imshow(X_test[j])
    ax1.set_title('Input Image', fontdict={'fontsize': 16, 'fontweight': 'medium'})
    ax1.grid(False)

    ax2 = fig.add_subplot(1,3,2)
    ax2.set_title('Ground Truth Mask', fontdict={'fontsize': 16, 'fontweight': 'medium'})
    ax2.imshow(onehot_to_rgb(y_test[j],id2code))
    ax2.grid(False)

    ax3 = fig.add_subplot(1,3,3)
    ax3.set_title('Predicted Mask', fontdict={'fontsize': 16, 'fontweight': 'medium'})
    ax3.imshow(onehot_to_rgb(pred_all[j],id2code))
    ax3.grid(False)

    plt.savefig('./predictions/prediction_{}.png'.format(count), facecolor= 'w', transparent= False, bbox_inches= 'tight', dpi= 200)
    plt.show()

### Sandbox

In [None]:
tf.keras.layers.Resizing(
    height,
    width,
    interpolation='bilinear',
    crop_to_aspect_ratio=False,
    **kwargs
)

preprocessing = 
tf.keras.layers.CenterCrop(
    512, 512)

In [None]:
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
for element in dataset:
  print(element)

tf.keras.preprocessing.image.smart_resize(
    x, size, interpolation='bilinear'
)

size = (512, 512)
ds = ds.map(lambda img: tf.image.resize(img, size))