In [None]:
# ! pip install git+https://github.com/tensorflow/examples.git

In [None]:
import tensorflow as tf

import tensorflow_datasets as tfds

In [None]:
from tensorflow_examples.models.pix2pix import pix2pix

from IPython.display import clear_output
import matplotlib.pyplot as plt

In [None]:
dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)

In [None]:
def normalize(input_img, input_mask):
    input_img = tf.cast(input_img, tf.float32) / 255.0
    input_mask -=  1
    return input_img, input_mask

In [None]:
def load_img(datapoint):
    input_img = tf.image.resize(datapoint['image'], (128, 128))
    input_mask = tf.image.resize(
        datapoint['segmentation_mask'], 
        (128, 128),
        method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
    )
    input_img, input_mask = normalize(input_img, input_mask)
    
    return input_img, input_mask

In [None]:
TRAIN_LENGTH = info.splits['train'].num_examples
BATCH_SIZE = 64
BUFFER_SIZE = 1000
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE

In [None]:
train_images = dataset['train'].map(load_img, num_parallel_calls=tf.data.AUTOTUNE)
test_images = dataset['test'].map(load_img, num_parallel_calls=tf.data.AUTOTUNE)

In [None]:
class Augment(tf.keras.layers.Layer):
    def __init__(self, seed=42):
        super().__init__()
        self.augment_inputs = tf.keras.layers.RandomFlip(mode='horizontal', seed=seed)
        self.augment_labels = tf.keras.layers.RandomFlip(mode='horizontal', seed=seed)
        
    def call(self, inputs, labels):
        inputs = self.augment_inputs(inputs)
        labels = self.augment_labels(labels)  
        return inputs, labels

In [None]:
train_batches = (
    train_images
    .cache()
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE)
    .repeat()
    .map(Augment())
    .prefetch(buffer_size=tf.data.AUTOTUNE)
)

test_batches = test_images.batch(BATCH_SIZE)

In [None]:
def display(display_list):
    plt.figure(figsize=(15, 15))
    
    title = ['Input Image', 'True Mask', 'Predicted Mask']
    
    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(display_list[i])
        plt.axis('off')
    plt.show()
    

In [None]:
for images, masks in train_batches.take(2):
    sample_img, sample_mask = images[0], masks[0]
    display([sample_img, sample_mask])

In [None]:
for images, masks in train_batches.take(2):
    sample_img, sample_mask = images[0], masks[0]
    display([sample_img, sample_mask])

In [None]:
base_model = tf.keras.applications.MobileNetV2(input_shape=(128, 128, 3), include_top=False)

layer_names = [
    'block_1_expand_relu',   
    'block_3_expand_relu',   
    'block_6_expand_relu',   
    'block_13_expand_relu',  
    'block_16_project',     
]

base_model_outputs = [base_model.get_layer(name).output for name in layer_names]

down_stack = tf.keras.Model(inputs=base_model.input, outputs=base_model_outputs)

down_stack.trainable = False

In [None]:
up_stack = [
    pix2pix.upsample(512, 3),  
    pix2pix.upsample(256, 3), 
    pix2pix.upsample(128, 3),
    pix2pix.upsample(64, 3), 
]

In [None]:
def unet_model(output_channels):
    inputs = tf.keras.layers.Input(shape=[128, 128, 3])
    
    skips = down_stack(inputs)
    x = skips[-1]
    skips = reversed(skips[:-1])
    
    for up_block, skip in zip(up_stack, skips):
        x = up_block(x)
        x = tf.keras.layers.Concatenate()((x, skip))
    
    last = tf.keras.layers.Conv2DTranspose(
      filters=output_channels, kernel_size=3, strides=2,
      padding='same')
    
    output = last(x)
    
    return tf.keras.Model(inputs=inputs, outputs=output)

In [None]:
OUTPUT_CLASSES = 3

model = unet_model(output_channels=OUTPUT_CLASSES)
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [None]:
tf.keras.utils.plot_model(model, show_shapes=True)

In [None]:
def create_mask(pred_mask):
    pred_mask = tf.math.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    return pred_mask[0]

In [None]:
def show_predictions(dataset=None, num=1):
  if dataset:
    for image, mask in dataset.take(num):
      pred_mask = model.predict(image)
      display([image[0], mask[0], create_mask(pred_mask)])
  else:
    display([sample_img, sample_mask,
             create_mask(model.predict(sample_img[tf.newaxis, ...]))])

In [None]:
show_predictions()

In [None]:
class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        clear_output(wait=True)
        show_predictions()
        print ('\nSample Prediction after epoch {}\n'.format(epoch+1))


In [None]:
EPOCHS = 20
VAL_SUBSPLITS = 5
VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS

model_history = model.fit(train_batches, 
                          epochs=EPOCHS,
                          steps_per_epoch=STEPS_PER_EPOCH,
                          validation_steps=VALIDATION_STEPS,
                          validation_data=test_batches,
                          callbacks=[DisplayCallback()])

In [None]:
loss = model_history.history['loss']
val_loss = model_history.history['val_loss']

plt.plot(loss, label='loss')
plt.plot(val_loss,'bo', label='val_loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.ylim([0, 1])
plt.legend()
plt.show()

In [None]:
model.save("model_tf")

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

# save model
with open('model.tflite', 'wb') as fp:
    fp.write(tflite_model)