# An image segmentation example
In this example, we’ll focus on **semantic segmentation**: we’ll be looking once again at images of cats and dogs, and this time we’ll learn how to tell apart the main subject and its background.

We’ll work with the [Oxford-IIIT Pets dataset](https://www.robots.ox.ac.uk/~vgg/data/pets/), which contains $7,390$ pictures of various breeds of cats and dogs, together with foreground-background _segmentation masks_ for each picture.

A **segmentation mask** is the image-segmentation equivalent of a label: it’s an image the same size as the input image, with a single color channel where each integer value corresponds to the class of the corresponding pixel in the input image. In our case, the pixels of our segmentation masks can take one of three integer values:
- 1 (foreground)
- 2 (background)
- 3 (contour)

## Data preparation
### Downloadind the dataset
You can use the link above to download both images and annotations directories, otherwise you can use the following snippet.

In [None]:
!wget http://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz
!wget http://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz
!tar -xf images.tar.gz
!tar -xf annotations.tar.gz

The input pictures are stored as **JPG** files in the ```images/``` folder, and the corresponding segmentation mask is stored as a **PNG** file with the same name in the ```annotations/trimaps/``` folder.

### Preparing the input paths
Let’s prepare the list of input file paths, as well as the list of the corresponding
mask file paths.

In [None]:
## LOCAL VERSION
import os
from constants import PET_PATH

input_dir = os.path.join(PET_PATH, "images")
target_dir = os.path.join(PET_PATH, "annotations", "trimaps")

input_img_paths = sorted(
    [os.path.join(input_dir, fname)
     for fname in os.listdir(input_dir)
     if fname.endswith(".jpg")])
target_paths = sorted(
    [os.path.join(target_dir, fname)
     for fname in os.listdir(target_dir)
     if fname.endswith(".png") and not fname.startswith(".")])

In [None]:
## COLAB VERSION
import os

input_dir = "images/"
target_dir = "annotations/trimaps/"

input_img_paths = sorted(
    [os.path.join(input_dir, fname)
     for fname in os.listdir(input_dir)
     if fname.endswith(".jpg")])
target_paths = sorted(
    [os.path.join(target_dir, fname)
     for fname in os.listdir(target_dir)
     if fname.endswith(".png") and not fname.startswith(".")])

### Understanding the dataset
What does one of these inputs and its mask look like? Let’s take a quick look.

In [None]:
import matplotlib.pyplot as plt
from tensorflow.keras.utils import load_img, img_to_array

plt.axis("off")
plt.imshow(load_img(input_img_paths[9]))

In [None]:
def display_target(target_array):
    # The original labels are 1, 2 and 3. We subtract 1 so that the labels range from 0 to 2.
    normalized_array = (target_array.astype("uint8") - 1) * 127  # we multiply by 127 so that the labels become 0 (black), 127 (gray), 254 (near-white).
    plt.axis("off")
    plt.imshow(normalized_array[:, :, 0])

img = img_to_array(load_img(target_paths[9], color_mode="grayscale"))
display_target(img)

### Training-Validation split
Next, let’s load our inputs and targets into two NumPy arrays, and let’s split the arrays into a training and a validation set. Since the dataset is very small, we can just load everything into memory.

In [None]:
import numpy as np
import random

img_size = (200, 200)  # We resize everything to 200x200
num_imgs = len(input_img_paths)

# We Shuffle the file paths (they were originally sorted by breed).
# We use the same seed (1337) in both statements to ensure that the input paths and target paths stay in the same order.
random.Random(1337).shuffle(input_img_paths)
random.Random(1337).shuffle(target_paths)

def path_to_input_image(path):
    return img_to_array(load_img(path, target_size=img_size))

def path_to_target(path):
    img = img_to_array(
        load_img(path, target_size=img_size, color_mode="grayscale"))
    img = img.astype("uint8") - 1
    return img

input_imgs = np.zeros((num_imgs,) + img_size + (3,), dtype="float32")
targets = np.zeros((num_imgs,) + img_size + (1,), dtype="uint8")
for i in range(num_imgs):
    # Load all images in the input_imgs float32 array. The inputs have three channels (RBG values)
    input_imgs[i] = path_to_input_image(input_img_paths[i])
    # Load their masks in the targets uint8 array (same order). The targets have a single channel (which contains integer labels).
    targets[i] = path_to_target(target_paths[i])

num_val_samples = 1000  # We'll use 1'000 samples for validation
train_input_imgs = input_imgs[:-num_val_samples]
train_targets = targets[:-num_val_samples]
val_input_imgs = input_imgs[-num_val_samples:]
val_targets = targets[-num_val_samples:]

## Model definition

In [None]:
from tensorflow import keras
from tensorflow.keras import layers

def get_model(img_size, num_classes):
    inputs = keras.Input(shape=img_size + (3,))
    x = layers.Rescaling(1./255)(inputs)

    x = layers.Conv2D(64, 3, strides=2, activation="relu", padding="same")(x)
    x = layers.Conv2D(64, 3, activation="relu", padding="same")(x)
    x = layers.Conv2D(128, 3, strides=2, activation="relu", padding="same")(x)
    x = layers.Conv2D(128, 3, activation="relu", padding="same")(x)
    x = layers.Conv2D(256, 3, strides=2, padding="same", activation="relu")(x)
    x = layers.Conv2D(256, 3, activation="relu", padding="same")(x)

    x = layers.Conv2DTranspose(256, 3, activation="relu", padding="same")(x)
    x = layers.Conv2DTranspose(256, 3, activation="relu", padding="same", strides=2)(x)
    x = layers.Conv2DTranspose(128, 3, activation="relu", padding="same")(x)
    x = layers.Conv2DTranspose(128, 3, activation="relu", padding="same", strides=2)(x)
    x = layers.Conv2DTranspose(64, 3, activation="relu", padding="same")(x)
    x = layers.Conv2DTranspose(64, 3, activation="relu", padding="same", strides=2)(x)

    outputs = layers.Conv2D(num_classes, 3, activation="softmax", padding="same")(x)

    model = keras.Model(inputs, outputs)
    return model

model = get_model(img_size=img_size, num_classes=3)
model.summary()

The first half of the model closely resembles the kind of convnet you’d use for image classification: a stack of ```Conv2D``` layers, with gradually increasing filter sizes. We downsample our images three times by a factor of two each, ending up with activations of size $(25, 25, 256)$. The purpose of this first half is to encode the images into smaller feature maps, where each spatial location (or pixel) contains information about a large spatial chunk of the original image.

One important difference between the first half of this model and the classification models you’ve seen before is the way we do downsampling: in the classification convnets from the last notebook, we used ```MaxPooling2D``` layers to downsample feature maps. Here, we downsample by adding strides to every other convolution layer. We do this because, in the case of image segmentation, we care a lot about the spatial location of information in the image, since we need to produce per-pixel target masks as output of the model. When you do $2 × 2$ max pooling, you are completely destroying location information within each pooling window: you return one scalar value per window, with zero knowledge of which of the four locations in the windows the value came from.

The second half of the model is a stack of ```Conv2DTranspose``` layers. _What are those?_

The output of the first half of the model is a feature map of shape $(25, 25, 256)$, but we want our final output to have the same shape as the target masks, $(200, 200,3)$. Therefore, we need to apply a kind of inverse of the transformations we’ve applied so far—something that will **upsample** the feature maps instead of downsampling them. That’s the purpose of the Conv2DTranspose layer.



## Training

In [None]:
model.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy")

callbacks = [
    keras.callbacks.ModelCheckpoint("oxford_segmentation.keras",
                                    save_best_only=True)
]

history = model.fit(train_input_imgs, train_targets,
                    epochs=50,
                    callbacks=callbacks,
                    batch_size=64,
                    validation_data=(val_input_imgs, val_targets))

In [None]:
epochs = range(1, len(history.history["loss"]) + 1)
loss = history.history["loss"]
val_loss = history.history["val_loss"]

plt.figure()
plt.plot(epochs, loss, "bo", label="Training loss")
plt.plot(epochs, val_loss, "b", label="Validation loss")
plt.title("Training and validation loss")
plt.legend()

We start overfitting midway, around epoch 25.

## Making predictions
Let’s reload our best performing model according to the validation loss, and demonstrate how to use it to predict a segmentation mask.

In [None]:
from tensorflow.keras.utils import array_to_img

model = keras.models.load_model("oxford_segmentation.keras")

i = 4
test_image = val_input_imgs[i]
plt.axis("off")
plt.imshow(array_to_img(test_image))

mask = model.predict(np.expand_dims(test_image, 0))[0]

def display_mask(pred):
    mask = np.argmax(pred, axis=-1)
    mask *= 127
    plt.axis("off")
    plt.imshow(mask)

display_mask(mask)