<a href="https://colab.research.google.com/github/Taramas73/DS-final-project/blob/irusha/Multiclass_Segmentation_Implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 3. Multiclass Segmentation Implementation
Here's how to implement multiclass segmentation with one-hot encoding and softmax output:



*   Make sure your masks are properly one-hot encoded
*   Balance your dataset if class distributions are uneven
*   Consider using class weights if some classes are rare but important
*  The combined loss function (Dice + Categorical Crossentropy) usually gives better results than either alone



In [1]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.utils import to_categorical

def preprocess_multiclass_masks(mask_paths, class_mapping, img_height=256, img_width=256):
    """
    Preprocess masks for multiclass segmentation

    Args:
        mask_paths: List of paths to mask images
        class_mapping: Dictionary mapping pixel values to class indices
        img_height: Height to resize masks to
        img_width: Width to resize masks to

    Returns:
        Preprocessed one-hot encoded masks
    """
    num_classes = len(class_mapping)
    masks = []

    for mask_path in mask_paths:
        # Load mask image (assuming it contains class indices as pixel values)
        mask = tf.keras.preprocessing.image.load_img(
            mask_path, target_size=(img_height, img_width), color_mode='grayscale'
        )
        mask = np.array(mask)
        mask = mask.squeeze()

        # Map original pixel values to class indices
        encoded_mask = np.zeros_like(mask)
        for original_value, class_idx in class_mapping.items():
            encoded_mask[mask == original_value] = class_idx

        # Convert to one-hot encoding
        one_hot_mask = to_categorical(encoded_mask, num_classes=num_classes)
        masks.append(one_hot_mask)

    return np.array(masks)

def multiclass_dice_coefficient(y_true, y_pred, smooth=1e-6):
    """
    Dice coefficient for multiclass segmentation
    """
    # Flatten the predictions and true values
    y_true_flat = tf.reshape(y_true, [-1, tf.shape(y_true)[-1]])
    y_pred_flat = tf.reshape(y_pred, [-1, tf.shape(y_pred)[-1]])

    # Calculate intersection and union for each class
    intersection = tf.reduce_sum(y_true_flat * y_pred_flat, axis=0)
    union = tf.reduce_sum(y_true_flat, axis=0) + tf.reduce_sum(y_pred_flat, axis=0)

    # Calculate dice coefficient for each class
    dice = (2. * intersection + smooth) / (union + smooth)

    # Return mean dice over all classes
    return tf.reduce_mean(dice)

def multiclass_dice_loss(y_true, y_pred):
    """
    Dice loss for multiclass segmentation
    """
    return 1 - multiclass_dice_coefficient(y_true, y_pred)

# Combine with categorical crossentropy for better performance
def combined_multiclass_loss(y_true, y_pred):
    dice_loss = multiclass_dice_loss(y_true, y_pred)
    ce_loss = tf.keras.losses.categorical_crossentropy(y_true, y_pred)
    return dice_loss + ce_loss

# Data generator for multiclass segmentation
class MulticlassDataGenerator(tf.keras.utils.Sequence):
    def __init__(self, image_paths, mask_paths, class_mapping, batch_size=8,
                 img_height=256, img_width=256, augmentation=None, shuffle=True):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.class_mapping = class_mapping
        self.num_classes = len(class_mapping)
        self.batch_size = batch_size
        self.img_height = img_height
        self.img_width = img_width
        self.augmentation = augmentation
        self.shuffle = shuffle
        self.indexes = np.arange(len(image_paths))
        if self.shuffle:
            np.random.shuffle(self.indexes)

    def __len__(self):
        return int(np.ceil(len(self.image_paths) / self.batch_size))

    def __getitem__(self, idx):
        batch_indexes = self.indexes[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_images = []
        batch_masks = []

        for i in batch_indexes:
            # Load and preprocess image
            img = tf.keras.preprocessing.image.load_img(
                self.image_paths[i], target_size=(self.img_height, self.img_width)
            )
            img = np.array(img) / 255.0

            # Load and preprocess mask
            mask = tf.keras.preprocessing.image.load_img(
                self.mask_paths[i], target_size=(self.img_height, self.img_width),
                color_mode='grayscale'
            )
            mask = np.array(mask)
            mask = mask.squeeze()

            # Map original pixel values to class indices
            encoded_mask = np.zeros_like(mask)
            for original_value, class_idx in self.class_mapping.items():
                encoded_mask[mask == original_value] = class_idx

            # Convert to one-hot encoding
            one_hot_mask = to_categorical(encoded_mask, num_classes=self.num_classes)

            # Apply augmentation if specified
            if self.augmentation:
                augmented = self.augmentation(image=img, mask=one_hot_mask)
                img = augmented['image']
                one_hot_mask = augmented['mask']

            batch_images.append(img)
            batch_masks.append(one_hot_mask)

        return np.array(batch_images), np.array(batch_masks)

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indexes)

# Example of setting up and training a multiclass segmentation model
def train_multiclass_model(train_img_paths, train_mask_paths,
                           val_img_paths, val_mask_paths,
                           class_mapping, batch_size=8, epochs=50):
    # Create data generators
    train_gen = MulticlassDataGenerator(
        train_img_paths, train_mask_paths, class_mapping, batch_size=batch_size
    )
    val_gen = MulticlassDataGenerator(
        val_img_paths, val_mask_paths, class_mapping, batch_size=batch_size, shuffle=False
    )

    # Create model - you can use the EfficientNet + U-Net hybrid or regular U-Net
    # The only difference is the output layer should use softmax and output num_classes channels
    model = build_efficient_unet(input_shape=(256, 256, 3), n_classes=len(class_mapping))

    # Compile model
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
        loss=combined_multiclass_loss,
        metrics=[multiclass_dice_coefficient, 'categorical_accuracy']
    )

    # Setup callbacks
    callbacks = [
        tf.keras.callbacks.ModelCheckpoint(
            'best_multiclass_model.h5',
            save_best_only=True,
            monitor='val_multiclass_dice_coefficient',
            mode='max'
        ),
        tf.keras.callbacks.ReduceLROnPlateau(
            monitor='val_loss', factor=0.2, patience=5, min_lr=1e-6
        ),
        tf.keras.callbacks.EarlyStopping(
            monitor='val_loss', patience=15, restore_best_weights=True
        )
    ]

    # Train model
    history = model.fit(
        train_gen,
        validation_data=val_gen,
        epochs=epochs,
        callbacks=callbacks
    )

    return model, history