<a href="https://colab.research.google.com/github/DaraRahma536/TensorFlow-in-Action/blob/main/Chapter_08.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Chapter 8: Image Segmentation dengan DeepLab v3**
Chapter ini membahas implementasi semantic image segmentation menggunakan model DeepLab v3 dengan dataset PASCAL VOC 2012. Fokus utama meliputi:
* Persiapan data untuk tugas segmentasi
* Implementasi pipeline tf.data yang efisien
* Arsitektur DeepLab v3 dengan atrous convolution dan ASPP
* Fungsi loss dan metrik kustom untuk segmentasi
* Pelatihan dan evaluasi model

# **1. Jenis Segmentasi**
---
### **A. Semantic Segmentation**
* Setiap pixel diklasifikasikan ke kategori objek
* Objek yang sama jenisnya mendapat label sama (misal: semua orang = satu class)

### **B. Instance Segmentation**
* Setiap objek individu dipisahkan meskipun jenisnya sama
* Lebih sulit dari semantic segmentation
* Dataset yang digunakan: PASCAL VOC 2012 (22 kelas termasuk background)

# **2. Persiapan Data Pipeline dengan ```tf.data```**
---
### **A. Memuat Data Segmentasi**

In [None]:
# Listing 8.1: Download dataset
import os
import requests
import tarfile

if not os.path.exists(os.path.join('data','VOCtrainval_11-May-2012.tar')):
    url = "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"
    r = requests.get(url)

    if not os.path.exists('data'):
        os.mkdir('data')

    with open(os.path.join('data','VOCtrainval_11-May-2012.tar'), 'wb') as f:
        f.write(r.content)

# Ekstrak jika belum
if not os.path.exists(os.path.join('data','VOCtrainval_11-May-2012')):
    with tarfile.open(os.path.join('data','VOCtrainval_11-May-2012.tar'), 'r') as tar:
        tar.extractall('data')

### **B. Memuat Gambar Target (Palettized Images)**

In [None]:
# Listing 8.2: Konversi palettized image ke RGB
import numpy as np
from PIL import Image

def rgb_image_from_palette(image):
    """Mengembalikan RGB values dari PNG palettized image"""
    palette = image.getpalette()
    palette = np.array(palette).reshape(-1, 3)

    if isinstance(image, Image.Image):
        h, w = image.height, image.width
        image = np.array(image).reshape(-1)

    rgb_image = np.zeros(shape=(image.shape[0], 3))
    rgb_image[(image != 0), :] = palette[image[(image != 0)], :]
    rgb_image = rgb_image.reshape(h, w, 3)

    return rgb_image

### **C. Pipeline Data Lengkap dengan ```tf.data```**

In [None]:
# Listing 8.6: Pipeline tf.data lengkap
def get_subset_tf_dataset(
    subset_filename_gen_func, batch_size, epochs,
    input_size=(256, 256), output_size=None,
    resize_to_before_crop=None, augmentation=False, shuffle=False
):
    # Generator filenames
    filename_ds = tf.data.Dataset.from_generator(
        subset_filename_gen_func, output_types=(tf.string, tf.string)
    )

    # Load images
    def load_image_func(image):
        img = np.array(Image.open(image))
        return img

    image_ds = filename_ds.map(lambda x, y: (
        tf.image.decode_jpeg(tf.io.read_file(x)),
        tf.numpy_function(load_image_func, [y], [tf.uint8])
    )).cache()  # Optimization: cache di memory

    # Normalisasi
    image_ds = image_ds.map(lambda x, y: (tf.cast(x, 'float32')/255.0, y))

    # Resize/Crop dengan augmentasi
    def randomly_crop_or_resize(x, y):
        # Fungsi untuk random crop dan resize
        def rand_crop(x, y):
            x = tf.image.resize(x, resize_to_before_crop, method='bilinear')
            y = tf.cast(
                tf.image.resize(
                    tf.transpose(y, [1, 2, 0]),
                    resize_to_before_crop, method='nearest'
                ),
                'float32'
            )
            # Random crop
            offset_h = tf.random.uniform([], 0, x.shape[0]-input_size[0], dtype='int32')
            offset_w = tf.random.uniform([], 0, x.shape[1]-input_size[1], dtype='int32')

            x = tf.image.crop_to_bounding_box(
                x, offset_h, offset_w, input_size[0], input_size[1]
            )
            y = tf.image.crop_to_bounding_box(
                y, offset_h, offset_w, input_size[0], input_size[1]
            )
            return x, y

        def resize(x, y):
            x = tf.image.resize(x, input_size, method='bilinear')
            y = tf.cast(
                tf.image.resize(
                    tf.transpose(y, [1, 2, 0]),
                    input_size, method='nearest'
                ),
                'float32'
            )
            return x, y

        if augmentation:
            rand = tf.random.uniform([], 0.0, 1.0)
            x, y = tf.cond(
                rand < 0.5,
                lambda: rand_crop(x, y),
                lambda: resize(x, y)
            )
        else:
            x, y = resize(x, y)

        return x, y

    image_ds = image_ds.map(randomly_crop_or_resize)

    # Augmentasi tambahan
    if augmentation:
        # Random flip horizontal
        def randomly_flip_horizontal(x, y):
            rand = tf.random.uniform([], 0.0, 1.0)
            def flip(x, y):
                return tf.image.flip_left_right(x), tf.image.flip_left_right(y)
            x, y = tf.cond(rand < 0.5, lambda: flip(x, y), lambda: (x, y))
            return x, y

        image_ds = image_ds.map(randomly_flip_horizontal)
        image_ds = image_ds.map(lambda x, y: (tf.image.random_hue(x, 0.1), y))
        image_ds = image_ds.map(lambda x, y: (tf.image.random_brightness(x, 0.1), y))
        image_ds = image_ds.map(lambda x, y: (tf.image.random_contrast(x, 0.8, 1.2), y))

    # Shuffle dan batch
    if shuffle:
        image_ds = image_ds.shuffle(buffer_size=batch_size*5)

    image_ds = image_ds.batch(batch_size).repeat(epochs)
    image_ds = image_ds.prefetch(tf.data.experimental.AUTOTUNE)

    # Remove channel dimension dari target
    image_ds = image_ds.map(lambda x, y: (x, tf.squeeze(y)))

    return image_ds

### **D. Optimasi Pipeline**
* ```.cache()```: Menyimpan data di memory setelah load pertama
* ```.prefetch()```: Prefetch data saat training berlangsung
* ```.shuffle()```: Shuffle dengan buffer size optimal

# **3. Model DeepLab v3**
---
### **A. Konsep Inti**
* Backbone: ResNet-50 pretrained
* Atrous Convolution: Convolution dengan "holes" untuk receptive field lebih besar tanpa tambahan parameter
* ASPP (Atrous Spatial Pyramid Pooling): Menggabungkan informasi multi-scale

### **B. Atrous Convolution**

In [None]:
# Standard vs Atrous Convolution
# Standard 3x3 conv: receptive field 3x3
# Atrous conv rate=2: receptive field 5x5
# Atrous conv rate=3: receptive field 7x7

### **C. Implementasi DeepLab v3**

In [None]:
# Listing 8.8-8.12: Implementasi DeepLab v3 lengkap
import tensorflow as tf
from tensorflow.keras import layers, models

# Block Level 3 (conv layer dengan batch norm)
def block_level3(inp, filters, kernel_size, rate, block_id, convlayer_id, activation=True):
    conv_name = f'conv5_block{block_id}_{convlayer_id}_conv'
    bn_name = f'conv5_block{block_id}_{convlayer_id}_bn'

    conv_out = layers.Conv2D(
        filters, kernel_size, dilation_rate=rate,
        padding='same', name=conv_name
    )(inp)

    bn_out = layers.BatchNormalization(name=bn_name)(conv_out)

    if activation:
        return layers.Activation('relu',
            name=f'conv5_block{block_id}_{convlayer_id}_relu'
        )(bn_out)
    return bn_out

# Block Level 2 (3 layer convolution)
def block_level2(inp, rate, block_id):
    block_1_out = block_level3(inp, 512, (1,1), rate, block_id, 1)
    block_2_out = block_level3(block_1_out, 512, (3,3), rate, block_id, 2)
    block_3_out = block_level3(
        block_2_out, 2048, (1,1), rate, block_id, 3, activation=False
    )
    return block_3_out

# ASPP Module
def atrous_spatial_pyramid_pooling(inp):
    # Part A: Multi-scale atrous convolutions
    outa_1 = block_level3(inp, 256, (1,1), 1, '_aspp_a', 1, activation='relu')
    outa_2 = block_level3(inp, 256, (3,3), 6, '_aspp_a', 2, activation='relu')
    outa_3 = block_level3(inp, 256, (3,3), 12, '_aspp_a', 3, activation='relu')
    outa_4 = block_level3(inp, 256, (3,3), 18, '_aspp_a', 4, activation='relu')

    # Part B: Global context
    outb_1_avg = layers.Lambda(
        lambda x: tf.reduce_mean(x, axis=[1,2], keepdims=True)
    )(inp)
    outb_1_conv = block_level3(outb_1_avg, 256, (1,1), 1, '_aspp_b', 1, activation='relu')
    outb_1_up = layers.UpSampling2D((24,24), interpolation='bilinear')(outb_1_conv)

    # Concatenate semua output
    out_aspp = layers.Concatenate()([outa_1, outa_2, outa_3, outa_4, outb_1_up])
    return out_aspp

# Build Model Lengkap
def build_deeplabv3(input_size=(384, 384, 3), num_classes=21):
    inputs = layers.Input(shape=input_size)

    # Backbone ResNet50 sampai conv4
    resnet50 = tf.keras.applications.ResNet50(
        include_top=False, input_tensor=inputs, pooling=None
    )

    # Get output sampai conv4 block
    for layer in resnet50.layers:
        if layer.name == "conv5_block1_1_conv":
            break
        out = layer.output

    resnet50_upto_conv4 = models.Model(resnet50.input, out)

    # Conv5 block dengan atrous convolution
    def resnet_block(inp, rate):
        # Implementasi conv5 block dengan atrous conv
        # ... (lihat listing 8.10)
        return block_output

    # ASPP
    aspp_out = atrous_spatial_pyramid_pooling(resnet_block_out)

    # Final layers
    out = layers.Conv2D(num_classes, (1,1), padding='same')(aspp_out)
    final_out = layers.UpSampling2D((16,16), interpolation='bilinear')(out)

    model = models.Model(inputs, final_out)
    return model

# **4. Loss Functions dan Metrics untuk Segmentasi**
---
### **A Weighted Cross-Entropy Loss**

In [None]:
# Listing 8.13-8.14: Weighted CE Loss
def get_label_weights(y_true, y_pred):
    # Hitung class weights berdasarkan distribusi pixel
    weights = tf.reduce_sum(tf.one_hot(y_true, num_classes), axis=[1,2])
    tot = tf.reduce_sum(weights, axis=-1, keepdims=True)
    weights = (tot - weights) / tot

    y_true_flat = tf.reshape(y_true, [-1, y_pred.shape[1]*y_pred.shape[2]])
    y_weights = tf.gather(params=weights, indices=y_true_flat, batch_dims=1)
    return tf.reshape(y_weights, [-1])

def ce_weighted_from_logits(num_classes):
    def loss_fn(y_true, y_pred):
        valid_mask = tf.cast(
            tf.reshape((y_true <= num_classes - 1), [-1, 1]), 'int32'
        )
        y_true = tf.cast(y_true, 'int32')
        y_true.set_shape([None, y_pred.shape[1], y_pred.shape[2]])

        y_weights = get_label_weights(y_true, y_pred)
        y_pred_unwrap = tf.reshape(y_pred, [-1, num_classes])
        y_true_unwrap = tf.reshape(y_true, [-1])

        return tf.reduce_mean(
            y_weights * tf.nn.sparse_softmax_cross_entropy_with_logits(
                y_true_unwrap * tf.squeeze(valid_mask),
                y_pred_unwrap * tf.cast(valid_mask, 'float32')
            )
        )
    return loss_fn

### **B. Dice Loss**

In [None]:
# Listing 8.15: Dice Loss
def dice_loss_from_logits(num_classes):
    def loss_fn(y_true, y_pred):
        smooth = 1.0
        y_true = tf.cast(y_true, 'int32')
        y_true.set_shape([None, y_pred.shape[1], y_pred.shape[2]])

        y_weights = tf.reshape(get_label_weights(y_true, y_pred), [-1, 1])
        y_pred = tf.nn.softmax(y_pred)

        y_true_unwrap = tf.reshape(y_true, [-1])
        y_true_onehot = tf.one_hot(tf.cast(y_true_unwrap, 'int32'), num_classes)
        y_pred_unwrap = tf.reshape(y_pred, [-1, num_classes])

        intersection = tf.reduce_sum(y_true_onehot * y_pred_unwrap * y_weights)
        union = tf.reduce_sum((y_true_onehot + y_pred_unwrap) * y_weights)

        score = (2. * intersection + smooth) / (union + smooth)
        return 1 - score
    return loss_fn

### **C. Combined Loss**

In [None]:
# Listing 8.16: Combined CE + Dice Loss
def ce_dice_loss_from_logits(num_classes):
    def loss_fn(y_true, y_pred):
        ce_loss = ce_weighted_from_logits(num_classes)(y_true, y_pred)
        dice_loss = dice_loss_from_logits(num_classes)(y_true, y_pred)
        return ce_loss + dice_loss
    return loss_fn

### **D. Evaluation Metrics**

In [None]:
# Listing 8.17-8.19: Metrics untuk segmentasi

# Pixel Accuracy
class PixelAccuracyMetric(tf.keras.metrics.Accuracy):
    def __init__(self, num_classes, name='pixel_accuracy', **kwargs):
        super().__init__(name=name, **kwargs)
        self.num_classes = num_classes

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true.set_shape([None, y_pred.shape[1], y_pred.shape[2]])
        y_true = tf.reshape(y_true, [-1])
        y_pred = tf.reshape(tf.argmax(y_pred, axis=-1), [-1])

        valid_mask = tf.reshape((y_true <= self.num_classes - 1), [-1])
        y_true = tf.boolean_mask(y_true, valid_mask)
        y_pred = tf.boolean_mask(y_pred, valid_mask)

        super().update_state(y_true, y_pred)

# Mean Accuracy
class MeanAccuracyMetric(tf.keras.metrics.Mean):
    def __init__(self, num_classes, name='mean_accuracy', **kwargs):
        super().__init__(name=name, **kwargs)
        self.num_classes = num_classes

    def update_state(self, y_true, y_pred, sample_weight=None):
        smooth = 1
        y_true.set_shape([None, y_pred.shape[1], y_pred.shape[2]])
        y_true = tf.reshape(y_true, [-1])
        y_pred = tf.reshape(tf.argmax(y_pred, axis=-1), [-1])

        valid_mask = tf.reshape((y_true <= self.num_classes - 1), [-1])
        y_true = tf.boolean_mask(y_true, valid_mask)
        y_pred = tf.boolean_mask(y_pred, valid_mask)

        conf_matrix = tf.math.confusion_matrix(
            y_true, y_pred, num_classes=self.num_classes
        )
        true_pos = tf.linalg.diag_part(conf_matrix)
        mean_acc = tf.reduce_mean(
            (true_pos + smooth) / (tf.reduce_sum(conf_matrix, axis=1) + smooth)
        )
        super().update_state(mean_acc)

# Mean IoU
class MeanIoUMetric(tf.keras.metrics.MeanIoU):
    def __init__(self, num_classes, name='mean_iou', **kwargs):
        super().__init__(num_classes=num_classes, name=name, **kwargs)

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true.set_shape([None, y_pred.shape[1], y_pred.shape[2]])
        y_true = tf.reshape(y_true, [-1])
        y_pred = tf.reshape(tf.argmax(y_pred, axis=-1), [-1])

        valid_mask = tf.reshape((y_true <= self.num_classes - 1), [-1])
        y_true = tf.boolean_mask(y_true, valid_mask)
        y_pred = tf.boolean_mask(y_pred, valid_mask)

        super().update_state(y_true, y_pred)

# **5. Training dan Evaluasi**
---
### **A. Compile Model**

In [None]:
# Compile dengan loss dan metrics kustom
deeplabv3.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    loss=ce_dice_loss_from_logits(num_classes=21),
    metrics=[
        MeanIoUMetric(num_classes=21),
        MeanAccuracyMetric(num_classes=21),
        PixelAccuracyMetric(num_classes=21)
    ]
)

### **B. Training dengan Callbacks**

In [None]:
# Listing 8.20: Training model
import os

# Setup callbacks
csv_logger = tf.keras.callbacks.CSVLogger(
    os.path.join('eval', 'deeplabv3_training.log')
)

lr_callback = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss', factor=0.1, patience=3, min_lr=1e-8
)

es_callback = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss', patience=6, restore_best_weights=True
)

# Training
history = deeplabv3.fit(
    x=train_ds,
    steps_per_epoch=n_train_steps,
    validation_data=val_ds,
    validation_steps=n_val_steps,
    epochs=25,
    callbacks=[csv_logger, lr_callback, es_callback]
)

### **C. Evaluasi dan Visualisasi**

In [None]:
# Evaluate on test set
test_results = deeplabv3.evaluate(test_ds, steps=n_test_steps)
print(f"Test Results - Loss: {test_results[0]:.4f}, "
      f"Mean IoU: {test_results[1]:.4f}, "
      f"Mean Accuracy: {test_results[2]:.4f}, "
      f"Pixel Accuracy: {test_results[3]:.4f}")

# Visualize predictions
def visualize_predictions(model, dataset, n_samples=5):
    fig, axes = plt.subplots(n_samples, 3, figsize=(15, 5*n_samples))

    for i, (image, mask) in enumerate(dataset.take(n_samples)):
        pred = model.predict(tf.expand_dims(image, axis=0))
        pred_mask = tf.argmax(pred[0], axis=-1)

        axes[i, 0].imshow(image)
        axes[i, 0].set_title("Original")
        axes[i, 0].axis('off')

        axes[i, 1].imshow(mask)
        axes[i, 1].set_title("Ground Truth")
        axes[i, 1].axis('off')

        axes[i, 2].imshow(pred_mask)
        axes[i, 2].set_title("Prediction")
        axes[i, 2].axis('off')

    plt.tight_layout()
    plt.show()

# **6. Key Takeaways**
---
* tf.data pipeline penting untuk efisiensi I/O
* Atrous convolution meningkatkan receptive field tanpa tambahan parameter
* ASPP mengatasi masalah multi-scale information
* Combined loss (CE + Dice) bekerja baik untuk class imbalance
* Custom metrics diperlukan untuk evaluasi segmentasi yang akurat