# 1. Import libraris and load dataset

In [None]:
!pip install git+https://github.com/jakeret/unet.git

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/jakeret/unet.git
  Cloning https://github.com/jakeret/unet.git to /tmp/pip-req-build-i00maz9e
  Running command git clone --filter=blob:none --quiet https://github.com/jakeret/unet.git /tmp/pip-req-build-i00maz9e
  Resolved https://github.com/jakeret/unet.git to commit f557a51b6f95aae6848cab6141e6cae573934bf8
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [None]:
import unet
import tensorflow as tf
import tensorflow_datasets as tfds 
import numpy as np
import matplotlib.pyplot as plt

from unet.datasets import oxford_iiit_pet
from tensorflow import keras

In [None]:
dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)

# 2. Data processing

In [None]:
def resize(input_image, input_mask):
  input_image = tf.image.resize(input_image, (128, 128), method='nearest')
  input_mask = tf.image.resize(input_mask, (128, 128), method='nearest')
  return input_image, input_mask

def augment(input_image, input_mask):
  if tf.random.uniform(()) > 0.5:
    input_image = tf.image.flip_left_right(input_image)
    input_mask = tf.image.flip_left_right(input_mask) 
  return input_image, input_mask

def normalize(input_image, input_mask):
  input_image = tf.cast(input_image, tf.float32) / 225.0
  input_mask -= 1
  return input_image, input_mask

def load_image_train(datapoint):
  input_image = datapoint['image']
  input_mask = datapoint['segmentation_mask']
  input_image, input_mask = resize(input_image, input_mask)
  input_image, input_mask = augment(input_image, input_mask)
  input_image, input_mask = normalize(input_image, input_mask)
  return input_image, input_mask

def load_image_test(datapoint):
  input_image = datapoint['image']
  input_mask = datapoint['segmentation_mask']
  input_image, input_mask = resize(input_image, input_mask)
  input_image, input_mask = normalize(input_image, input_mask)
  return input_image, input_mask

In [None]:
train_dataset = dataset['train'].map(load_image_train, num_parallel_calls=tf.data.AUTOTUNE)
test_dataset = dataset['test'].map(load_image_test, num_parallel_calls=tf.data.AUTOTUNE)

In [None]:
BATCH_SIZE = 64
BUFFER_SIZE = 100

train_batches = train_dataset.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
train_batches = train_batches.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

validation_batches = test_dataset.take(3000).batch(BATCH_SIZE)

test_batches = test_dataset.skip(3000).take(669).batch(BATCH_SIZE)

# 3. U-net construction

In [None]:
class UNet():
  def __init__(self):
    pass

  def double_conv_block(self, x, n_filters):
    x = keras.layers.Conv2D(n_filters, 3, padding='same', activation='relu',
                            kernel_initializer='he_normal')(x)
    
    x = keras.layers.Conv2D(n_filters, 3, padding='same', activation='relu',
                            kernel_initializer='he_normal')(x)

    return x

  def downsample_block(self, x, n_filters):
    f = self.double_conv_block(x, n_filters)
    
    p = keras.layers.MaxPool2D(2)(f)
    p = keras.layers.Dropout(0.3)(p)
    
    return f, p

  def upsample_block(self, x, conv_features, n_filters):
    x = keras.layers.Conv2DTranspose(n_filters, 3, 2, padding='same')(x)
    x = keras.layers.concatenate([x, conv_features])
    x = keras.layers.Dropout(0.3)(x)
    x = self.double_conv_block(x, n_filters)

    return x

  def build_model(self):
    # Input
    inputs = keras.layers.Input(shape=(128, 128, 3))

    #Encoder: contracting with down-sample
    f1, p1 = self.downsample_block(inputs, 64)
    f2, p2 = self.downsample_block(p1, 128)
    f3, p3 = self.downsample_block(p2, 256)
    f4, p4 = self.downsample_block(p3, 512)

    # Bottleneck
    bottleneck = self.double_conv_block(p4, 1024)

    # Decoder: expanding path with up-sample
    u6 = self.upsample_block(bottleneck, f4, 512)
    u7 = self.upsample_block(u6, f3, 256)
    u8 = self.upsample_block(u7, f2, 128)
    u9 = self.upsample_block(u8, f1, 64)

    # Output
    outputs = keras.layers.Conv2D(3, 1, padding='same', activation='softmax')(u9)

    unet_model = keras.models.Model(inputs=inputs, outputs=outputs, name='U-net')

    return unet_model

In [None]:
unet = UNet()
unet_model = unet.build_model()
unet_model.summary()

Model: "U-net"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 128, 128, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 128, 128, 64  1792        ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 conv2d_1 (Conv2D)              (None, 128, 128, 64  36928       ['conv2d[0][0]']                 
                                )                                                             

In [None]:
NUM_EPOCHS = 20
TRAIN_LENGTH = info.splits['train'].num_examples
STEP_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE

VAL_SUBSPLITS = 5
TEST_LENGTH = info.splits['test'].num_examples
VALIDATION_STEPS = TEST_LENGTH // BATCH_SIZE // VAL_SUBSPLITS

In [None]:
unet_model.compile(optimizer=keras.optimizers.Adam(), 
                   loss=keras.losses.SparseCategoricalCrossentropy(), 
                   metrics=['accuracy'])

unet_history = unet_model.fit(train_batches,
                              epochs=NUM_EPOCHS,
                              steps_per_epoch=STEP_PER_EPOCH, 
                              validation_steps=VALIDATION_STEPS,
                              validation_data=validation_batches)

Epoch 1/20

KeyboardInterrupt: ignored

In [None]:
def learning_curves(history):
  acc = history.history['accuracy']
  val_acc = history.history['val_accuracy']

  loss = history.history['loss']
  val_loss = history.history['val_loss'] 

  epochs_range = range(NUM_EPOCHS)

  fig = plt.figure(figsize=(18, 9))

  plt.subplot(1, 2, 1)
  plt.plot(epochs_range, acc, label='train accuracy')
  plt.plot(epochs_range, val_acc, label='validation accuracy')
  plt.title('Accuracy')
  plt.xlabel('Epoch')
  plt.ylabel('Accuracy')
  plt.legend(loc='lower right')

  plt.subplot(1, 2, 2)
  plt.plot(epochs_range, loss, label='train loss')
  plt.plot(epochs_range, val_loss, label='validation loss')
  plt.title('Loss')
  plt.xlabel('Epoch')
  plt.ylabel('Loss')
  plt.legend(loc='upper right')

  fig.tight_layout()
  plt.show()

learning_curves(unet_history)

In [None]:
for image, mask in train_dataset.take(3):
    sample_image, sample_mask = image, mask
    display([sample_image, sample_mask])

In [None]:
def display(display_list):
    plt.figure(figsize=(15, 15))

    title = ['Input Image', 'True Mask', 'Predicted Mask']

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
        plt.axis('off')
    plt.show()

def create_mask(pred_mask):
  pred_mask = tf.argmax(pred_mask, axis=-1)
  pred_mask = pred_mask[..., tf.newaxis]
  return pred_mask[0]

def show_predictions(dataset=None, num=1):
  if dataset:
    for image, mask in dataset.take(num):
      pred_mask = unet_model.predict(image)
      display([image[0], mask[0], create_mask(pred_mask)])
  else:
    display([sample_image, sample_mask, create_mask(unet_model.predict(sample_image[tf.newaxis, ...]))])