# Image segmentation

In [1]:
import tensorflow as tf
import os
import pandas as pd
import numpy as np
from IPython.display import clear_output
import matplotlib.pyplot as plt
from datetime import datetime
from pathlib import Path

gpu_devices = tf.config.list_physical_devices('GPU')

print(f"TensorFlow Version: {tf.__version__}")

if gpu_devices:
    print(f"✅ Found {len(gpu_devices)} GPU(s):")
    for device in gpu_devices:
        print(f"  - {device}")
else:
    print("❌ No GPU found. TensorFlow is using the CPU.")

# dd/mm/YY H:M:S
dt_string = datetime.now().strftime("%d-%m-%Y-%H:%M:%S")
print("date and time =", dt_string)

TensorFlow Version: 2.16.2
❌ No GPU found. TensorFlow is using the CPU.
date and time = 10-10-2025-14:59:33


## Create a Dataset from the folder of train/val images

And make a directory to save outputs on Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

code_root = Path('/content/drive/MyDrive/10 - code/')
ds_path = code_root.joinpath('Semantic-Segmentation-Suite/data/210221-colab-test-1')
print("Dataset folder:")
print(ds_path.as_posix())

outputs_path = code_root.joinpath('Colab Data Xfer/{}'.format(datetime.now().strftime("%Y-%m-%d")))
displays_path = outputs_path.joinpath('displays')
checkpoints_path = outputs_path.joinpath('checkpoints')
outputs_path.mkdir(parents=True, exist_ok=True)
displays_path.mkdir(parents=True, exist_ok=True)
checkpoints_path.mkdir(parents=True, exist_ok=True)
print("Outputs folder:")
print(outputs_path.as_posix())

In [None]:
list_ds = tf.data.Dataset.list_files(ds_path.joinpath('train/*.png').as_posix())
test_list_ds = tf.data.Dataset.list_files(ds_path.joinpath('val/*.png').as_posix())
list_ds

In [None]:
df_classes = pd.read_csv('/content/drive/MyDrive/10 - code/Semantic-Segmentation-Suite/data/210221-colab-test-1/class_dict.csv')
class_dict = {r['name']:np.array((r.r, r.g, r.b)) for ii,r in df_classes.iterrows()}
class_dict

In [None]:
for f in list_ds.take(1):
  print(f.numpy())

In [None]:
INPUT_SIZE = (512, 512)

In [None]:
def mask_rgb_to_int(A):
  """ Given image tensor A (h, w, 3) and class_dict {class: [r, g, b]},
      build mask array M (h, w, 1) where values are integers corresponding
      to each class
      """
  class_dict = {
    'background': np.array([255, 255, 212]),
    # 'error_bars': np.array([54, 55, 55]),
    'markers': np.array([  3,  67, 223]),
    'x_tick_labels': np.array([229,   0,   0]),
    'x_ticks': np.array([132,   0,   0]),
    'y_tick_labels': np.array([191, 119, 246]),
    'y_ticks': np.array([154,  14, 234])
  }

  A_png = (A.numpy()*255).astype(np.uint8)
  M = np.zeros((A.shape[0], A.shape[1], 1), dtype=np.uint8)

  for ii, (key, rgb) in enumerate(class_dict.items()):
    M_i = np.sum(np.abs(A_png-rgb), axis=2)<1
    M[M_i] = ii

  return M

def tf_mask_rgb_to_int(image):
  im_shape = (image.shape[0], image.shape[1], 1)
  [image,] = tf.py_function(mask_rgb_to_int, [image], [tf.uint8])
  image.set_shape(im_shape)
  return image

In [None]:
def parse_image(filepath):
  filename = tf.strings.split(filepath, os.sep)[-1]
  label = tf.strings.split(filename, '.')[-2]

  image_raw = tf.io.read_file(filepath)
  image_png = tf.image.decode_png(image_raw, channels=3)
  image_float = tf.image.convert_image_dtype(image_png, tf.float32)
  image_float = tf.image.resize(image_float, INPUT_SIZE, method='nearest')

  mask_float = get_segmentation_mask(filepath)
  mask_int = tf_mask_rgb_to_int(mask_float)
  return {'image_float':image_float, 'mask_float':mask_float, 'mask_int':mask_int, 'label':label}

def get_segmentation_mask(image_filepath):
  mask_filepath = tf.strings.regex_replace(image_filepath, 'train', 'train_labels')
  mask_filepath = tf.strings.regex_replace(mask_filepath, 'val', 'val_labels')
  mask_filename = tf.strings.split(mask_filepath, os.sep)[-1]
  # label = tf.strings.split(mask_filename, '.')[-2]

  mask_raw = tf.io.read_file(mask_filepath)
  mask_png = tf.image.decode_png(mask_raw, channels=3)
  mask_float = tf.image.convert_image_dtype(mask_png, tf.float32)
  mask_float = tf.image.resize(mask_float, INPUT_SIZE, method='nearest')
  return mask_float

def show(image, label):
  plt.figure()
  plt.imshow(image)
  plt.title(label.numpy().decode('utf-8'))
  plt.axis('off')

def display_init(display_list, fig_size=5):
  plt.figure(figsize=(len(display_list) * fig_size, fig_size), tight_layout=True)

  title = ['Input Image', 'True Mask RGB', 'True Mask int']

  for i in range(len(display_list)):
    plt.subplot(1, len(display_list), i+1)
    plt.title(title[i])
    plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
    plt.axis('off')

In [None]:
for f in list_ds.take(1):
  dp = parse_image(f.numpy())
  display_init([dp['image_float'], dp['mask_float'], dp['mask_int']])
plt.savefig(displays_path.joinpath('mask_test_{}.png'.format(datetime.now().strftime("%Y-%m-%d-%H:%M:%S"))).as_posix(), dpi=300)

In [None]:
np.unique(dp['mask_int'].numpy())

Debugging stuff

In [None]:
# pd.DataFrame(dp['mask_int'].numpy()[:,:,0]).to_csv('array.csv')

In [None]:
# print((mask_int_test.numpy()==5).sum())
# print(np.all(mask_rgb_test.numpy()==class_dict['y_tick_labels'],axis=2).sum())

The following code performs a simple augmentation of flipping an image. In addition,  image is normalized to [0,1]. Finally, as mentioned above the pixels in the segmentation mask are labeled either {1, 2, 3}. For the sake of convenience, let's subtract 1 from the segmentation mask, resulting in labels that are : {0, 1, 2}.

In [None]:
train_img_ds = list_ds.map(parse_image)
test_img_ds = test_list_ds.map(parse_image)

In [None]:
@tf.function
def load_image_train(datapoint):
  input_image = datapoint['image_float']
  input_mask = datapoint['mask_int']

  if tf.random.uniform(()) > 0.5:
    input_image = tf.image.flip_left_right(input_image)
    input_mask = tf.image.flip_left_right(input_mask)

  return input_image, input_mask

In [None]:
def load_image_test(datapoint):
  input_image = datapoint['image_float']
  input_mask = datapoint['mask_int']

  return input_image, input_mask

The dataset already contains the required splits of test and train and so let's continue to use the same split.

In [None]:
TRAIN_LENGTH = 1000
BATCH_SIZE = 16
BUFFER_SIZE = 1000
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE

In [None]:
train = train_img_ds.map(load_image_train, num_parallel_calls=tf.data.AUTOTUNE)
test = test_img_ds.map(load_image_test)

In [None]:
train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
train_dataset = train_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
test_dataset = test.batch(BATCH_SIZE)

Let's take a look at an image example and its correponding mask from the dataset.

In [None]:
def display(display_list, fig_size=5, show=True):
  fig = plt.figure(figsize=(len(display_list) * fig_size, fig_size))

  title = ['Input Image', 'True Mask', 'Predicted Mask']

  for i in range(len(display_list)):
    plt.subplot(1, len(display_list), i+1)
    plt.title(title[i])
    plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
    plt.axis('off')

  plt.savefig(
      displays_path.joinpath('display_{}.png'.format(datetime.now().strftime("%Y-%m-%d-%H:%M:%S"))).as_posix(),
      dpi=300
      )

  if not show:
    plt.close();

In [None]:
for image, mask in train.take(2):
  sample_image, sample_mask = image, mask
  display([sample_image, sample_mask]);

In [None]:
for image, mask in test.take(2):
  sample_image, sample_mask = image, mask
  display([sample_image, sample_mask]);

## Define the model
The model being used here is a modified U-Net. A U-Net consists of an encoder (downsampler) and decoder (upsampler). In-order to learn robust features, and reduce the number of trainable parameters, a pretrained model can be used as the encoder. Thus, the encoder for this task will be a pretrained MobileNetV2 model, whose intermediate outputs will be used, and the decoder will be the upsample block already implemented in TensorFlow Examples in the [Pix2pix tutorial](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py).

The reason to output three channels is because there are three possible labels for each pixel. Think of this as multi-classification where each pixel is being classified into three classes.

In [None]:
OUTPUT_CHANNELS = 6

As mentioned, the encoder will be a pretrained MobileNetV2 model which is prepared and ready to use in [tf.keras.applications](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/applications). The encoder consists of specific outputs from intermediate layers in the model. Note that the encoder will not be trained during the training process.

In [None]:
from tensorflow.keras import layers


def get_model(img_size, num_classes):
    inputs = tf.keras.Input(shape=img_size + (3,))

    ### [First half of the network: downsampling inputs] ###

    # Entry block
    x = layers.Conv2D(32, 3, strides=2, padding="same")(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    previous_block_activation = x  # Set aside residual

    # Blocks 1, 2, 3 are identical apart from the feature depth.
    for filters in [64, 128, 256]:
        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.MaxPooling2D(3, strides=2, padding="same")(x)

        # Project residual
        residual = layers.Conv2D(filters, 1, strides=2, padding="same")(
            previous_block_activation
        )
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    ### [Second half of the network: upsampling inputs] ###

    for filters in [256, 128, 64, 32]:
        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.UpSampling2D(2)(x)

        # Project residual
        residual = layers.UpSampling2D(2)(previous_block_activation)
        residual = layers.Conv2D(filters, 1, padding="same")(residual)
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    # Add a per-pixel classification layer
    outputs = layers.Conv2D(num_classes, 3, activation="softmax", padding="same")(x)

    # Define the model
    model = tf.keras.Model(inputs, outputs)
    return model


# Free up RAM in case the model definition cells were run multiple times
tf.keras.backend.clear_session()

# Build model
model = get_model(INPUT_SIZE, OUTPUT_CHANNELS)
model.summary()


## Train the model
Now, all that is left to do is to compile and train the model. The loss being used here is `losses.SparseCategoricalCrossentropy(from_logits=True)`. The reason to use this loss function is because the network is trying to assign each pixel a label, just like multi-class prediction. In the true segmentation mask, each pixel has either a {0,1,2}. The network here is outputting three channels. Essentially, each channel is trying to learn to predict a class, and `losses.SparseCategoricalCrossentropy(from_logits=True)` is the recommended loss for
such a scenario. Using the output of the network, the label assigned to the pixel is the channel with the highest value. This is what the create_mask function is doing.

In [None]:
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

Have a quick look at the resulting model architecture:

In [None]:
# tf.keras.utils.plot_model(model, show_shapes=True)

Let's try out the model to see what it predicts before training.

In [None]:
def create_mask(pred_mask):
  pred_mask = tf.argmax(pred_mask, axis=-1)
  pred_mask = pred_mask[..., tf.newaxis]
  return pred_mask[0]

In [None]:
def show_predictions(model, dataset=None, num=1, show=True):
  if dataset:
    for image, mask in dataset.take(num):
      pred_mask = model.predict(image)
      display([image[0], mask[0], create_mask(pred_mask)], show=show)
  else:
    display([sample_image, sample_mask,
             create_mask(model.predict(sample_image[tf.newaxis, ...]))], show=show)

In [None]:
show_predictions(model, dataset=test_dataset, num=2)

Let's observe how the model improves while it is training. To accomplish this task, a callback function is defined below. Let's also save model weights while training to enable re-starting training...

In [None]:
class DisplayCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    clear_output(wait=True)
    show_predictions(model, dataset=test_dataset, num=3, show=False)
    print ('\nSample Prediction after epoch {}\n'.format(epoch+1))


checkpoint_path = checkpoints_path.joinpath("cp.ckpt").as_posix()
checkpoint_dir = os.path.dirname(checkpoint_path)
print("Checkpoints path:")
print(checkpoint_path)

# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)

Ability to Load from Checkpoint here

In [None]:
# Loads the weights
old_checkpoint_path = outputs_path.joinpath("2023-12-18/checkpoints/cp.ckpt"),
model.load_weights(checkpoint_path)
print("Weights loaded from checkpoint")

In [None]:
EPOCHS = 100
VAL_SUBSPLITS = 2
VALIDATION_STEPS = len(test_img_ds)//BATCH_SIZE//VAL_SUBSPLITS

model_history = model.fit(train_dataset, epochs=EPOCHS,
                          steps_per_epoch=STEPS_PER_EPOCH,
                          validation_steps=VALIDATION_STEPS,
                          validation_data=test_dataset,
                          callbacks=[DisplayCallback(), cp_callback])

In [None]:
loss = model_history.history['loss']
val_loss = model_history.history['val_loss']

epochs = range(EPOCHS)

plt.figure()
plt.plot(epochs, loss, 'r', label='Training loss')
plt.plot(epochs, val_loss, 'bo', label='Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.ylim([0, 1])
plt.legend()
plt.show()

## Make predictions

Let's make some predictions. In the interest of saving time, the number of epochs was kept small, but you may set this higher to achieve more accurate results.

In [None]:
def display_new_prediction(input_image, pred_mask, fig_size=5, show=True):

  titles = ['Input Image', 'Predicted Mask']

  fig = plt.figure(figsize=(2 * fig_size, fig_size))
  for ii, (img, title) in enumerate(zip([input_image, pred_mask], titles)):
    plt.subplot(1, 2, ii+1)
    plt.title(title)
    plt.imshow(tf.keras.preprocessing.image.array_to_img(img))
    plt.axis('off')
  plt.savefig(
      displays_path.joinpath('prediction_{}.png'.format(datetime.now().strftime("%Y-%m-%d-%H:%M:%S"))).as_posix(),
      dpi=300
      )
  if not show:
    plt.close();


def load_image_to_tensor(filepath, format="jpg"):
  image_raw = tf.io.read_file(filepath)
  if format == "png":
    image_decode = tf.image.decode_png(image_raw, channels=3)
  elif format =="jpg":
    image_decode = tf.image.decode_jpeg(image_raw, channels=3)
  else:
    print("invalid input format")
    return
  image_float = tf.image.convert_image_dtype(image_decode, tf.float32)
  image_resize = tf.image.resize(image_float, INPUT_SIZE, method='nearest')
  image_4d = tf.expand_dims(image_resize, 0)
  return image_4d


def predict_from_file(model, filepath, format="jpg"):
  image_resize = load_image_to_tensor(filepath, format="jpg")
  print("Input image size:", image_resize.shape)
  pred_mask = model.predict(image_resize)
  # pred_mask_reshape = tf.squeeze(pred_mask)
  # print("Pred mask shape:" pred_mask_reshape.shape)
  display_new_prediction(tf.squeeze(image_resize), create_mask(pred_mask))
  return pred_mask

In [None]:
test_data_dir = code_root.joinpath('Semantic-Segmentation-Suite/data/real-world-examples')
test_filename = "example101.jpg"
test_filepath = test_data_dir.joinpath(test_filename).as_posix()
print("Predicting mask for:", test_filepath)
# load_image_to_tensor(test_filepath, format="jpg")
predict_from_file(model, test_filepath, format="jpg")

## Next steps
Now that you have an understanding of what image segmentation is and how it works, you can try this tutorial out with different intermediate layer outputs, or even different pretrained model. You may also challenge yourself by trying out the [Carvana](https://www.kaggle.com/c/carvana-image-masking-challenge/overview) image masking challenge hosted on Kaggle.

You may also want to see the [Tensorflow Object Detection API](https://github.com/tensorflow/models/tree/master/research/object_detection) for another model you can retrain on your own data.