<h2>Run this first</h2>

In [None]:
!pip install tensorflow==2.7.0 pillow matplotlib

In [None]:
from tensorflow.keras.utils import get_file
from tensorflow.random import set_seed
set_seed(42)

Download pretrained model:

In [None]:
MODEL_PATH = get_file(
    'pretrained_unet.h5',
    'https://github.com/ZFTurbo/ZF_UNET_224_Pretrained_Model/releases/download/v1.0/zf_unet_224.h5',
    cache_subdir='models',
    file_hash='203146f209baf34ac0d793e1691f1ab7')

<h1>Data Preprocessing</h1>

In [None]:
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

def preprocess_input(x): 
    return -(x / 255 - 0.5)

def reverse_transform_input(x):
    return ((-x + 0.5) * 255).astype(np.uint8)

def preprocess_masks(x):
    return np.round(x / 255).astype(np.uint8) 

def reverse_transform_masks(x):
    return x * 255

def preprocess_predictions(x):
    return np.round(x).astype(np.uint8) * 255

def load_images(path):
    images = os.listdir(path)
    images = [os.path.join(path, img) for img in images]
    return np.array([np.array(Image.open(img).resize((224, 224), resample=Image.NEAREST)) for img in images])

def show_images(images=None, masks=None, predictions=None, n_rows=5, figsize=(30, 30), preprocess=False):
    plt.figure(figsize=figsize)
    if images is not None and len(images.shape) != 4:
        n_rows = 1
        images = np.expand_dims(images, axis=0)
    if masks is not None:
        masks = np.squeeze(masks)
        if len(masks.shape) != 3:
            n_rows = 1
            masks = np.expand_dims(masks, axis=0)
    if predictions is not None:
        predictions = np.squeeze(predictions)
        if len(predictions.shape) != 3:
            n_rows = 1
            predictions = np.expand_dims(predictions, axis=0)

    if preprocess:
        if images is not None:
            images = reverse_transform_input(images)
        if masks is not None:
            masks = reverse_transform_masks(masks)
        if predictions is not None:
            predictions = preprocess_predictions(predictions)
            
    for i in range(n_rows):
        if images is not None:
            plt.subplot(n_rows, 3, i*3+1)
            plt.title("Original image",fontsize=15)
            plt.imshow(images[i])
            ax = plt.gca()
            ax.axes.xaxis.set_visible(False)
            ax.axes.yaxis.set_visible(False)
        
        if masks is not None:
            plt.subplot(n_rows, 3, i*3+2)
            plt.title("True Mask",fontsize=15)
            plt.imshow(masks[i], cmap='gray')
            ax = plt.gca()
            ax.axes.xaxis.set_visible(False)
            ax.axes.yaxis.set_visible(False)
        
        if predictions is not None:
            plt.subplot(n_rows, 3, i*3+3)
            plt.title("Predicted Mask",fontsize=15)
            plt.imshow(predictions[i], cmap='gray')
            ax = plt.gca()
            ax.axes.xaxis.set_visible(False)
            ax.axes.yaxis.set_visible(False)
        
    plt.plot()

In [None]:
images = load_images('homm3_dataset')
show_images(images[0], figsize=(10, 10))

In [None]:
images.shape # sanity check - dimensions represent (img_count, height, width, channels)

There are 4 channels (the format is RGBA), let's split the alpha channel and the RGB image

In [None]:
rgb_images = images[..., :-1]
masks = np.expand_dims(images[..., -1], axis=-1)
print(rgb_images.shape, masks.shape)

In [None]:
show_images(images[0], masks[0], figsize=(10, 20)) # image should not change

In [None]:
print(rgb_images[0].min(), rgb_images[0].max())
rgb_images = preprocess_input(rgb_images) #preprocessing step
print(rgb_images[0].min(), rgb_images[0].max())

In [None]:
show_images(rgb_images[0], figsize=(10, 10)) # colors should change

In [None]:
show_images(rgb_images[0], figsize=(10, 10), preprocess=True) # colors shouldn't change because of the preprocess=True flag

The alpha mask may have values ranging from 0 to 255 - let's change it to binary mask 

In [None]:
print(masks.min(), masks.max())
masks = preprocess_masks(masks)
print(masks.min(), masks.max())

In [None]:
show_images(rgb_images[0], masks[0], figsize=(10, 20), preprocess=True) # nothing should change visually

<h1>Loading pretrained UNet</h1>

In [None]:
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import SpatialDropout2D
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import concatenate
from tensorflow.keras.layers import UpSampling2D

from tensorflow.keras.models import Model

In [None]:
# Number of image channels (for example 3 in case of RGB, or 1 for grayscale images)
INPUT_CHANNELS = 3
# Number of output masks (1 in case you predict only one type of objects)
OUTPUT_MASK_CHANNELS = 1

def double_conv_layer(x, size, dropout=0.0, batch_norm=True):
    conv = Conv2D(size, (3, 3), padding='same')(x)
    if batch_norm:
        conv = BatchNormalization(axis=3)(conv)
    conv = Activation('relu')(conv)
    
    conv = Conv2D(size, (3, 3), padding='same')(conv)
    if batch_norm:
        conv = BatchNormalization(axis=3)(conv)
    conv = Activation('relu')(conv)
    if dropout > 0:
        conv = SpatialDropout2D(dropout)(conv) # dropout that drops whole feature maps instead of individual elements (https://www.tensorflow.org/api_docs/python/tf/keras/layers/SpatialDropout2D)
        
    return conv

# definition of the model is taken from this here:
# https://github.com/ZFTurbo/ZF_UNET_224_Pretrained_Model/blob/master/zf_unet_224_model.py
def ZF_UNET_224(dropout_val=0.2):
    inputs = Input((224, 224, INPUT_CHANNELS))
    axis = 3
    filters = 32

    conv_224 = double_conv_layer(inputs, filters)
    pool_112 = MaxPooling2D(pool_size=(2, 2))(conv_224)

    conv_112 = double_conv_layer(pool_112, 2*filters)
    pool_56 = MaxPooling2D(pool_size=(2, 2))(conv_112)

    conv_56 = double_conv_layer(pool_56, 4*filters)
    pool_28 = MaxPooling2D(pool_size=(2, 2))(conv_56)

    conv_28 = double_conv_layer(pool_28, 8*filters)
    pool_14 = MaxPooling2D(pool_size=(2, 2))(conv_28)

    conv_14 = double_conv_layer(pool_14, 16*filters)
    pool_7 = MaxPooling2D(pool_size=(2, 2))(conv_14)

    conv_7 = double_conv_layer(pool_7, 32*filters)

    up_14 = concatenate([UpSampling2D(size=(2, 2))(conv_7), conv_14], axis=3)
    up_conv_14 = double_conv_layer(up_14, 16*filters)

    up_28 = concatenate([UpSampling2D(size=(2, 2))(up_conv_14), conv_28], axis=3)
    up_conv_28 = double_conv_layer(up_28, 8*filters)

    up_56 = concatenate([UpSampling2D(size=(2, 2))(up_conv_28), conv_56], axis=3)
    up_conv_56 = double_conv_layer(up_56, 4*filters)

    up_112 = concatenate([UpSampling2D(size=(2, 2))(up_conv_56), conv_112], axis=3)
    up_conv_112 = double_conv_layer(up_112, 2*filters)

    up_224 = concatenate([UpSampling2D(size=(2, 2))(up_conv_112), conv_224], axis=3)
    up_conv_224 = double_conv_layer(up_224, filters, dropout_val)

    conv_final = Conv2D(OUTPUT_MASK_CHANNELS, (1, 1), name='final_conv')(up_conv_224)
    conv_final = Activation('sigmoid', name='final_activation')(conv_final)

    model = Model(inputs, conv_final, name="ZF_UNET_224")

    return model

In [None]:
model = ZF_UNET_224()
model.summary()

In [None]:
model.load_weights(MODEL_PATH)

<h1>Using pretrained model</h1>

So now let's see how our pretrained UNet will work with segmenting our images

In [None]:
preds = model.predict(rgb_images[:5])

In [None]:
show_images(rgb_images, masks, preds, preprocess=True)

So, as we can see, because the UNet was trained on similar task, its outputs already have some sense even though it has never seen our data before. 

<h1>Fine-tuning Model</h1>

In [None]:
TRAIN_PERCENTAGE = 0.7
split = int(images.shape[0] * TRAIN_PERCENTAGE)

X_train, y_train = rgb_images[:split], masks[:split]
X_test, y_test = rgb_images[split:], masks[split:]
print('Train:', X_train.shape, y_train.shape)
print('Test:', X_test.shape, y_test.shape)
print(f'{X_train.shape[0] + X_test.shape[0]} == {images.shape[0]}')

In [None]:
preds = model.predict(X_test)
show_images(X_test, y_test, preds, preprocess=True) # plotting before training for comparison

<h2>TASK</h2>
Use pretrained model and fine-tune it on our dataset. The amount of data is really small so remember to freeze some layers. <br><br>
<strong>THE GOAL</strong> is to get the accuracy over 98% on test set. Even without fine-tuning the model should get around 97.5% so don't expect much improvement. Check your results with function <i>show_images</i> to make sure predictions don't look worse than before <br><br>
<strong>IMPORTANT NOTE:</strong> DO NOT freeze BatchNormalization layers. Transfer learning doesn't work well with frozen BatchNorms (as to why, check here: https://stackoverflow.com/questions/51123198/strange-behaviour-of-the-loss-function-in-keras-model-with-pretrained-convoluti/51124511)