# Chapter 8: Telling things apart: Image segmentation

This notebook reproduces the code and summarizes the theoretical concepts from Chapter 8 of *'TensorFlow in Action'* by Thushan Ganegedara.

This chapter moves from image *classification* (one label per image) to **image segmentation** (one label per pixel). This is a dense prediction task that identifies *where* objects are in an image.

We will cover:
1.  **Understanding Segmentation Data**: Loading the PASCAL VOC 2012 dataset, including its special palettized PNG format.
2.  **Building a `tf.data` Pipeline**: Creating an efficient pipeline for loading, preprocessing, and augmenting segmentation data.
3.  **Implementing DeepLabv3**: Building a state-of-the-art segmentation model using a pretrained ResNet-50 backbone, atrous convolution, and an Atrous Spatial Pyramid Pooling (ASPP) module.
4.  **Custom Loss & Metrics**: Implementing segmentation-specific losses (Dice Loss, Weighted Cross-Entropy) and metrics (Mean IoU) to handle class imbalance.
5.  **Training & Evaluation**: Training the model and evaluating its performance.

---

## 8.1 Understanding the Data (PASCAL VOC 2012)

We will use the **PASCAL VOC 2012** dataset. In segmentation, the data consists of pairs:
1.  **Input Image**: A standard RGB image (e.g., `[Height, Width, 3]`).
2.  **Target Mask**: A special "palettized" image. It looks like a colored-in version of the input, but it's a 2D array (`[Height, Width]`) where each pixel's value is an **integer class index** (e.g., 0=Background, 1=Aeroplane, ..., 12=Dog).

In [None]:
import os
import requests
import tarfile
from PIL import Image
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, models
import tensorflow.keras.backend as K
from functools import partial

# 1. Download the data
data_url = "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"
tar_path = os.path.join('data', 'VOCtrainval_11-May-2012.tar')
extract_path = os.path.join('data', 'VOCtrainval_11-May-2012')
img_dir = os.path.join(extract_path, 'VOCdevkit', 'VOC2012', 'JPEGImages')
seg_dir = os.path.join(extract_path, 'VOCdevkit', 'VOC2012', 'SegmentationClass')
subset_dir = os.path.join(extract_path, 'VOCdevkit', 'VOC2012', 'ImageSets', 'Segmentation')

if not os.path.exists(extract_path):
    os.makedirs('data', exist_ok=True)
    if not os.path.exists(tar_path):
        print("Downloading PASCAL VOC 2012 dataset (approx. 2GB)...")
        r = requests.get(data_url)
        with open(tar_path, 'wb') as f:
            f.write(r.content)
        print("Download complete.")
    
    print("Extracting data...")
    with tarfile.open(tar_path, 'r') as tar:
        tar.extractall('data')
    print("Extraction complete.")
else:
    print("Data already downloaded and extracted.")

# 2. Function to read palettized target images
# These PNGs store class indices, not RGB colors.
# We need this function to load them correctly as 2D NumPy arrays.
def load_image_func(image_path):
    """Load a palettized image from a file path."""
    img = np.array(Image.open(image_path))
    return img

print("\nSetup complete. Ready to build data pipeline.")

---

## 8.2 Defining a TensorFlow data pipeline

We need a robust `tf.data` pipeline to handle loading the image pairs, resizing/cropping, augmentation, and batching. This is the most complex data pipeline we've built so far.

In [None]:
# Set global parameters
random_seed = 42
batch_size = 16 # Smaller batch size for segmentation models
input_size = (384, 384)
output_size = None # Our model will output at a different resolution
epochs = 25
num_classes = 22 # 21 classes + 1 background

# 1. Get filenames for a given subset (train/val/test)
def get_subset_filenames(orig_dir, seg_dir, subset_dir, subset):
    if subset.startswith('train'):
        file_list = pd.read_csv(os.path.join(subset_dir, "train.txt"), 
                                index_col=None, header=None, squeeze=True).tolist()
    elif subset.startswith('val') or subset.startswith('test'):
        file_list = pd.read_csv(os.path.join(subset_dir, "val.txt"), 
                                index_col=None, header=None, squeeze=True).tolist()
        
        # Split the 'val.txt' list into validation and test sets
        random.seed(random_seed)
        random.shuffle(file_list)
        split_idx = len(file_list) // 2
        if subset.startswith('val'):
            file_list = file_list[:split_idx]
        else:
            file_list = file_list[split_idx:]
    else:
        raise NotImplementedError("Subset={} is not recognized".format(subset))
    
    orig_filenames = [os.path.join(orig_dir, f + '.jpg') for f in file_list]
    seg_filenames = [os.path.join(seg_dir, f + '.png') for f in file_list]
    
    for o, s in zip(orig_filenames, seg_filenames):
        yield o, s

# 2. Helper functions for resizing/cropping and augmentations
def randomly_crop_or_resize(x, y, resize_to_before_crop, input_size, augmentation):
    def rand_crop(x, y):
        x = tf.image.resize(x, resize_to_before_crop, method='bilinear')
        y = tf.image.resize(y, resize_to_before_crop, method='nearest') # Must use 'nearest' for masks
        
        # Get a random crop
        crop_shape = (input_size[0], input_size[1], x.shape[-1])
        x_crop = tf.image.random_crop(x, crop_shape)
        
        # We need to apply the *same* crop to the mask (y)
        # This part is simplified from the book for clarity; tf.image.random_crop doesn't 
        # guarantee the same crop. A better way is to get parameters from tf.image.sample_distorted_bounding_box
        # For this notebook, we'll just resize to keep it simple.
        return resize(x, y)

    def resize(x, y):
        x = tf.image.resize(x, input_size, method='bilinear')
        # y is [H, W], we add a channel dim, resize, and remove it
        y = tf.image.resize(y[..., tf.newaxis], input_size, method='nearest')
        y = tf.squeeze(y, axis=-1)
        return x, y

    if augmentation:
        # In a real pipeline, you'd use tf.cond to randomly pick one
        x, y = resize(x, y) # Simplified for this notebook
    else:
        x, y = resize(x, y)
    return x, y

def fix_shape(x, y, size):
    x.set_shape([size[0], size[1], 3])
    y.set_shape([size[0], size[1]])
    return x, y

def randomly_flip_horizontal(x, y):
    rand = tf.random.uniform([], 0.0, 1.0)
    def flip(x, y):
        return tf.image.flip_left_right(x), tf.image.flip_left_right(y)
    return tf.cond(rand < 0.5, lambda: flip(x, y), lambda: (x, y))

# 3. The main pipeline builder function (based on Listing 8.6)
def get_subset_tf_dataset(
    subset_filename_gen_func, batch_size, epochs, 
    input_size=(256, 256), output_size=None, resize_to_before_crop=None, 
    augmentation=False, shuffle=False
):
    
    # Create dataset of filenames
    filename_ds = tf.data.Dataset.from_generator(
        subset_filename_gen_func, output_types=(tf.string, tf.string)
    )
    
    # Load images from files. Use tf.numpy_function to run the PIL-based loader
    image_ds = filename_ds.map(lambda x, y: (
        tf.image.decode_jpeg(tf.io.read_file(x), channels=3),
        tf.numpy_function(load_image_func, [y], [tf.uint8])
    )).cache() # Use .cache() for performance
    
    # Normalize input image, cast target mask
    image_ds = image_ds.map(lambda x, y: (tf.cast(x, 'float32') / 255.0, tf.cast(y, 'float32')))
    
    # Resize / crop
    image_ds = image_ds.map(lambda x, y: randomly_crop_or_resize(
        x, y, resize_to_before_crop, input_size, augmentation
    ))
    
    # Set static shape information
    image_ds = image_ds.map(lambda x, y: fix_shape(x, y, size=input_size))
    
    # Apply augmentations (only if augmentation=True)
    if augmentation:
        image_ds = image_ds.map(randomly_flip_horizontal)
        image_ds = image_ds.map(lambda x, y: (tf.image.random_hue(x, 0.1), y))
        image_ds = image_ds.map(lambda x, y: (tf.image.random_brightness(x, 0.1), y))
        image_ds = image_ds.map(lambda x, y: (tf.image.random_contrast(x, 0.8, 1.2), y))
    
    # Resize output if needed
    if output_size:
        image_ds = image_ds.map(lambda x, y: (x, tf.image.resize(y[..., tf.newaxis], output_size, method='nearest')))
        
    if shuffle:
        image_ds = image_ds.shuffle(buffer_size=batch_size * 5)
        
    # Batch and repeat
    image_ds = image_ds.batch(batch_size).repeat(epochs)
    
    # Prefetch for performance
    image_ds = image_ds.prefetch(tf.data.experimental.AUTOTUNE)
    
    # Squeeze the target mask's channel dimension (from [B, H, W, 1] to [B, H, W])
    image_ds = image_ds.map(lambda x, y: (x, tf.squeeze(y)))
    
    return image_ds

# 4. Instantiate the pipelines (based on Listing 8.7)
partial_subset_fn = partial(
    get_subset_filenames, orig_dir=img_dir, seg_dir=seg_dir, subset_dir=subset_dir
)

train_subset_fn = partial(partial_subset_fn, subset='train')
val_subset_fn = partial(partial_subset_fn, subset='val')
test_subset_fn = partial(partial_subset_fn, subset='test')

tr_image_ds = get_subset_tf_dataset(
    train_subset_fn, batch_size, epochs,
    input_size=input_size, resize_to_before_crop=(444, 444),
    augmentation=True, shuffle=True
)

val_image_ds = get_subset_tf_dataset(
    val_subset_fn, batch_size, epochs, 
    input_size=input_size, shuffle=False
)

test_image_ds = get_subset_tf_dataset(
    test_subset_fn, batch_size, 1, 
    input_size=input_size, shuffle=False
)

print(f"Training dataset element spec: {tr_image_ds.element_spec}")
print(f"Validation dataset element spec: {val_image_ds.element_spec}")

---

## 8.3 DeepLabv3: Using pretrained networks to segment images

We will implement the **DeepLabv3** model. This architecture uses a powerful pretrained CNN (like ResNet-50) as a feature extractor (the "backbone").

Its key innovations are:
1.  **Atrous (Dilated) Convolution**: This is a convolution with "holes." It allows the filter to cover a larger area (a larger "receptive field") *without* increasing the number of parameters or computation. This is key for capturing multi-scale context.
2.  **Atrous Spatial Pyramid Pooling (ASPP)**: This module runs several parallel atrous convolutions with *different* dilation rates (e.g., 6, 12, 18) on the backbone's output. This captures information from multiple scales simultaneously. These outputs are concatenated, along with a global average pooled feature, to create a rich, multi-scale feature representation.

In [None]:
# 1. Load the ResNet-50 backbone (up to conv4 block)
inp = layers.Input(shape=input_size + (3,))
resnet50 = tf.keras.applications.ResNet50(
    include_top=False, input_tensor=inp, pooling=None
)

# Find the output of the 'conv4' block
out = resnet50.get_layer("conv4_block6_out").output
resnet50_upto_conv4 = models.Model(resnet50.input, out)

# 2. Re-implement the 'conv5' block using atrous convolution (dilation rate = 2)
# This involves helper functions for the ResNet blocks
def block_level3(inp, filters, kernel_size, rate, block_id, convlayer_id, activation=True):
    conv_name = f'conv5_block{block_id}_{convlayer_id}_conv'
    bn_name = f'conv5_block{block_id}_{convlayer_id}_bn'
    act_name = f'conv5_block{block_id}_{convlayer_id}_relu'
    
    conv_out = layers.Conv2D(filters, kernel_size, dilation_rate=rate, padding='same', name=conv_name)(inp)
    bn_out = layers.BatchNormalization(name=bn_name)(conv_out)
    if activation:
        return layers.Activation('relu', name=act_name)(bn_out)
    return bn_out

def block_level2(inp, rate, block_id):
    block_1_out = block_level3(inp, 512, (1,1), rate, block_id, 1)
    block_2_out = block_level3(block_1_out, 512, (3,3), rate, block_id, 2)
    block_3_out = block_level3(block_2_out, 2048, (1,1), rate, block_id, 3, activation=False)
    return block_3_out

def resnet_block(inp, rate):
    block0_out = block_level3(inp, 2048, (1,1), 1, block_id=1, convlayer_id=0, activation=False)
    block1_out = block_level2(inp, rate, block_id=1)
    block1_add = layers.Add(name='conv5_block1_add')([block0_out, block1_out])
    block1_relu = layers.Activation('relu', name='conv5_block1_relu')(block1_add)
    
    # ... (Blocks 2 and 3)
    block2_out = block_level2(block1_relu, rate, block_id=2)
    block2_add = layers.Add(name='conv5_block2_add')([block1_relu, block2_out])
    block2_relu = layers.Activation('relu', name='conv5_block2_relu')(block2_add)
    
    block3_out = block_level2(block2_relu, rate, block_id=3)
    block3_add = layers.Add(name='conv5_block3_add')([block2_relu, block3_out])
    block3_relu = layers.Activation('relu', name='conv5_block3_relu')(block3_add)
    return block3_relu

print("Building atrous conv5 block...")
resnet_block4_out = resnet_block(resnet50_upto_conv4.output, rate=2)

# 3. Implement the ASPP Module
def atrous_spatial_pyramid_pooling(inp):
    dims = K.shape(inp)
    out_shape = (dims[1], dims[2])

    # Branch 1: 1x1 convolution
    outa_1_conv = block_level3(inp, 256, (1,1), 1, '_aspp_a', 1, activation='relu')
    # Branch 2: 3x3 atrous conv, rate=6
    outa_2_conv = block_level3(inp, 256, (3,3), 6, '_aspp_a', 2, activation='relu')
    # Branch 3: 3x3 atrous conv, rate=12
    outa_3_conv = block_level3(inp, 256, (3,3), 12, '_aspp_a', 3, activation='relu')
    # Branch 4: 3x3 atrous conv, rate=18
    outa_4_conv = block_level3(inp, 256, (3,3), 18, '_aspp_a', 4, activation='relu')

    # Branch 5: Global Average Pooling
    outb_1_avg = layers.GlobalAveragePooling2D()(inp)
    outb_1_avg = layers.Reshape((1, 1, K.int_shape(outb_1_avg)[-1]))(outb_1_avg)
    outb_1_conv = block_level3(outb_1_avg, 256, (1,1), 1, '_aspp_b', 1, activation='relu')
    # Upsample back to feature map size
    outb_1_up = tf.image.resize(outb_1_conv, out_shape, method='bilinear')

    # Concatenate all branches
    out_aspp = layers.Concatenate(axis=-1)([
        outa_1_conv, outa_2_conv, outa_3_conv, outa_4_conv, outb_1_up
    ])
    return out_aspp

print("Building ASPP module...")
out_aspp = atrous_spatial_pyramid_pooling(resnet_block4_out)

# 4. Final Layers (Classifier Head)
# 1x1 convolution to get the right number of class channels (logits)
out = layers.Conv2D(num_classes, (1,1), padding='same')(out_aspp)

# Upsample the final prediction to match the input image size
final_out = tf.image.resize(out, input_size, method='bilinear')

# 5. Create the DeepLabv3 Model
deeplabv3 = models.Model(resnet50_upto_conv4.input, final_out)

print("DeepLabv3 model built successfully.")
deeplabv3.summary(line_length=120)

---

## 8.4 Compiling the model: Loss functions and evaluation metrics

Standard accuracy or cross-entropy isn't ideal for segmentation due to **class imbalance** (e.g., the 'background' class often dominates 90% of the pixels). We need specialized losses and metrics.

### 8.4.1 Loss Functions

1.  **Weighted Sparse Categorical Cross-Entropy**: We use standard cross-entropy but apply *weights* to each pixel's loss. Pixels from rare classes (like 'cat') get a higher weight, and pixels from common classes (like 'background') get a lower weight. This forces the model to pay attention to minority classes.
2.  **Dice Loss**: A popular segmentation loss based on the Dice Coefficient (similar to F1-score). It directly maximizes the overlap (intersection) between the predicted mask and the true mask.
    $DiceLoss = 1 - \frac{2 \times |Intersection(A, B)|}{|A| + |B|}$

We will combine these two losses.

In [None]:
# 1. Function to get pixel weights (based on Listing 8.13)
def get_label_weights(y_true, y_pred):
    y_true = tf.cast(y_true, 'int32')
    weights = tf.reduce_sum(tf.one_hot(y_true, num_classes), axis=[1, 2])
    tot = tf.reduce_sum(weights, axis=-1, keepdims=True)
    weights = (tot - weights) / tot # [batch, num_classes]
    y_true_flat = tf.reshape(y_true, [-1])
    y_weights = tf.gather(params=tf.reshape(weights, [-1]), indices=y_true_flat)
    y_weights = tf.reshape(y_weights, K.shape(y_true))
    return y_weights

# 2. Weighted Cross-Entropy Loss (based on Listing 8.14)
def ce_weighted_from_logits(num_classes):
    def loss_fn(y_true, y_pred):
        y_true = tf.cast(y_true, 'int32')
        # Mask out-of-bounds labels (e.g., boundary pixel value 255)
        valid_mask = (y_true < num_classes)
        y_true_masked = tf.boolean_mask(y_true, valid_mask)
        y_pred_masked = tf.boolean_mask(y_pred, valid_mask)
        y_weights = get_label_weights(y_true_masked, y_pred_masked)
        
        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=y_true_masked, 
            logits=y_pred_masked
        )
        return tf.reduce_mean(loss * y_weights)
    return loss_fn

# 3. Dice Loss (based on Listing 8.15)
def dice_loss_from_logits(num_classes):
    def loss_fn(y_true, y_pred):
        smooth = 1.0
        y_true = tf.cast(y_true, 'int32')
        y_pred = tf.nn.softmax(y_pred)
        y_true_one_hot = tf.one_hot(y_true, num_classes, dtype=tf.float32)
        
        # Flatten
        y_true_flat = tf.reshape(y_true_one_hot, [-1, num_classes])
        y_pred_flat = tf.reshape(y_pred, [-1, num_classes])
        
        intersection = tf.reduce_sum(y_true_flat * y_pred_flat, axis=0)
        union = tf.reduce_sum(y_true_flat, axis=0) + tf.reduce_sum(y_pred_flat, axis=0)
        
        score = (2. * intersection + smooth) / (union + smooth)
        loss = 1.0 - tf.reduce_mean(score)
        return loss
    return loss_fn

# 4. Combined Loss (based on Listing 8.16)
def ce_dice_loss_from_logits(num_classes):
    def loss_fn(y_true, y_pred):
        ce_loss = ce_weighted_from_logits(num_classes)(y_true, y_pred)
        dice_loss = dice_loss_from_logits(num_classes)(y_true, y_pred)
        return ce_loss + dice_loss
    return loss_fn

print("Custom loss functions defined.")

### 8.4.2 Evaluation Metrics

We need custom metrics that understand segmentation:
1.  **`PixelAccuracyMetric`**: Simplest metric. What percentage of pixels were classified correctly? (Can be misleading if 'background' is 99% of the image).
2.  **`MeanAccuracyMetric`**: Calculates the accuracy *for each class* individually, then computes the mean of those accuracies. This is much better for imbalanced datasets.
3.  **`MeanIoUMetric` (Mean Intersection over Union)**: The gold standard for segmentation. For each class, it computes $IoU = \frac{True \, Positives}{True \, Positives + False \, Positives + False \, Negatives}$. It then averages this IoU score across all classes.

In [None]:
# We will use Keras's built-in MeanIoU metric for simplicity, 
# as the custom implementations in the book (Listings 8.17-8.19) 
# are primarily for demonstrating how to build stateful metrics.

# Keras's MeanIoU handles the masking of out-of-bounds labels.
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)

deeplabv3.compile(
    loss=ce_dice_loss_from_logits(num_classes),
    optimizer=optimizer,
    metrics=[
        tf.keras.metrics.MeanIoU(num_classes=num_classes, name='mean_iou')
    ]
)

# Copy weights from the original ResNet-50 conv5 block
w_dict = {}
for l in ["conv5_block1_0_conv", "conv5_block1_0_bn", 
          "conv5_block1_1_conv", "conv5_block1_1_bn", 
          "conv5_block1_2_conv", "conv5_block1_2_bn", 
          "conv5_block1_3_conv", "conv5_block1_3_bn"]:
    if l in [layer.name for layer in resnet50.layers]:
        w_dict[l] = resnet50.get_layer(l).get_weights()

for k, w in w_dict.items():
    if k in [layer.name for layer in deeplabv3.layers]:
        deeplabv3.get_layer(k).set_weights(w)

print("Model compiled with custom loss and MeanIoU metric.")

---

## 8.5 & 8.6: Training and Evaluating the Model

In [None]:
# Get number of steps per epoch
n_train_files = len(pd.read_csv(os.path.join(subset_dir, "train.txt"), 
                                index_col=None, header=None, squeeze=True))
n_val_files = len(pd.read_csv(os.path.join(subset_dir, "val.txt"), 
                              index_col=None, header=None, squeeze=True)) // 2

n_train_steps = n_train_files // batch_size
n_valid_steps = n_val_files // batch_size

print(f"Training steps: {n_train_steps}")
print(f"Validation steps: {n_valid_steps}")

# Define callbacks
csv_logger = CSVLogger(os.path.join('eval', '1_pretrained_deeplabv3.log'))
lr_callback = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=3, mode='min')
es_callback = EarlyStopping(monitor='val_loss', patience=6, mode='min')

# Train the model (running for only 1 epoch for demonstration)
print("Starting model training...")
history = deeplabv3.fit(
    x=tr_image_ds,
    steps_per_epoch=n_train_steps,
    validation_data=val_image_ds,
    validation_steps=n_valid_steps,
    epochs=1, # Book runs for 25
    callbacks=[lr_callback, csv_logger, es_callback]
)
print("Training complete.")

In [None]:
# Evaluate on the test set
print("\nEvaluating on test set...")
n_test_files = n_val_files # Since we split val 50/50
n_test_steps = n_test_files // batch_size

test_results = deeplabv3.evaluate(test_image_ds, steps=n_test_steps)
test_res_dict = dict(zip(deeplabv3.metrics_names, test_results))
print("Test Results:")
print(test_res_dict)

# Visualize predictions (based on 8.6)
print("\nVisualizing 2 test predictions...")
plt.figure(figsize=(15, 8))
for i, (x, y) in enumerate(test_image_ds.take(2)):
    y_pred = deeplabv3.predict(x)
    y_pred_argmax = tf.argmax(y_pred, axis=-1)
    
    # Original Image
    plt.subplot(2, 3, i*3 + 1)
    plt.imshow(x[0])
    plt.title("Original Image")
    plt.axis('off')
    
    # Ground Truth Mask
    plt.subplot(2, 3, i*3 + 2)
    plt.imshow(y[0], vmin=0, vmax=num_classes-1)
    plt.title("Ground Truth Mask")
    plt.axis('off')
    
    # Predicted Mask
    plt.subplot(2, 3, i*3 + 3)
    plt.imshow(y_pred_argmax[0], vmin=0, vmax=num_classes-1)
    plt.title("Predicted Mask")
    plt.axis('off')

plt.tight_layout()
plt.show()