<a href="https://colab.research.google.com/github/Monaa48/TensorFlow-in-Action-starter/blob/main/notebooks/Ch08_Telling_things_apart_Image_segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Chapter 08 — Telling Things Apart: Image Segmentation

This chapter moves from **image classification** to **image segmentation**, where the target is not a single label per image, but a **label for every pixel**.  
Instead of asking “What is in the image?”, segmentation asks “Which pixels belong to which object category?”

In the book, the task is built around **PASCAL VOC 2012** segmentation data and an advanced model, **DeepLab v3**, which uses:
- a strong feature extractor (pretrained ResNet-50 backbone),
- **atrous / dilated convolution** to preserve spatial detail while keeping a large receptive field,
- an **ASPP (Atrous Spatial Pyramid Pooling)** module to combine context at multiple effective scales,
- and a final upsampling step to produce a full-resolution segmentation mask.

In this notebook, I reproduce the chapter workflow end-to-end:
1) download + inspect the VOC2012 segmentation data,
2) build a proper `tf.data` pipeline (with aligned augmentation for image + mask),
3) implement a DeepLab v3–style model (ResNet50 + ASPP),
4) compile with segmentation-aware loss/metrics (including masking the “void/border” label),
5) train, evaluate, and visualize predictions (qualitative inspection).


## 1) Summary

### 1.1 Segmentation vs classification
Classification outputs a single label, but segmentation outputs a **dense grid** of labels.  
For a model, this changes several things at once:

- **Targets are images too.**  
  In VOC2012, the target masks are stored as palettized PNGs: each pixel is an index into a fixed color palette. That index corresponds to a class such as *person*, *dog*, *chair*, etc.

- **Augmentation must be coordinated.**  
  If an input image is randomly cropped or flipped, the target mask must be cropped or flipped in exactly the same way. Otherwise the labels no longer correspond to the pixels.

- **Metrics should reflect spatial correctness.**  
  Pixel accuracy can be misleading when background dominates. The most common segmentation metric is **IoU (Intersection-over-Union)**, which penalizes both false positives and false negatives in a spatial way.

### 1.2 VOC2012 segmentation label structure (important detail)
VOC2012 segmentation masks typically contain:
- class indices (0…20) for known categories + background,
- a special value for “border/void” that should usually be excluded from loss/metrics.

In this notebook, I map `255 → 21` and treat `21` as the **void** class id.  
During training/evaluation, I ignore void pixels so that the model is not rewarded or punished for ambiguous boundary regions.

### 1.3 DeepLab v3 in a practical way
DeepLab v3 is a design that balances **semantic context** and **spatial precision**:

- The backbone (ResNet) provides strong feature extraction.
- Atrous convolution helps keep a large receptive field without aggressive downsampling.
- ASPP uses multiple dilation rates + global pooling to gather information from different spatial extents.
- A final upsampling step restores full resolution for per-pixel prediction.

The final output is a tensor shaped `(H, W, num_classes)`, and training uses per-pixel categorical loss.


## 2) Setup

Imports, seed, and small utilities.

In [1]:
import os
import tarfile
import random
from pathlib import Path

import numpy as np
import pandas as pd
import tensorflow as tf
from PIL import Image

import matplotlib.pyplot as plt

SEED = 4321
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

print("TensorFlow:", tf.__version__)


TensorFlow: 2.19.0


## 3) Download and extract PASCAL VOC 2012

The chapter uses PASCAL VOC 2012.  
The official download is hosted by Oxford:
- `http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar`

This is a large archive, so the code below avoids re-downloading if the file already exists.


In [3]:
from pathlib import Path
import tarfile

tar_path = Path("data/VOCtrainval_11-May-2012.tar")

print("Exists:", tar_path.exists())
if tar_path.exists():
    size_mb = tar_path.stat().st_size / (1024 * 1024)
    print("Size (MB):", round(size_mb, 2))
    print("is_tarfile:", tarfile.is_tarfile(tar_path))

    if size_mb < 500:  # ambang aman untuk mendeteksi file yang “kecil banget”
        with open(tar_path, "rb") as f:
            head = f.read(200)
        print("First 200 bytes preview:")
        print(head)


Exists: True
Size (MB): 0.0
is_tarfile: False
First 200 bytes preview:
b''


In [4]:
import shutil
from pathlib import Path

DATA_ROOT = Path("data")
bad_tar = DATA_ROOT / "VOCtrainval_11-May-2012.tar"
bad_extract = DATA_ROOT / "VOCdevkit"

if bad_tar.exists():
    bad_tar.unlink()

shutil.rmtree(bad_extract, ignore_errors=True)

print("Cleaned old/corrupted VOC files.")


Cleaned old/corrupted VOC files.


In [5]:
import urllib.request
import tarfile

DATA_ROOT = Path("data")
DATA_ROOT.mkdir(parents=True, exist_ok=True)

tar_name = "VOCtrainval_11-May-2012.tar"
tar_path = DATA_ROOT / tar_name
extract_dir = DATA_ROOT / "VOCtrainval_11-May-2012"

# URL mirror sebelumnya error (404), gunakan server resmi Oxford
url = [
    "https://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar",
    "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar",
]


# 1. Clean up corrupted file if present
if tar_path.exists():
    try:
        if not tarfile.is_tarfile(tar_path):
            print("Existing file is not a valid tar (likely corrupted download). Deleting...")
            tar_path.unlink()
        else:
            print("Archive exists and seems valid:", tar_path)
    except Exception:
        print("Error checking file. Deleting...")
        tar_path.unlink()

# 2. Download if missing
if not tar_path.exists():
    print("Downloading:", url)
    try:
        urllib.request.urlretrieve(url, tar_path)
        print("Saved to:", tar_path)
    except Exception as e:
        print(f"Download failed: {e}")
        if tar_path.exists():
            tar_path.unlink() # Hapus file parsial jika gagal

# 3. Extract
if not extract_dir.exists():
    if tar_path.exists():
        print("Extracting tar (this can take a while)...")
        try:
            with tarfile.open(tar_path) as tf_tar:
                tf_tar.extractall(extract_dir)
            print("Extracted to:", extract_dir)
        except Exception as e:
            print(f"Error during extraction: {e}")
            print("You may need to delete the file and try again.")
    else:
        print("Skipping extraction (archive not found).")
else:
    print("Extracted folder already exists:", extract_dir)

# VOC folder structure
VOC_DIR = extract_dir / "VOCdevkit" / "VOC2012"
JPEG_DIR = VOC_DIR / "JPEGImages"
MASK_DIR = VOC_DIR / "SegmentationClass"
SET_DIR = VOC_DIR / "ImageSets" / "Segmentation"

print("VOC_DIR :", VOC_DIR)
print("JPEG_DIR:", JPEG_DIR)
print("MASK_DIR:", MASK_DIR)
print("SET_DIR :", SET_DIR)

Downloading: ['https://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar', 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar']
Download failed: expected string or bytes-like object, got 'list'
Skipping extraction (archive not found).
VOC_DIR : data/VOCtrainval_11-May-2012/VOCdevkit/VOC2012
JPEG_DIR: data/VOCtrainval_11-May-2012/VOCdevkit/VOC2012/JPEGImages
MASK_DIR: data/VOCtrainval_11-May-2012/VOCdevkit/VOC2012/SegmentationClass
SET_DIR : data/VOCtrainval_11-May-2012/VOCdevkit/VOC2012/ImageSets/Segmentation


In [6]:
print("Tar exists?", tar_path.exists())
if tar_path.exists():
    print("Size (MB):", tar_path.stat().st_size / (1024 * 1024))
else:
    print("File is missing (likely deleted after failed validation).")


Tar exists? False
File is missing (likely deleted after failed validation).


In [7]:
with open(tar_path, "rb") as f:
    head = f.read(200)

print(head[:200])


FileNotFoundError: [Errno 2] No such file or directory: 'data/VOCtrainval_11-May-2012.tar'

## 4) Load split files and build file pairs (image, mask)

VOC provides predefined train/val file lists in:
- `ImageSets/Segmentation/train.txt`
- `ImageSets/Segmentation/val.txt`

The book also uses a train/val/test setup.  
VOC2012 does not ship a labeled test set publicly, so a common approach is to split the provided `val.txt` list into:
- validation subset
- test subset

I reproduce that idea by shuffling `val.txt` with a fixed seed and splitting it in half.


In [None]:
import os
import urllib.request
import tarfile
import numpy as np
from pathlib import Path

# Setup paths
DATA_ROOT = Path("data")
DATA_ROOT.mkdir(parents=True, exist_ok=True)

extract_dir = DATA_ROOT / "VOCtrainval_11-May-2012"
VOC_DIR = extract_dir / "VOCdevkit" / "VOC2012"
JPEG_DIR = VOC_DIR / "JPEGImages"
MASK_DIR = VOC_DIR / "SegmentationClass"
SET_DIR = VOC_DIR / "ImageSets" / "Segmentation"

tar_name = "VOCtrainval_11-May-2012.tar"
tar_path = DATA_ROOT / tar_name

# --- Robust Download & Extraction Logic ---

# List of mirrors to try if one fails
mirrors = [
    "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar", # Official
    "https://pjreddie.com/media/files/VOCtrainval_11-May-2012.tar",             # Mirror 1
    "http://data.lip6.fr/cadene/VOCtrainval_11-May-2012.tar"                     # Mirror 2
]

def download_file(url, path):
    print(f"Attempting download from: {url}")
    try:
        # User-agent to avoid 403 Forbidden on some servers
        opener = urllib.request.build_opener()
        opener.addheaders = [('User-agent', 'Mozilla/5.0 (Windows NT 10.0; Win64; x64)')]
        urllib.request.install_opener(opener)
        urllib.request.urlretrieve(url, path)
        print("Download finished.")
        return True
    except Exception as e:
        print(f"Download failed: {e}")
        return False

def validate_and_extract(path, target_dir):
    if not path.exists():
        return False

    print("Verifying file integrity...")
    try:
        if not tarfile.is_tarfile(path):
            print("Error: File is not a valid tar archive.")
            return False
    except Exception:
        return False

    print("Extracting...")
    try:
        with tarfile.open(path) as tf:
            tf.extractall(target_dir)
        print("Extraction successful.")
        return True
    except Exception as e:
        print(f"Extraction failed: {e}")
        return False

# Check if data is already ready
train_txt_path = SET_DIR / "train.txt"

if not train_txt_path.exists():
    print("Dataset not ready. Starting setup...")

    # 1. If file exists, check if it works. If not, delete it.
    if tar_path.exists():
        if not validate_and_extract(tar_path, extract_dir):
            print("Existing file corrupted. Deleting...")
            tar_path.unlink()

    # 2. If we still don't have the data, loop through mirrors
    if not train_txt_path.exists():
        success = False
        for url in mirrors:
            if download_file(url, tar_path):
                if validate_and_extract(tar_path, extract_dir):
                    success = True
                    break
                else:
                    print("Downloaded file invalid. Deleting and trying next mirror...")
                    if tar_path.exists():
                        tar_path.unlink()

        if not success:
            raise RuntimeError("Failed to download and extract dataset from all available mirrors.")
else:
    print("Dataset already extracted.")

# --- Load Data Helper ---

def read_id_list(path: Path):
    if not path.exists():
        raise FileNotFoundError(f"{path} not found. Dataset download might have failed.")
    ids = path.read_text().strip().split()
    return ids

train_ids = read_id_list(SET_DIR / "train.txt")
val_ids = read_id_list(SET_DIR / "val.txt")

print("train ids:", len(train_ids))
print("val ids  :", len(val_ids))

# Split val into val/test
random_seed = 4321 # explicit local seed if needed
rng = np.random.RandomState(random_seed)
val_ids_shuffled = val_ids.copy()
rng.shuffle(val_ids_shuffled)

mid = len(val_ids_shuffled) // 2
valid_ids = val_ids_shuffled[:mid]
test_ids  = val_ids_shuffled[mid:]

print("valid ids:", len(valid_ids))
print("test ids :", len(test_ids))

def make_pairs(ids):
    x_paths = [str(JPEG_DIR / f"{i}.jpg") for i in ids]
    y_paths = [str(MASK_DIR / f"{i}.png") for i in ids]
    return x_paths, y_paths

train_x, train_y = make_pairs(train_ids)
valid_x, valid_y = make_pairs(valid_ids)
test_x,  test_y  = make_pairs(test_ids)

if len(train_x) > 0:
    print("Sample path:", train_x[0])
else:
    print("Warning: No data loaded.")

## 5) Visual inspection: input image + mask

VOC masks are stored as palettized PNGs.  
If we load them with Pillow, `np.array(mask)` returns a 2D array of **class indices** (not RGB).  
For visualization, it helps to convert those indices into an RGB image using the VOC colormap.


In [None]:
def voc_colormap():
    # Standard PASCAL VOC color map (21 classes); we will add a "void" color for index 21.
    # Source idea: common VOC colormap implementation (bit trick); kept here as a utility.
    def bitget(byteval, idx):
        return (byteval & (1 << idx)) != 0

    cmap = np.zeros((256, 3), dtype=np.uint8)
    for i in range(256):
        r = g = b = 0
        c = i
        for j in range(8):
            r |= (bitget(c, 0) << (7 - j))
            g |= (bitget(c, 1) << (7 - j))
            b |= (bitget(c, 2) << (7 - j))
            c >>= 3
        cmap[i] = [r, g, b]
    return cmap

VOC_CMAP = voc_colormap()

VOID_RAW_ID = 255   # typical "border/void" in VOC masks
VOID_CLASS_ID = 21  # mapped void id in this notebook

def mask_to_rgb(mask_index):
    # mask_index: (H, W) int
    # Map 255 to 21 for visualization consistency
    m = mask_index.copy()
    m[m == VOID_RAW_ID] = VOID_CLASS_ID
    rgb = VOC_CMAP[m]
    return rgb

def show_image_and_mask(img_path, mask_path):
    img = np.array(Image.open(img_path).convert("RGB"))
    mask = np.array(Image.open(mask_path))
    mask_rgb = mask_to_rgb(mask)

    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(img)
    plt.axis("off")
    plt.title("Input image", fontsize=11)

    plt.subplot(1, 2, 2)
    plt.imshow(mask_rgb)
    plt.axis("off")
    plt.title("Segmentation mask (colorized)", fontsize=11)
    plt.tight_layout()
    plt.show()

# Show a few examples
for i in range(3):
    show_image_and_mask(train_x[i], train_y[i])


## 6) Data preparation: a `tf.data` pipeline for segmentation

### 6.1 Why `tf.data` is useful here
Segmentation datasets can be large, and masks are images too.  
A streaming pipeline helps because:
- it avoids loading everything into RAM,
- it enables parallelism and prefetching,
- it keeps augmentation consistent and repeatable.

### 6.2 Key requirements for segmentation pipelines
1) input and mask must be transformed together (crop/flip),
2) mask resizing must use **nearest-neighbor** (to keep class indices),
3) normalize inputs (the chapter uses `(x - 128) / 255`).


In [None]:
IMG_SIZE = (384, 384)              # the chapter uses a large input size
RESIZE_BEFORE_CROP = (444, 444)    # for random crop augmentation
BATCH_SIZE = 16                    # adjust based on GPU memory
EPOCHS = 10                        # change for longer training

AUTOTUNE = tf.data.AUTOTUNE

def load_image_np(path):
    # path is a bytes object when called via tf.numpy_function
    path = path.decode("utf-8")
    img = np.array(Image.open(path).convert("RGB"), dtype=np.uint8)
    return img

def load_mask_np(path):
    path = path.decode("utf-8")
    m = np.array(Image.open(path), dtype=np.uint8)  # palettized indices
    # Map VOC void border 255 -> 21
    m[m == VOID_RAW_ID] = VOID_CLASS_ID
    return m

def tf_load_pair(img_path, mask_path):
    img = tf.numpy_function(load_image_np, [img_path], Tout=tf.uint8)
    mask = tf.numpy_function(load_mask_np, [mask_path], Tout=tf.uint8)
    img.set_shape([None, None, 3])
    mask.set_shape([None, None])
    mask = mask[..., tf.newaxis]  # (H, W, 1)
    return img, mask

def random_resize_and_crop(img, mask, crop_size=IMG_SIZE, resize_to=RESIZE_BEFORE_CROP):
    # Resize first, then random crop
    img = tf.image.resize(img, resize_to, method="bilinear")
    mask = tf.image.resize(mask, resize_to, method="nearest")

    h = tf.shape(img)[0]
    w = tf.shape(img)[1]
    ch, cw = crop_size

    offset_h = tf.random.uniform([], 0, h - ch + 1, dtype=tf.int32)
    offset_w = tf.random.uniform([], 0, w - cw + 1, dtype=tf.int32)

    img = tf.image.crop_to_bounding_box(img, offset_h, offset_w, ch, cw)
    mask = tf.image.crop_to_bounding_box(mask, offset_h, offset_w, ch, cw)
    return img, mask

def resize_only(img, mask, size=IMG_SIZE):
    img = tf.image.resize(img, size, method="bilinear")
    mask = tf.image.resize(mask, size, method="nearest")
    return img, mask

def augment(img, mask):
    # Horizontal flip
    if tf.random.uniform([]) < 0.5:
        img = tf.image.flip_left_right(img)
        mask = tf.image.flip_left_right(mask)

    # Color jitter for image only
    img = tf.image.random_hue(img, 0.08)
    img = tf.image.random_brightness(img, 0.15)
    img = tf.image.random_contrast(img, 0.8, 1.2)
    img = tf.clip_by_value(img, 0.0, 255.0)
    return img, mask

def normalize(img, mask):
    # Chapter normalization: (x - 128) / 255
    img = tf.cast(img, tf.float32)
    img = (img - 128.0) / 255.0

    mask = tf.cast(mask, tf.int32)  # keep integer labels
    return img, mask

def make_dataset(x_paths, y_paths, training, batch_size=BATCH_SIZE):
    ds = tf.data.Dataset.from_tensor_slices((x_paths, y_paths))
    if training:
        ds = ds.shuffle(buffer_size=min(len(x_paths), 2048), seed=SEED, reshuffle_each_iteration=True)

    ds = ds.map(tf_load_pair, num_parallel_calls=AUTOTUNE)

    if training:
        ds = ds.map(lambda x, y: random_resize_and_crop(x, y), num_parallel_calls=AUTOTUNE)
        ds = ds.map(augment, num_parallel_calls=AUTOTUNE)
    else:
        ds = ds.map(resize_only, num_parallel_calls=AUTOTUNE)

    ds = ds.map(normalize, num_parallel_calls=AUTOTUNE)

    ds = ds.batch(batch_size)
    ds = ds.prefetch(AUTOTUNE)
    return ds

train_ds = make_dataset(train_x, train_y, training=True)
valid_ds = make_dataset(valid_x, valid_y, training=False)
test_ds  = make_dataset(test_x,  test_y,  training=False)

# Sanity check shapes
x_batch, y_batch = next(iter(train_ds))
print("x batch:", x_batch.shape, x_batch.dtype)
print("y batch:", y_batch.shape, y_batch.dtype)
print("unique labels (sample):", np.unique(y_batch.numpy())[:15])


## 7) Model: DeepLab v3 style (ResNet50 backbone + ASPP)

### 7.1 Backbone feature map choice
DeepLab v3 typically uses a backbone where the output stride is ~16.
A simple practical way is to take the output of the ResNet50 **conv4** block as the feature map,
then apply ASPP on top of that feature map and upsample back to the input resolution.

### 7.2 ASPP (Atrous Spatial Pyramid Pooling)
ASPP combines:
- 1×1 conv (local features),
- 3×3 conv with multiple dilation rates (multi-scale context),
- image-level pooling branch (global context),
then concatenates and projects the combined features.


In [None]:
from tensorflow.keras import layers, Model
from tensorflow.keras.applications import ResNet50

NUM_CLASSES = 22  # 0..20 + void=21

def conv_bn_relu(x, filters, k, rate=1):
    x = layers.Conv2D(filters, k, padding="same", dilation_rate=rate, use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    return x

def aspp(x, filters=256, rates=(6, 12, 18), dropout=0.1):
    # 1x1
    b0 = conv_bn_relu(x, filters, 1, rate=1)

    # 3x3 atrous
    b1 = conv_bn_relu(x, filters, 3, rate=rates[0])
    b2 = conv_bn_relu(x, filters, 3, rate=rates[1])
    b3 = conv_bn_relu(x, filters, 3, rate=rates[2])

    # Image-level pooling
    pool = layers.Lambda(lambda t: tf.reduce_mean(t, axis=[1, 2], keepdims=True))(x)
    pool = conv_bn_relu(pool, filters, 1, rate=1)
    pool = layers.Lambda(lambda t: tf.image.resize(t, tf.shape(x)[1:3], method="bilinear"))(pool)

    y = layers.Concatenate(axis=-1)([b0, b1, b2, b3, pool])
    y = conv_bn_relu(y, filters, 1, rate=1)
    y = layers.Dropout(dropout)(y)
    return y

def build_deeplab_v3(input_size=IMG_SIZE, num_classes=NUM_CLASSES):
    inputs = layers.Input(shape=(input_size[0], input_size[1], 3))

    base = ResNet50(include_top=False, weights="imagenet", input_tensor=inputs)
    feat = base.get_layer("conv4_block6_out").output  # stride ~16

    x = aspp(feat, filters=256, rates=(6, 12, 18), dropout=0.2)
    x = conv_bn_relu(x, 256, 3, rate=1)

    # Upsample to input size (bilinear)
    x = layers.Lambda(lambda t: tf.image.resize(t, input_size, method="bilinear"))(x)

    logits = layers.Conv2D(num_classes, 1, padding="same")(x)
    outputs = layers.Softmax(axis=-1)(logits)

    model = Model(inputs, outputs, name="deeplabv3_resnet50_aspp")
    return model

model = build_deeplab_v3()
model.summary()


## 8) Loss and metrics (masking void pixels)

### 8.1 Why masking is necessary
VOC masks contain a “void/border” label (mapped to class id `21` here).  
These pixels are ambiguous boundaries and should not dominate optimization.

So for loss and metrics:
- create a boolean mask: `y_true != VOID_CLASS_ID`,
- compute values only on valid pixels.

### 8.2 Metrics included
- masked pixel accuracy (useful for monitoring),
- masked mean IoU (more meaningful for segmentation quality).


In [None]:
def masked_sparse_cce(y_true, y_pred, void_id=VOID_CLASS_ID):
    # y_true: (B, H, W, 1), y_pred: (B, H, W, C) probabilities
    y_true = tf.cast(tf.squeeze(y_true, axis=-1), tf.int32)
    mask = tf.not_equal(y_true, void_id)

    y_true_m = tf.boolean_mask(y_true, mask)
    y_pred_m = tf.boolean_mask(y_pred, mask)

    # sparse categorical cross-entropy expects y_true shape (N,) and y_pred shape (N, C)
    return tf.reduce_mean(tf.keras.losses.sparse_categorical_crossentropy(y_true_m, y_pred_m))

def masked_pixel_accuracy(y_true, y_pred, void_id=VOID_CLASS_ID):
    y_true = tf.cast(tf.squeeze(y_true, axis=-1), tf.int32)
    y_hat = tf.argmax(y_pred, axis=-1, output_type=tf.int32)

    mask = tf.not_equal(y_true, void_id)
    y_true_m = tf.boolean_mask(y_true, mask)
    y_hat_m  = tf.boolean_mask(y_hat, mask)

    return tf.reduce_mean(tf.cast(tf.equal(y_true_m, y_hat_m), tf.float32))

def masked_mean_iou(y_true, y_pred, void_id=VOID_CLASS_ID, num_classes=NUM_CLASSES):
    y_true = tf.cast(tf.squeeze(y_true, axis=-1), tf.int32)
    y_hat = tf.argmax(y_pred, axis=-1, output_type=tf.int32)

    mask = tf.not_equal(y_true, void_id)
    y_true_m = tf.boolean_mask(y_true, mask)
    y_hat_m  = tf.boolean_mask(y_hat, mask)

    cm = tf.math.confusion_matrix(y_true_m, y_hat_m, num_classes=num_classes, dtype=tf.float32)

    # Exclude void row/col for IoU averaging
    cm = cm[:void_id, :void_id]  # classes 0..20

    diag = tf.linalg.diag_part(cm)
    rowsum = tf.reduce_sum(cm, axis=1)
    colsum = tf.reduce_sum(cm, axis=0)
    denom = rowsum + colsum - diag

    iou = diag / (denom + 1e-8)

    # Average only over classes that appear (denom > 0)
    valid = tf.greater(denom, 0.0)
    iou = tf.boolean_mask(iou, valid)
    return tf.reduce_mean(iou)

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=masked_sparse_cce,
    metrics=[masked_pixel_accuracy, masked_mean_iou],
)


## 9) Training

This is a compute-heavy task. The defaults are chosen to be runnable, but you can scale them:
- increase `EPOCHS`,
- increase `BATCH_SIZE` if GPU memory allows,
- or reduce `IMG_SIZE` to speed up iterations.

I also add callbacks to make training more controlled:
- ModelCheckpoint to save best validation IoU,
- EarlyStopping to reduce wasted epochs when validation stops improving.


In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

ckpt_path = "deeplabv3_voc2012_best.keras"

callbacks = [
    ModelCheckpoint(ckpt_path, monitor="val_masked_mean_iou", mode="max", save_best_only=True),
    EarlyStopping(monitor="val_masked_mean_iou", mode="max", patience=4, restore_best_weights=True),
]

history = model.fit(
    train_ds,
    validation_data=valid_ds,
    epochs=EPOCHS,
    callbacks=callbacks,
    verbose=1,
)


### 9.1 Plot learning curves

In [None]:
def plot_history(hist, keys):
    plt.figure(figsize=(10, 4))
    for k in keys:
        if k in hist.history:
            plt.plot(hist.history[k], label=k)
    plt.xlabel("Epoch")
    plt.legend()
    plt.grid(True)
    plt.show()

plot_history(history, ["loss", "val_loss"])
plot_history(history, ["masked_pixel_accuracy", "val_masked_pixel_accuracy"])
plot_history(history, ["masked_mean_iou", "val_masked_mean_iou"])


## 10) Evaluation on the held-out test split

After training, evaluate on the split created from VOC `val.txt`.  
The reported IoU is computed after excluding void pixels, which makes it more meaningful.


In [None]:
test_metrics = model.evaluate(test_ds, verbose=1)
for name, value in zip(model.metrics_names, test_metrics):
    print(f"{name:25s}: {value:.4f}")


## 11) Qualitative results: visualize predictions

Quantitative metrics summarize performance, but segmentation also benefits from direct inspection:
- are object boundaries roughly correct?
- does the model confuse similar categories?
- does it mostly predict background?

Below, I:
1) run inference on a few test images,
2) convert predicted class indices to RGB for visualization,
3) show input / ground-truth / prediction side-by-side.


In [None]:
def predict_mask(model, img):
    # img: (H, W, 3) float32 normalized
    pred = model.predict(img[None, ...], verbose=0)[0]  # (H, W, C)
    pred_idx = np.argmax(pred, axis=-1).astype(np.uint8)
    return pred_idx

def denormalize_for_display(x):
    # x was (x - 128)/255; undo approximately for display
    x_disp = (x * 255.0) + 128.0
    x_disp = np.clip(x_disp, 0, 255).astype(np.uint8)
    return x_disp

def show_prediction_triplet(img, true_mask, pred_mask):
    img_disp = denormalize_for_display(img)
    true_rgb = VOC_CMAP[true_mask]
    pred_rgb = VOC_CMAP[pred_mask]

    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1)
    plt.imshow(img_disp)
    plt.axis("off")
    plt.title("Input", fontsize=11)

    plt.subplot(1, 3, 2)
    plt.imshow(true_rgb)
    plt.axis("off")
    plt.title("Ground truth", fontsize=11)

    plt.subplot(1, 3, 3)
    plt.imshow(pred_rgb)
    plt.axis("off")
    plt.title("Prediction", fontsize=11)
    plt.tight_layout()
    plt.show()

# Run on a few test batches
for x_b, y_b in test_ds.take(1):
    x_np = x_b.numpy()
    y_np = y_b.numpy().squeeze(-1).astype(np.uint8)

    for i in range(min(3, x_np.shape[0])):
        pred = predict_mask(model, x_np[i])
        show_prediction_triplet(x_np[i], y_np[i], pred)


## 12) Takeaways

- Segmentation requires treating labels as structured spatial data, not simple categories.
- A good pipeline matters: resizing, cropping, and flipping must keep the input image and mask aligned.
- Nearest-neighbor resizing for masks is non-negotiable, otherwise class labels get corrupted.
- DeepLab v3’s ASPP is a practical way to combine local and global context without an encoder-decoder structure.
- Masking void pixels makes training and metrics more stable and interpretable on VOC-style datasets.


## 13) References

- Thushan Ganegedara, *TensorFlow in Action* (Chapter 8).
- PASCAL VOC 2012 dataset (segmentation benchmark).
- DeepLab v3: Chen et al., “Rethinking Atrous Convolution for Semantic Image Segmentation.”
