# Image segmentation with a U-Net-like architecture
Rebuilt to demo validity of Wedge Dropout, and compare it to Spatial Dropout.

Original is: https://keras.io/examples/vision/oxford_pets_image_segmentation/

In [1]:
!pip uninstall -y tensorflow
!pip install -q tensorflow==2.3.3

Found existing installation: tensorflow 2.3.3
Uninstalling tensorflow-2.3.3:
  Successfully uninstalled tensorflow-2.3.3


In [2]:
!wget -q -nc http://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz
!wget -q -nc http://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz
!tar -xf images.tar.gz
!tar -xf annotations.tar.gz


In [3]:
import os

input_dir = "images/"
target_dir = "annotations/trimaps/"
img_size = (160, 160)
num_classes = 4
batch_size = 96

input_img_paths = sorted(
    [
        os.path.join(input_dir, fname)
        for fname in os.listdir(input_dir)
        if fname.endswith(".jpg")
    ]
)
target_img_paths = sorted(
    [
        os.path.join(target_dir, fname)
        for fname in os.listdir(target_dir)
        if fname.endswith(".png") and not fname.startswith(".")
    ]
)

print("Number of samples:", len(input_img_paths))

for input_path, target_path in zip(input_img_paths[:10], target_img_paths[:10]):
    print(input_path, "|", target_path)


Number of samples: 7390
images/Abyssinian_1.jpg | annotations/trimaps/Abyssinian_1.png
images/Abyssinian_10.jpg | annotations/trimaps/Abyssinian_10.png
images/Abyssinian_100.jpg | annotations/trimaps/Abyssinian_100.png
images/Abyssinian_101.jpg | annotations/trimaps/Abyssinian_101.png
images/Abyssinian_102.jpg | annotations/trimaps/Abyssinian_102.png
images/Abyssinian_103.jpg | annotations/trimaps/Abyssinian_103.png
images/Abyssinian_104.jpg | annotations/trimaps/Abyssinian_104.png
images/Abyssinian_105.jpg | annotations/trimaps/Abyssinian_105.png
images/Abyssinian_106.jpg | annotations/trimaps/Abyssinian_106.png
images/Abyssinian_107.jpg | annotations/trimaps/Abyssinian_107.png


In [4]:
from tensorflow import keras
import numpy as np
from tensorflow.keras.preprocessing.image import load_img


class OxfordPets(keras.utils.Sequence):
    """Helper to iterate over the data (as Numpy arrays)."""

    def __init__(self, batch_size, img_size, input_img_paths, target_img_paths):
        self.batch_size = batch_size
        self.img_size = img_size
        self.input_img_paths = input_img_paths
        self.target_img_paths = target_img_paths

    def __len__(self):
        return len(self.target_img_paths) // self.batch_size

    def __getitem__(self, idx):
        """Returns tuple (input, target) correspond to batch #idx."""
        i = idx * self.batch_size
        batch_input_img_paths = self.input_img_paths[i : i + self.batch_size]
        batch_target_img_paths = self.target_img_paths[i : i + self.batch_size]
        x = np.zeros((batch_size,) + self.img_size + (3,), dtype="float32")
        for j, path in enumerate(batch_input_img_paths):
            img = load_img(path, target_size=self.img_size)
            x[j] = img
        y = np.zeros((batch_size,) + self.img_size + (1,), dtype="uint8")
        for j, path in enumerate(batch_target_img_paths):
            img = load_img(path, target_size=self.img_size, color_mode="grayscale")
            y[j] = np.expand_dims(img, 2)
        return x / 255, y - 1



# Wedge2D Dropout



In [5]:
!pip uninstall -y keras-wedge-dropout
!pip install -q git+https://github.com/LanceNorskog/keras-wedge.git
import numpy as np
from keras_wedge_dropout import WedgeDropout2D

Found existing installation: keras-wedge-dropout 0.1.0
Uninstalling keras-wedge-dropout-0.1.0:
  Successfully uninstalled keras-wedge-dropout-0.1.0
  Building wheel for keras-wedge-dropout (setup.py) ... [?25l[?25hdone


## Perpare U-Net Xception-style model


In [6]:
from tensorflow.keras import layers

he = 'he_normal'
def get_model(img_size, num_classes, wedgeDropout=False):
    inputs = keras.Input(shape=img_size + (3,))

    ### [First half of the network: downsampling inputs] ###

    # Entry block
    x = layers.Conv2D(32, 3, strides=2, padding="same", kernel_initializer = he)(inputs)
    x = layers.BatchNormalization()(x)

    previous_block_activation = x  # Set aside residual

    # Blocks 1, 2, 3 are identical apart from the feature depth.
    for step, filters in enumerate([64, 128, 256]):
        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same", kernel_initializer = he)(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same", kernel_initializer = he)(x)
        x = layers.BatchNormalization()(x)

        x = layers.MaxPooling2D(3, strides=2, padding="same")(x)

        # Project residual
        residual = layers.Conv2D(filters, 1, strides=2, padding="same", kernel_initializer = he)(
            previous_block_activation
        )
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    ### [Second half of the network: upsampling inputs] ###

    previous_block_activation = x  # Set aside residual

    for step, filters in enumerate([256, 128, 64, 32]):
        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same", kernel_initializer = he)(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same", kernel_initializer = he)(x)
        x = layers.BatchNormalization()(x)

        x = layers.UpSampling2D(2)(x)

        # Project residual
        residual = layers.UpSampling2D(2)(previous_block_activation)
        residual = layers.Conv2D(filters, 1, padding="same", kernel_initializer = he)(residual)
        x = layers.add([x, residual], name='final_add_'+str(step))  # Add back residual
        previous_block_activation = x  # Set aside next residual

    wedge = None
    if wedgeDropout:
        x = WedgeDropout2D(0.65, batchwise=True)(x)
        wedge = x

    # Add a per-pixel classification layer
    outputs = layers.Conv2D(num_classes, 3, activation="softmax", padding="same", name='output', kernel_initializer = he)(x)

    # Define the model
    model = keras.Model(inputs, outputs)
    return model, wedge


# Free up RAM in case the model definition cells were run multiple times
keras.backend.clear_session()

# Build model
model, _ = get_model(img_size, num_classes)
model.summary()


Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 160, 160, 3) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 80, 80, 32)   896         input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 80, 80, 32)   128         conv2d[0][0]                     
__________________________________________________________________________________________________
activation (Activation)         (None, 80, 80, 32)   0           batch_normalization[0][0]        
_______________________________________________________________________________________

## Set aside a validation split


In [7]:
import random

# Split our img paths into a training and a validation set
val_samples = 1000
random.Random(1337).shuffle(input_img_paths)
random.Random(1337).shuffle(target_img_paths)
train_input_img_paths = input_img_paths[:-val_samples]
train_target_img_paths = target_img_paths[:-val_samples]
val_input_img_paths = input_img_paths[-val_samples:]
val_target_img_paths = target_img_paths[-val_samples:]

# Instantiate data Sequences for each split
train_gen = OxfordPets(
    batch_size, img_size, train_input_img_paths, train_target_img_paths
)
val_gen = OxfordPets(batch_size, img_size, val_input_img_paths, val_target_img_paths)


## Train the model


In [8]:
# Configure the model for training.
# We use the "sparse" version of categorical_crossentropy
# because our target data is integers.
model.compile(optimizer="rmsprop", 
            loss="sparse_categorical_crossentropy",
            metrics=['accuracy'])

callbacks = [
    keras.callbacks.EarlyStopping(
        monitor='val_loss', patience=10, restore_best_weights=True)
]

# Train the model, doing validation at the end of each epoch.
epochs = 100
model.fit(train_gen, epochs=epochs, validation_data=val_gen, callbacks=callbacks, verbose=2)
print(model.evaluate(val_gen))


Epoch 1/100
66/66 - 36s - loss: 4.1725 - accuracy: 0.3018 - val_loss: 5.1031 - val_accuracy: 0.0101
Epoch 2/100
66/66 - 35s - loss: 0.8967 - accuracy: 0.3046 - val_loss: 1.6175 - val_accuracy: 0.0915
Epoch 3/100
66/66 - 33s - loss: 0.7542 - accuracy: 0.3058 - val_loss: 1.5730 - val_accuracy: 2.9553e-04
Epoch 4/100
66/66 - 34s - loss: 0.6245 - accuracy: 0.3135 - val_loss: 1.3110 - val_accuracy: 6.3745e-04
Epoch 5/100
66/66 - 33s - loss: 0.5496 - accuracy: 0.3099 - val_loss: 1.1826 - val_accuracy: 0.0071
Epoch 6/100
66/66 - 35s - loss: 0.4955 - accuracy: 0.3081 - val_loss: 0.8450 - val_accuracy: 0.0588
Epoch 7/100
66/66 - 33s - loss: 0.4660 - accuracy: 0.3107 - val_loss: 0.5835 - val_accuracy: 0.2210
Epoch 8/100
66/66 - 35s - loss: 0.4230 - accuracy: 0.3090 - val_loss: 0.4825 - val_accuracy: 0.2514
Epoch 9/100
66/66 - 33s - loss: 0.4007 - accuracy: 0.3103 - val_loss: 0.4455 - val_accuracy: 0.2948
Epoch 10/100
66/66 - 35s - loss: 0.3907 - accuracy: 0.3089 - val_loss: 0.4763 - val_accuracy

In [9]:
# Free up RAM in case the model definition cells were run multiple times
keras.backend.clear_session()

# Build model
model_wedge, _ = get_model(img_size, num_classes, wedgeDropout=True)
model_wedge.summary()

model_wedge.compile(optimizer="rmsprop", 
            loss="sparse_categorical_crossentropy",
            metrics=['accuracy'])

callbacks = [
    keras.callbacks.EarlyStopping(
        monitor='val_loss', patience=10, restore_best_weights=True)
]

# Train the model, doing validation at the end of each epoch.
epochs = 100
model_wedge.fit(train_gen, epochs=epochs, validation_data=val_gen, callbacks=callbacks, verbose=2)
print(model_wedge.evaluate(val_gen))


WedgeDropout2D.build: input_shape: (None, 160, 160, 32)
Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 160, 160, 3) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 80, 80, 32)   896         input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 80, 80, 32)   128         conv2d[0][0]                     
__________________________________________________________________________________________________
activation (Activation)         (None, 80, 80, 32)   0           batch_normalization[0][0]        
_______________________________

In [10]:
# Free up RAM in case the model definition cells were run multiple times
keras.backend.clear_session()

# Build model
model_wedge, wedge = get_model(img_size, num_classes, wedgeDropout=True)
model_wedge.summary()

model_wedge.compile(optimizer="rmsprop", 
            loss="sparse_categorical_crossentropy",
            metrics=['accuracy'])

callbacks = [
    keras.callbacks.EarlyStopping(
        monitor='val_loss', patience=10, restore_best_weights=True)
]

# Do 5 epochs with WedgeDropout disabled, then turn it on.
starting_epochs=5
wedge.trainable = False
model_wedge.fit(train_gen, epochs=starting_epochs, verbose=2)
wedge.trainable = True
print('WedgeDropout2D enabled')
epochs = 100
model_wedge.fit(train_gen, initial_epoch=starting_epochs, epochs=epochs, validation_data=val_gen, callbacks=callbacks, verbose=2)
print(model_wedge.evaluate(val_gen))


WedgeDropout2D.build: input_shape: (None, 160, 160, 32)
Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 160, 160, 3) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 80, 80, 32)   896         input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 80, 80, 32)   128         conv2d[0][0]                     
__________________________________________________________________________________________________
activation (Activation)         (None, 80, 80, 32)   0           batch_normalization[0][0]        
_______________________________

In [11]:
# Free up RAM in case the model definition cells were run multiple times
keras.backend.clear_session()

# Build model
model_wedge, wedge = get_model(img_size, num_classes, wedgeDropout=True)
model_wedge.summary()

model_wedge.compile(optimizer="adam", 
            loss="sparse_categorical_crossentropy",
            metrics=['accuracy'])

callbacks = [
    keras.callbacks.EarlyStopping(
        monitor='val_loss', patience=10, restore_best_weights=True)
]

# Do 5 epochs with WedgeDropout disabled, then turn it on.
starting_epochs=5
wedge.trainable = False
model_wedge.fit(train_gen, epochs=starting_epochs, verbose=2)
wedge.trainable = True
print('WedgeDropout2D enabled')
epochs = 100
model_wedge.fit(train_gen, initial_epoch=starting_epochs, epochs=epochs, validation_data=val_gen, callbacks=callbacks, verbose=2)
print(model_wedge.evaluate(val_gen))


WedgeDropout2D.build: input_shape: (None, 160, 160, 32)
Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 160, 160, 3) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 80, 80, 32)   896         input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 80, 80, 32)   128         conv2d[0][0]                     
__________________________________________________________________________________________________
activation (Activation)         (None, 80, 80, 32)   0           batch_normalization[0][0]        
_______________________________

In [12]:
# Free up RAM in case the model definition cells were run multiple times
keras.backend.clear_session()

# Build model
model_wedge, wedge = get_model(img_size, num_classes, wedgeDropout=True)
model_wedge.summary()

model_wedge.compile(optimizer="adam", 
            loss="sparse_categorical_crossentropy",
            metrics=['accuracy'])

callbacks = [
    keras.callbacks.EarlyStopping(
        monitor='val_loss', patience=10, restore_best_weights=True)
]

# Do 5 epochs with WedgeDropout enabled, then turn it off.
starting_epochs=5
wedge.trainable = True
model_wedge.fit(train_gen, epochs=starting_epochs, verbose=2)
wedge.trainable = False
print('WedgeDropout2D disabled')
epochs = 100
model_wedge.fit(train_gen, initial_epoch=starting_epochs, epochs=epochs, validation_data=val_gen, callbacks=callbacks, verbose=2)
print(model_wedge.evaluate(val_gen))


WedgeDropout2D.build: input_shape: (None, 160, 160, 32)
Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 160, 160, 3) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 80, 80, 32)   896         input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 80, 80, 32)   128         conv2d[0][0]                     
__________________________________________________________________________________________________
activation (Activation)         (None, 80, 80, 32)   0           batch_normalization[0][0]        
_______________________________