# 09 image segmentation keras
**Location: TensorVerseHub/notebooks/03_computer_vision/09_image_segmentation_keras.ipynb**

TODO: Implement comprehensive TensorFlow + tf.keras learning content.

## Learning Objectives
- TODO: Define specific learning objectives
- TODO: List key TensorFlow concepts covered
- TODO: Outline tf.keras integration points

In [None]:
import tensorflow as tf
import numpy as np
print(f"TensorFlow version: {tf.__version__}")
# TODO: Add comprehensive implementation

# Image Segmentation with tf.keras

**File Location:** `notebooks/03_computer_vision/09_image_segmentation_keras.ipynb`

Master image segmentation using tf.keras Functional API, implementing U-Net, DeepLab, and custom segmentation architectures. Build semantic segmentation, instance segmentation, and medical imaging models with advanced loss functions and evaluation metrics.

## Learning Objectives
- Implement U-Net architecture with tf.keras Functional API
- Build DeepLab and advanced segmentation models
- Master segmentation-specific loss functions and metrics
- Handle multi-class and binary segmentation tasks
- Create custom segmentation datasets and data augmentation
- Deploy segmentation models for real-world applications

---

## 1. Segmentation Fundamentals and Data Preparation

```python
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras import layers, models, utils
import cv2
from sklearn.model_selection import train_test_split
import os
import warnings
warnings.filterwarnings('ignore')

print(f"TensorFlow version: {tf.__version__}")
tf.random.set_seed(42)
np.random.seed(42)

# Create synthetic segmentation dataset
def create_synthetic_segmentation_data(num_samples=1000, image_size=128):
    """Create synthetic dataset for segmentation experiments"""
    
    print("Creating synthetic segmentation dataset...")
    
    images = []
    masks = []
    
    for i in range(num_samples):
        # Create synthetic image with geometric shapes
        img = np.zeros((image_size, image_size, 3), dtype=np.uint8)
        mask = np.zeros((image_size, image_size), dtype=np.uint8)
        
        # Add background noise
        noise = np.random.randint(0, 50, (image_size, image_size, 3), dtype=np.uint8)
        img = img + noise
        
        # Add random shapes
        num_shapes = np.random.randint(1, 4)
        
        for shape_id in range(1, num_shapes + 1):
            shape_type = np.random.choice(['circle', 'rectangle', 'triangle'])
            
            if shape_type == 'circle':
                center = (np.random.randint(20, image_size-20), np.random.randint(20, image_size-20))
                radius = np.random.randint(10, 30)
                color = np.random.randint(100, 255, 3).tolist()
                
                cv2.circle(img, center, radius, color, -1)
                cv2.circle(mask, center, radius, shape_id, -1)
                
            elif shape_type == 'rectangle':
                pt1 = (np.random.randint(0, image_size//2), np.random.randint(0, image_size//2))
                pt2 = (pt1[0] + np.random.randint(20, 50), pt1[1] + np.random.randint(20, 50))
                color = np.random.randint(100, 255, 3).tolist()
                
                cv2.rectangle(img, pt1, pt2, color, -1)
                cv2.rectangle(mask, pt1, pt2, shape_id, -1)
        
        images.append(img)
        masks.append(mask)
    
    images = np.array(images, dtype=np.float32) / 255.0
    masks = np.array(masks, dtype=np.uint8)
    
    print(f"Created dataset: {images.shape} images, {masks.shape} masks")
    print(f"Unique mask values: {np.unique(masks)}")
    
    return images, masks

# Create datasets
images, masks = create_synthetic_segmentation_data(num_samples=800, image_size=128)

# Split into train/validation/test
X_train, X_temp, y_train, y_temp = train_test_split(images, masks, test_size=0.3, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)

print(f"Training set: {X_train.shape}, {y_train.shape}")
print(f"Validation set: {X_val.shape}, {y_val.shape}")
print(f"Test set: {X_test.shape}, {y_test.shape}")

# Visualization function
def visualize_segmentation(images, masks, predictions=None, num_samples=4):
    """Visualize segmentation results"""
    
    fig, axes = plt.subplots(num_samples, 3 if predictions is None else 4, figsize=(15, num_samples*3))
    
    for i in range(min(num_samples, len(images))):
        # Original image
        axes[i, 0].imshow(images[i])
        axes[i, 0].set_title('Original Image')
        axes[i, 0].axis('off')
        
        # Ground truth mask
        axes[i, 1].imshow(masks[i], cmap='viridis')
        axes[i, 1].set_title('Ground Truth Mask')
        axes[i, 1].axis('off')
        
        # Overlay
        overlay = images[i].copy()
        colored_mask = plt.cm.viridis(masks[i] / masks[i].max())[:, :, :3]
        overlay = 0.7 * overlay + 0.3 * colored_mask
        axes[i, 2].imshow(overlay)
        axes[i, 2].set_title('Overlay')
        axes[i, 2].axis('off')
        
        # Predictions if provided
        if predictions is not None:
            pred_mask = np.argmax(predictions[i], axis=-1)
            axes[i, 3].imshow(pred_mask, cmap='viridis')
            axes[i, 3].set_title('Prediction')
            axes[i, 3].axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize sample data
visualize_segmentation(X_train, y_train, num_samples=3)

# Data augmentation for segmentation
class SegmentationDataGenerator:
    """Custom data generator for segmentation with synchronized augmentations"""
    
    def __init__(self, rotation_range=20, width_shift_range=0.1, height_shift_range=0.1,
                 zoom_range=0.1, horizontal_flip=True, brightness_range=None):
        self.rotation_range = rotation_range
        self.width_shift_range = width_shift_range
        self.height_shift_range = height_shift_range
        self.zoom_range = zoom_range
        self.horizontal_flip = horizontal_flip
        self.brightness_range = brightness_range or [0.8, 1.2]
    
    def augment(self, image, mask):
        """Apply synchronized augmentations to image and mask"""
        
        # Convert to tensor for TensorFlow operations
        image_tensor = tf.constant(image, dtype=tf.float32)
        mask_tensor = tf.constant(mask, dtype=tf.float32)
        
        # Add batch dimension
        image_tensor = tf.expand_dims(image_tensor, 0)
        mask_tensor = tf.expand_dims(tf.expand_dims(mask_tensor, 0), -1)
        
        # Random horizontal flip
        if self.horizontal_flip and tf.random.uniform([]) > 0.5:
            image_tensor = tf.image.flip_left_right(image_tensor)
            mask_tensor = tf.image.flip_left_right(mask_tensor)
        
        # Random rotation
        if self.rotation_range > 0:
            angle = tf.random.uniform([], -self.rotation_range, self.rotation_range) * np.pi / 180
            image_tensor = tf.image.rot90(image_tensor, k=tf.cast(angle / (np.pi/2), tf.int32))
            mask_tensor = tf.image.rot90(mask_tensor, k=tf.cast(angle / (np.pi/2), tf.int32))
        
        # Brightness adjustment (only for image)
        if self.brightness_range:
            brightness_factor = tf.random.uniform([], self.brightness_range[0], self.brightness_range[1])
            image_tensor = tf.image.adjust_brightness(image_tensor, brightness_factor - 1.0)
            image_tensor = tf.clip_by_value(image_tensor, 0.0, 1.0)
        
        # Remove batch dimension
        image_tensor = tf.squeeze(image_tensor, 0)
        mask_tensor = tf.squeeze(mask_tensor, [0, -1])
        
        return image_tensor.numpy(), mask_tensor.numpy().astype(np.uint8)
    
    def flow(self, images, masks, batch_size=32):
        """Generate augmented batches"""
        
        num_samples = len(images)
        indices = np.arange(num_samples)
        
        while True:
            np.random.shuffle(indices)
            
            for start_idx in range(0, num_samples, batch_size):
                end_idx = min(start_idx + batch_size, num_samples)
                batch_indices = indices[start_idx:end_idx]
                
                batch_images = []
                batch_masks = []
                
                for idx in batch_indices:
                    img, mask = self.augment(images[idx], masks[idx])
                    batch_images.append(img)
                    batch_masks.append(mask)
                
                yield np.array(batch_images), np.array(batch_masks)

# Test data augmentation
augmenter = SegmentationDataGenerator()
aug_gen = augmenter.flow(X_train[:8], y_train[:8], batch_size=4)
aug_images, aug_masks = next(aug_gen)

print("Testing data augmentation:")
visualize_segmentation(aug_images, aug_masks, num_samples=2)
```

## 2. U-Net Architecture Implementation

```python
# U-Net implementation with tf.keras Functional API
def conv_block(inputs, filters, dropout_rate=0.0):
    """Convolutional block for U-Net"""
    
    x = layers.Conv2D(filters, 3, activation='relu', padding='same',
                      kernel_initializer='he_normal')(inputs)
    x = layers.BatchNormalization()(x)
    
    if dropout_rate > 0:
        x = layers.Dropout(dropout_rate)(x)
    
    x = layers.Conv2D(filters, 3, activation='relu', padding='same',
                      kernel_initializer='he_normal')(x)
    x = layers.BatchNormalization()(x)
    
    return x

def encoder_block(inputs, filters, pool=True, dropout_rate=0.0):
    """Encoder block with optional pooling"""
    
    x = conv_block(inputs, filters, dropout_rate)
    
    if pool:
        pooled = layers.MaxPooling2D(pool_size=(2, 2))(x)
        return x, pooled
    else:
        return x

def decoder_block(inputs, skip_features, filters, dropout_rate=0.0):
    """Decoder block with skip connections"""
    
    # Upsampling
    x = layers.Conv2DTranspose(filters, (2, 2), strides=2, padding='same')(inputs)
    
    # Concatenate with skip connection
    x = layers.Concatenate()([x, skip_features])
    
    # Convolutional block
    x = conv_block(x, filters, dropout_rate)
    
    return x

def build_unet(input_shape, num_classes, filters=64, dropout_rate=0.1):
    """Build U-Net model for segmentation"""
    
    inputs = layers.Input(shape=input_shape)
    
    # Encoder path (Contracting path)
    e1, p1 = encoder_block(inputs, filters, pool=True, dropout_rate=dropout_rate)
    e2, p2 = encoder_block(p1, filters * 2, pool=True, dropout_rate=dropout_rate)
    e3, p3 = encoder_block(p2, filters * 4, pool=True, dropout_rate=dropout_rate)
    e4, p4 = encoder_block(p3, filters * 8, pool=True, dropout_rate=dropout_rate)
    
    # Bottleneck
    bottleneck = encoder_block(p4, filters * 16, pool=False, dropout_rate=dropout_rate)
    
    # Decoder path (Expanding path)
    d4 = decoder_block(bottleneck, e4, filters * 8, dropout_rate=dropout_rate)
    d3 = decoder_block(d4, e3, filters * 4, dropout_rate=dropout_rate)
    d2 = decoder_block(d3, e2, filters * 2, dropout_rate=dropout_rate)
    d1 = decoder_block(d2, e1, filters, dropout_rate=dropout_rate)
    
    # Output layer
    if num_classes == 1:
        activation = 'sigmoid'
    else:
        activation = 'softmax'
    
    outputs = layers.Conv2D(num_classes, 1, activation=activation, padding='same')(d1)
    
    model = models.Model(inputs, outputs, name='U-Net')
    return model

# Attention U-Net implementation
def attention_gate(gate_signal, skip_connection, filters):
    """Attention gate for focusing on relevant features"""
    
    # Gate signal processing
    gate = layers.Conv2D(filters, 1, padding='same')(gate_signal)
    gate = layers.BatchNormalization()(gate)
    
    # Skip connection processing  
    skip = layers.Conv2D(filters, 1, padding='same')(skip_connection)
    skip = layers.BatchNormalization()(skip)
    
    # Attention mechanism
    attention = layers.Add()([gate, skip])
    attention = layers.Activation('relu')(attention)
    attention = layers.Conv2D(1, 1, activation='sigmoid', padding='same')(attention)
    
    # Apply attention
    attended = layers.Multiply()([skip_connection, attention])
    
    return attended

def build_attention_unet(input_shape, num_classes, filters=64):
    """Build Attention U-Net model"""
    
    inputs = layers.Input(shape=input_shape)
    
    # Encoder
    e1, p1 = encoder_block(inputs, filters, pool=True)
    e2, p2 = encoder_block(p1, filters * 2, pool=True)
    e3, p3 = encoder_block(p2, filters * 4, pool=True)
    e4, p4 = encoder_block(p3, filters * 8, pool=True)
    
    # Bottleneck
    bottleneck = encoder_block(p4, filters * 16, pool=False)
    
    # Decoder with attention gates
    # Upsampling bottleneck
    up4 = layers.Conv2DTranspose(filters * 8, (2, 2), strides=2, padding='same')(bottleneck)
    att4 = attention_gate(up4, e4, filters * 4)
    merge4 = layers.Concatenate()([up4, att4])
    d4 = conv_block(merge4, filters * 8)
    
    up3 = layers.Conv2DTranspose(filters * 4, (2, 2), strides=2, padding='same')(d4)
    att3 = attention_gate(up3, e3, filters * 2)
    merge3 = layers.Concatenate()([up3, att3])
    d3 = conv_block(merge3, filters * 4)
    
    up2 = layers.Conv2DTranspose(filters * 2, (2, 2), strides=2, padding='same')(d3)
    att2 = attention_gate(up2, e2, filters)
    merge2 = layers.Concatenate()([up2, att2])
    d2 = conv_block(merge2, filters * 2)
    
    up1 = layers.Conv2DTranspose(filters, (2, 2), strides=2, padding='same')(d2)
    att1 = attention_gate(up1, e1, filters // 2)
    merge1 = layers.Concatenate()([up1, att1])
    d1 = conv_block(merge1, filters)
    
    # Output
    outputs = layers.Conv2D(num_classes, 1, activation='softmax', padding='same')(d1)
    
    model = models.Model(inputs, outputs, name='Attention-U-Net')
    return model

# Build and test U-Net models
print("=== Building U-Net Models ===")

# Standard U-Net
num_classes = len(np.unique(y_train))
input_shape = X_train.shape[1:]

unet_model = build_unet(input_shape, num_classes, filters=32, dropout_rate=0.1)
print(f"U-Net model built: {unet_model.count_params():,} parameters")

# Attention U-Net
attention_unet_model = build_attention_unet(input_shape, num_classes, filters=32)
print(f"Attention U-Net model built: {attention_unet_model.count_params():,} parameters")

# Display model architectures
utils.plot_model(unet_model, show_shapes=True, show_layer_names=True, 
                 to_file='unet_architecture.png', dpi=150)
plt.figure(figsize=(12, 8))
plt.imshow(plt.imread('unet_architecture.png'))
plt.axis('off')
plt.title('U-Net Architecture')
plt.show()
```

## 3. DeepLab and Advanced Segmentation Models

```python
# DeepLab v3+ implementation
def dilated_conv_block(inputs, filters, dilation_rate, dropout_rate=0.1):
    """Dilated convolution block for DeepLab"""
    
    x = layers.Conv2D(filters, 3, padding='same', dilation_rate=dilation_rate,
                      kernel_initializer='he_normal')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    
    if dropout_rate > 0:
        x = layers.Dropout(dropout_rate)(x)
    
    return x

def atrous_spatial_pyramid_pooling(inputs, filters=256):
    """ASPP module for DeepLab"""
    
    # Image pooling
    shape = tf.shape(inputs)
    pool = layers.GlobalAveragePooling2D()(inputs)
    pool = layers.Reshape((1, 1, -1))(pool)
    pool = layers.Conv2D(filters, 1, activation='relu')(pool)
    pool = tf.image.resize(pool, (shape[1], shape[2]))
    
    # Dilated convolutions with different rates
    conv1x1 = layers.Conv2D(filters, 1, activation='relu', padding='same')(inputs)
    conv3x3_1 = dilated_conv_block(inputs, filters, dilation_rate=6)
    conv3x3_2 = dilated_conv_block(inputs, filters, dilation_rate=12)
    conv3x3_3 = dilated_conv_block(inputs, filters, dilation_rate=18)
    
    # Concatenate all features
    concat = layers.Concatenate()([pool, conv1x1, conv3x3_1, conv3x3_2, conv3x3_3])
    
    # Final 1x1 convolution
    output = layers.Conv2D(filters, 1, activation='relu', padding='same')(concat)
    output = layers.Dropout(0.1)(output)
    
    return output

def build_deeplab_v3plus(input_shape, num_classes, backbone='mobilenetv2'):
    """Build DeepLab v3+ model"""
    
    inputs = layers.Input(shape=input_shape)
    
    # Backbone encoder
    if backbone == 'mobilenetv2':
        base_model = tf.keras.applications.MobileNetV2(
            input_tensor=inputs, weights='imagenet', include_top=False
        )
        
        # Extract features at different scales
        low_level_features = base_model.get_layer('block_1_expand_relu').output  # 1/4
        high_level_features = base_model.get_layer('out_relu').output  # 1/32
        
    else:  # Custom lightweight encoder for our synthetic data
        # Encoder
        x = layers.Conv2D(32, 3, strides=2, padding='same', activation='relu')(inputs)
        x = layers.BatchNormalization()(x)
        low_level_features = x  # 1/2
        
        x = layers.Conv2D(64, 3, strides=2, padding='same', activation='relu')(x)
        x = layers.BatchNormalization()(x)
        
        x = layers.Conv2D(128, 3, strides=2, padding='same', activation='relu')(x)
        x = layers.BatchNormalization()(x)
        
        x = layers.Conv2D(256, 3, strides=2, padding='same', activation='relu')(x)
        x = layers.BatchNormalization()(x)
        
        high_level_features = x  # 1/16
    
    # ASPP module
    aspp_features = atrous_spatial_pyramid_pooling(high_level_features)
    
    # Upsample ASPP features
    aspp_upsampled = tf.image.resize(aspp_features, 
                                   tf.shape(low_level_features)[1:3])
    
    # Process low-level features
    low_level_processed = layers.Conv2D(48, 1, activation='relu', 
                                      padding='same')(low_level_features)
    
    # Combine features
    combined = layers.Concatenate()([aspp_upsampled, low_level_processed])
    
    # Decoder
    x = layers.Conv2D(256, 3, activation='relu', padding='same')(combined)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.1)(x)
    
    x = layers.Conv2D(256, 3, activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.1)(x)
    
    # Final upsampling and output
    x = tf.image.resize(x, tf.shape(inputs)[1:3])
    outputs = layers.Conv2D(num_classes, 1, activation='softmax', 
                          padding='same')(x)
    
    model = models.Model(inputs, outputs, name='DeepLab-v3+')
    return model

# PSPNet (Pyramid Scene Parsing Network)
def pyramid_pooling_module(inputs, pool_sizes=[1, 2, 3, 6], filters=512):
    """Pyramid pooling module"""
    
    feature_maps = []
    
    for pool_size in pool_sizes:
        # Global average pooling
        pooled = layers.AveragePooling2D(pool_size=pool_size, 
                                       strides=pool_size)(inputs)
        
        # 1x1 conv to reduce channels
        conv = layers.Conv2D(filters // len(pool_sizes), 1, 
                           activation='relu', padding='same')(pooled)
        
        # Upsample back to original size
        upsampled = tf.image.resize(conv, tf.shape(inputs)[1:3])
        feature_maps.append(upsampled)
    
    # Add original features
    feature_maps.append(inputs)
    
    # Concatenate all features
    concatenated = layers.Concatenate()(feature_maps)
    
    # Final convolution
    output = layers.Conv2D(filters, 3, activation='relu', 
                         padding='same')(concatenated)
    
    return output

def build_pspnet(input_shape, num_classes, backbone_filters=512):
    """Build PSPNet model"""
    
    inputs = layers.Input(shape=input_shape)
    
    # Backbone (simplified ResNet-like)
    x = layers.Conv2D(64, 7, strides=2, padding='same', activation='relu')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D(3, strides=2, padding='same')(x)
    
    # Residual blocks (simplified)
    for filters in [64, 128, 256, backbone_filters]:
        residual = x
        
        x = layers.Conv2D(filters, 3, padding='same', activation='relu')(x)
        x = layers.BatchNormalization()(x)
        x = layers.Conv2D(filters, 3, padding='same', activation='relu')(x)
        x = layers.BatchNormalization()(x)
        
        # Adjust residual if needed
        if residual.shape[-1] != filters:
            residual = layers.Conv2D(filters, 1, padding='same')(residual)
        
        x = layers.Add()([x, residual])
        x = layers.Activation('relu')(x)
        
        # Downsample for some blocks
        if filters in [128, 256]:
            x = layers.MaxPooling2D(2)(x)
    
    # Pyramid pooling module
    pyramid_features = pyramid_pooling_module(x, filters=backbone_filters)
    
    # Classification head
    x = layers.Conv2D(256, 3, activation='relu', padding='same')(pyramid_features)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.1)(x)
    
    # Upsample and output
    x = tf.image.resize(x, tf.shape(inputs)[1:3])
    outputs = layers.Conv2D(num_classes, 1, activation='softmax', 
                          padding='same')(x)
    
    model = models.Model(inputs, outputs, name='PSPNet')
    return model

# Build advanced models
print("\n=== Building Advanced Segmentation Models ===")

# DeepLab v3+
deeplab_model = build_deeplab_v3plus(input_shape, num_classes, backbone='custom')
print(f"DeepLab v3+ model built: {deeplab_model.count_params():,} parameters")

# PSPNet
pspnet_model = build_pspnet(input_shape, num_classes)
print(f"PSPNet model built: {pspnet_model.count_params():,} parameters")

# Model comparison
models_dict = {
    'U-Net': unet_model,
    'Attention U-Net': attention_unet_model,
    'DeepLab v3+': deeplab_model,
    'PSPNet': pspnet_model
}

print("\nModel Comparison:")
for name, model in models_dict.items():
    print(f"{name}: {model.count_params():,} parameters")
```

## 4. Advanced Loss Functions and Metrics

```python
# Segmentation-specific loss functions
class SegmentationLosses:
    """Collection of segmentation loss functions"""
    
    @staticmethod
    def dice_coefficient(y_true, y_pred, smooth=1e-6):
        """Dice coefficient for binary segmentation"""
        y_true_flat = tf.reshape(y_true, [-1])
        y_pred_flat = tf.reshape(y_pred, [-1])
        
        intersection = tf.reduce_sum(y_true_flat * y_pred_flat)
        union = tf.reduce_sum(y_true_flat) + tf.reduce_sum(y_pred_flat)
        
        dice = (2.0 * intersection + smooth) / (union + smooth)
        return dice
    
    @staticmethod
    def dice_loss(y_true, y_pred, smooth=1e-6):
        """Dice loss function"""
        return 1 - SegmentationLosses.dice_coefficient(y_true, y_pred, smooth)
    
    @staticmethod
    def tversky_coefficient(y_true, y_pred, alpha=0.7, beta=0.3, smooth=1e-6):
        """Tversky coefficient - generalization of Dice"""
        y_true_flat = tf.reshape(y_true, [-1])
        y_pred_flat = tf.reshape(y_pred, [-1])
        
        true_pos = tf.reduce_sum(y_true_flat * y_pred_flat)
        false_neg = tf.reduce_sum(y_true_flat * (1 - y_pred_flat))
        false_pos = tf.reduce_sum((1 - y_true_flat) * y_pred_flat)
        
        tversky = (true_pos + smooth) / (true_pos + alpha * false_neg + beta * false_pos + smooth)
        return tversky
    
    @staticmethod
    def tversky_loss(y_true, y_pred, alpha=0.7, beta=0.3):
        """Tversky loss - good for imbalanced datasets"""
        return 1 - SegmentationLosses.tversky_coefficient(y_true, y_pred, alpha, beta)
    
    @staticmethod
    def focal_loss(y_true, y_pred, alpha=0.8, gamma=2.0):
        """Focal loss for handling class imbalance"""
        epsilon = tf.keras.backend.epsilon()
        y_pred = tf.clip_by_value(y_pred, epsilon, 1 - epsilon)
        
        alpha_factor = y_true * alpha + (1 - y_true) * (1 - alpha)
        focal_weight = y_true * (1 - y_pred) ** gamma + (1 - y_true) * y_pred ** gamma
        
        focal_loss = alpha_factor * focal_weight * (-y_true * tf.math.log(y_pred) - 
                                                   (1 - y_true) * tf.math.log(1 - y_pred))
        
        return tf.reduce_mean(focal_loss)
    
    @staticmethod
    def combo_loss(y_true, y_pred, alpha=0.5, ce_ratio=0.5):
        """Combination of Dice loss and Cross-entropy"""
        dice = SegmentationLosses.dice_loss(y_true, y_pred)
        ce = tf.keras.losses.binary_crossentropy(y_true, y_pred)
        
        combo = (ce_ratio * ce) + ((1 - ce_ratio) * dice)
        return combo
    
    @staticmethod
    def boundary_loss(y_true, y_pred):
        """Boundary loss for precise edge segmentation"""
        # Compute gradients to detect boundaries
        def compute_gradient(tensor):
            dy, dx = tf.image.image_gradients(tensor)
            return tf.sqrt(dx**2 + dy**2)
        
        true_grad = compute_gradient(y_true)
        pred_grad = compute_gradient(y_pred)
        
        boundary_loss = tf.reduce_mean(tf.square(true_grad - pred_grad))
        return boundary_loss

# Segmentation metrics
class SegmentationMetrics:
    """Collection of segmentation evaluation metrics"""
    
    @staticmethod
    def iou_coefficient(y_true, y_pred, smooth=1e-6):
        """Intersection over Union (IoU) / Jaccard Index"""
        y_true_flat = tf.reshape(y_true, [-1])
        y_pred_flat = tf.reshape(y_pred, [-1])
        
        intersection = tf.reduce_sum(y_true_flat * y_pred_flat)
        union = tf.reduce_sum(y_true_flat) + tf.reduce_sum(y_pred_flat) - intersection
        
        iou = (intersection + smooth) / (union + smooth)
        return iou
    
    @staticmethod
    def mean_iou(y_true, y_pred, num_classes):
        """Mean IoU across all classes"""
        ious = []
        
        for class_idx in range(num_classes):
            true_class = tf.cast(tf.equal(y_true, class_idx), tf.float32)
            pred_class = tf.cast(tf.equal(tf.argmax(y_pred, axis=-1), class_idx), tf.float32)
            
            iou = SegmentationMetrics.iou_coefficient(true_class, pred_class)
            ious.append(iou)
        
        return tf.reduce_mean(ious)
    
    @staticmethod
    def pixel_accuracy(y_true, y_pred):
        """Pixel-wise accuracy"""
        pred_classes = tf.argmax(y_pred, axis=-1)
        correct_pixels = tf.cast(tf.equal(y_true, pred_classes), tf.float32)
        accuracy = tf.reduce_mean(correct_pixels)
        return accuracy
    
    @staticmethod
    def frequency_weighted_iou(y_true, y_pred, num_classes):
        """Frequency weighted IoU"""
        total_pixels = tf.cast(tf.size(y_true), tf.float32)
        ious = []
        frequencies = []
        
        for class_idx in range(num_classes):
            true_class = tf.cast(tf.equal(y_true, class_idx), tf.float32)
            pred_class = tf.cast(tf.equal(tf.argmax(y_pred, axis=-1), class_idx), tf.float32)
            
            iou = SegmentationMetrics.iou_coefficient(true_class, pred_class)
            frequency = tf.reduce_sum(true_class) / total_pixels
            
            ious.append(iou)
            frequencies.append(frequency)
        
        ious = tf.stack(ious)
        frequencies = tf.stack(frequencies)
        
        weighted_iou = tf.reduce_sum(frequencies * ious)
        return weighted_iou

# Custom metric classes for Keras
class MeanIoU(tf.keras.metrics.Metric):
    """Mean IoU metric for Keras"""
    
    def __init__(self, num_classes, name='mean_iou', **kwargs):
        super().__init__(name=name, **kwargs)
        self.num_classes = num_classes
        self.total_iou = self.add_weight(name='total_iou', initializer='zeros')
        self.count = self.add_weight(name='count', initializer='zeros')
    
    def update_state(self, y_true, y_pred, sample_weight=None):
        mean_iou = SegmentationMetrics.mean_iou(y_true, y_pred, self.num_classes)
        self.total_iou.assign_add(mean_iou)
        self.count.assign_add(1)
    
    def result(self):
        return self.total_iou / self.count
    
    def reset_state(self):
        self.total_iou.assign(0)
        self.count.assign(0)

class DiceCoefficient(tf.keras.metrics.Metric):
    """Dice coefficient metric for Keras"""
    
    def __init__(self, name='dice_coefficient', **kwargs):
        super().__init__(name=name, **kwargs)
        self.total_dice = self.add_weight(name='total_dice', initializer='zeros')
        self.count = self.add_weight(name='count', initializer='zeros')
    
    def update_state(self, y_true, y_pred, sample_weight=None):
        dice = SegmentationLosses.dice_coefficient(y_true, y_pred)
        self.total_dice.assign_add(dice)
        self.count.assign_add(1)
    
    def result(self):
        return self.total_dice / self.count
    
    def reset_state(self):
        self.total_dice.assign(0)
        self.count.assign(0)

# Test loss functions and metrics
print("\n=== Testing Loss Functions and Metrics ===")

# Create sample predictions for testing
sample_true = tf.random.uniform([8, 64, 64], maxval=4, dtype=tf.int32)
sample_pred = tf.random.uniform([8, 64, 64, 4])

# Test metrics
mean_iou_metric = MeanIoU(num_classes=4)
mean_iou_metric.update_state(sample_true, sample_pred)
print(f"Mean IoU: {mean_iou_metric.result():.4f}")

dice_metric = DiceCoefficient()
sample_true_binary = tf.cast(sample_true > 0, tf.float32)
sample_pred_binary = sample_pred[:, :, :, 1]  # Take one class
dice_metric.update_state(sample_true_binary, sample_pred_binary)
print(f"Dice Coefficient: {dice_metric.result():.4f}")

# Test loss functions
dice_loss_val = SegmentationLosses.dice_loss(sample_true_binary, sample_pred_binary)
focal_loss_val = SegmentationLosses.focal_loss(sample_true_binary, sample_pred_binary)
print(f"Dice Loss: {dice_loss_val:.4f}")
print(f"Focal Loss: {focal_loss_val:.4f}")
```

## 5. Model Training and Evaluation

```python
# Prepare data for multi-class segmentation
def prepare_multiclass_data(X, y, num_classes):
    """Convert masks to one-hot encoding"""
    y_categorical = utils.to_categorical(y, num_classes=num_classes)
    return X, y_categorical

# Prepare training data
X_train_prep, y_train_prep = prepare_multiclass_data(X_train, y_train, num_classes)
X_val_prep, y_val_prep = prepare_multiclass_data(X_val, y_val, num_classes)
X_test_prep, y_test_prep = prepare_multiclass_data(X_test, y_test, num_classes)

print(f"Prepared data shapes:")
print(f"X_train: {X_train_prep.shape}, y_train: {y_train_prep.shape}")
print(f"X_val: {X_val_prep.shape}, y_val: {y_val_prep.shape}")

# Training configuration
def compile_segmentation_model(model, loss_type='categorical_crossentropy', learning_rate=0.001):
    """Compile segmentation model with appropriate loss and metrics"""
    
    # Select loss function
    if loss_type == 'dice':
        loss_fn = SegmentationLosses.dice_loss
    elif loss_type == 'focal':
        loss_fn = SegmentationLosses.focal_loss
    elif loss_type == 'combo':
        loss_fn = SegmentationLosses.combo_loss
    else:
        loss_fn = loss_type
    
    # Compile model
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        loss=loss_fn,
        metrics=[
            'accuracy',
            MeanIoU(num_classes=num_classes),
            tf.keras.metrics.CategoricalAccuracy()
        ]
    )
    
    return model

# Training callbacks
def get_training_callbacks(model_name, patience=7):
    """Get training callbacks for segmentation models"""
    
    callbacks = [
        tf.keras.callbacks.EarlyStopping(
            monitor='val_mean_iou',
            patience=patience,
            restore_best_weights=True,
            mode='max',
            verbose=1
        ),
        tf.keras.callbacks.ReduceLROnPlateau(
            monitor='val_mean_iou',
            factor=0.5,
            patience=3,
            min_lr=1e-7,
            mode='max',
            verbose=1
        ),
        tf.keras.callbacks.ModelCheckpoint(
            f'{model_name}_best.h5',
            monitor='val_mean_iou',
            save_best_only=True,
            mode='max',
            verbose=1
        ),
        tf.keras.callbacks.CSVLogger(f'{model_name}_training.log')
    ]
    
    return callbacks

# Train models
def train_segmentation_model(model, model_name, X_train, y_train, X_val, y_val, 
                           epochs=30, batch_size=16):
    """Train segmentation model"""
    
    print(f"\n=== Training {model_name} ===")
    
    # Compile model
    model = compile_segmentation_model(model, loss_type='categorical_crossentropy')
    
    # Get callbacks
    callbacks = get_training_callbacks(model_name)
    
    # Train model
    history = model.fit(
        X_train, y_train,
        validation_data=(X_val, y_val),
        epochs=epochs,
        batch_size=batch_size,
        callbacks=callbacks,
        verbose=1
    )
    
    return model, history

# Model comparison training
training_results = {}

# Train U-Net (fastest to train for demonstration)
unet_trained, unet_history = train_segmentation_model(
    unet_model, 'unet', X_train_prep, y_train_prep, X_val_prep, y_val_prep,
    epochs=5, batch_size=8  # Reduced for demo
)
training_results['U-Net'] = unet_history

print("U-Net training completed!")

# Evaluate models
def evaluate_segmentation_model(model, X_test, y_test, model_name):
    """Comprehensive model evaluation"""
    
    print(f"\n=== Evaluating {model_name} ===")
    
    # Make predictions
    predictions = model.predict(X_test, verbose=0)
    
    # Calculate metrics
    test_loss, test_acc, test_miou, test_cat_acc = model.evaluate(
        X_test, y_test, verbose=0
    )
    
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_acc:.4f}")
    print(f"Test Mean IoU: {test_miou:.4f}")
    print(f"Test Categorical Accuracy: {test_cat_acc:.4f}")
    
    # Visualize predictions
    visualize_segmentation(X_test[:4], np.argmax(y_test[:4], axis=-1), 
                         predictions[:4], num_samples=4)
    
    return {
        'loss': test_loss,
        'accuracy': test_acc,
        'mean_iou': test_miou,
        'categorical_accuracy': test_cat_acc,
        'predictions': predictions
    }

# Evaluate U-Net
unet_results = evaluate_segmentation_model(
    unet_trained, X_test_prep, y_test_prep, 'U-Net'
)

# Training history visualization
def plot_training_history(history, model_name):
    """Plot training history"""
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Loss
    axes[0, 0].plot(history.history['loss'], label='Train Loss')
    axes[0, 0].plot(history.history['val_loss'], label='Val Loss')
    axes[0, 0].set_title(f'{model_name} - Loss')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True)
    
    # Accuracy
    axes[0, 1].plot(history.history['accuracy'], label='Train Accuracy')
    axes[0, 1].plot(history.history['val_accuracy'], label='Val Accuracy')
    axes[0, 1].set_title(f'{model_name} - Accuracy')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].legend()
    axes[0, 1].grid(True)
    
    # Mean IoU
    axes[1, 0].plot(history.history['mean_iou'], label='Train Mean IoU')
    axes[1, 0].plot(history.history['val_mean_iou'], label='Val Mean IoU')
    axes[1, 0].set_title(f'{model_name} - Mean IoU')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Mean IoU')
    axes[1, 0].legend()
    axes[1, 0].grid(True)
    
    # Learning Rate (if available)
    if 'lr' in history.history:
        axes[1, 1].plot(history.history['lr'], label='Learning Rate')
        axes[1, 1].set_title(f'{model_name} - Learning Rate')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Learning Rate')
        axes[1, 1].set_yscale('log')
        axes[1, 1].legend()
        axes[1, 1].grid(True)
    else:
        axes[1, 1].axis('off')
    
    plt.tight_layout()
    plt.show()

# Plot training history
plot_training_history(unet_history, 'U-Net')
```

## 6. Real-World Application: Medical Image Segmentation

```python
# Medical image segmentation example
def create_medical_dataset(num_samples=200, image_size=256):
    """Create synthetic medical imaging dataset"""
    
    print("Creating synthetic medical dataset (simulating organ segmentation)...")
    
    images = []
    masks = []
    
    for i in range(num_samples):
        # Create medical-like image with organ structures
        img = np.random.normal(0.3, 0.1, (image_size, image_size, 1))
        img = np.clip(img, 0, 1)
        
        # Create organ-like structures
        mask = np.zeros((image_size, image_size), dtype=np.uint8)
        
        # Add organ structures (simplified)
        # Large organ (like liver)
        center1 = (image_size//3, image_size//3)
        axes1 = (image_size//6, image_size//8)
        
        # Create elliptical structure
        y, x = np.ogrid[:image_size, :image_size]
        ellipse1 = ((x - center1[0])/axes1[0])**2 + ((y - center1[1])/axes1[1])**2 <= 1
        mask[ellipse1] = 1  # Organ class
        img[ellipse1] = img[ellipse1] + 0.2
        
        # Smaller organ (like kidney)
        center2 = (2*image_size//3, 2*image_size//3)
        axes2 = (image_size//12, image_size//10)
        
        ellipse2 = ((x - center2[0])/axes2[0])**2 + ((y - center2[1])/axes2[1])**2 <= 1
        mask[ellipse2] = 2  # Different organ class
        img[ellipse2] = img[ellipse2] + 0.3
        
        # Add noise and contrast variations
        img = img + np.random.normal(0, 0.05, img.shape)
        img = np.clip(img, 0, 1)
        
        # Convert to 3-channel for consistency
        img = np.repeat(img, 3, axis=-1)
        
        images.append(img)
        masks.append(mask)
    
    images = np.array(images, dtype=np.float32)
    masks = np.array(masks, dtype=np.uint8)
    
    print(f"Medical dataset created: {images.shape} images, {masks.shape} masks")
    print(f"Classes: {np.unique(masks)} (0=background, 1=organ1, 2=organ2)")
    
    return images, masks

# Create medical dataset
med_images, med_masks = create_medical_dataset(num_samples=150, image_size=128)

# Split medical data
med_X_train, med_X_test, med_y_train, med_y_test = train_test_split(
    med_images, med_masks, test_size=0.2, random_state=42
)

# Visualize medical data
print("Medical segmentation dataset:")
visualize_segmentation(med_X_train, med_y_train, num_samples=3)

# Build specialized medical segmentation model
def build_medical_unet(input_shape, num_classes, filters=64):
    """Build U-Net optimized for medical imaging"""
    
    inputs = layers.Input(shape=input_shape)
    
    # Encoder with more aggressive pooling for medical images
    e1, p1 = encoder_block(inputs, filters, pool=True, dropout_rate=0.1)
    e2, p2 = encoder_block(p1, filters * 2, pool=True, dropout_rate=0.1)
    e3, p3 = encoder_block(p2, filters * 4, pool=True, dropout_rate=0.2)
    e4, p4 = encoder_block(p3, filters * 8, pool=True, dropout_rate=0.2)
    
    # Bottleneck with higher capacity
    bottleneck = encoder_block(p4, filters * 16, pool=False, dropout_rate=0.3)
    
    # Decoder with residual connections
    d4 = decoder_block(bottleneck, e4, filters * 8, dropout_rate=0.2)
    d3 = decoder_block(d4, e3, filters * 4, dropout_rate=0.2)
    d2 = decoder_block(d3, e2, filters * 2, dropout_rate=0.1)
    d1 = decoder_block(d2, e1, filters, dropout_rate=0.1)
    
    # Multi-scale output supervision (deep supervision)
    # Main output
    main_output = layers.Conv2D(num_classes, 1, activation='softmax', 
                              padding='same', name='main_output')(d1)
    
    # Auxiliary outputs for deep supervision
    aux_output1 = layers.Conv2D(num_classes, 1, activation='softmax', 
                               padding='same', name='aux_output1')(d2)
    aux_output1 = tf.image.resize(aux_output1, tf.shape(main_output)[1:3])
    
    aux_output2 = layers.Conv2D(num_classes, 1, activation='softmax', 
                               padding='same', name='aux_output2')(d3)
    aux_output2 = tf.image.resize(aux_output2, tf.shape(main_output)[1:3])
    
    model = models.Model(inputs, [main_output, aux_output1, aux_output2], 
                        name='Medical-U-Net')
    return model

# Build medical model
med_num_classes = len(np.unique(med_y_train))
med_input_shape = med_X_train.shape[1:]

medical_unet = build_medical_unet(med_input_shape, med_num_classes, filters=32)
print(f"Medical U-Net built: {medical_unet.count_params():,} parameters")

# Prepare medical data
med_X_train_prep, med_y_train_prep = prepare_multiclass_data(med_X_train, med_y_train, med_num_classes)
med_X_test_prep, med_y_test_prep = prepare_multiclass_data(med_X_test, med_y_test, med_num_classes)

# Custom loss for deep supervision
def deep_supervision_loss(y_true, y_pred_list, weights=[1.0, 0.5, 0.25]):
    """Deep supervision loss combining multiple outputs"""
    
    total_loss = 0
    for i, (y_pred, weight) in enumerate(zip(y_pred_list, weights)):
        loss = tf.keras.losses.categorical_crossentropy(y_true, y_pred)
        total_loss += weight * loss
    
    return total_loss

# Compile medical model with custom loss
medical_unet.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    loss={
        'main_output': 'categorical_crossentropy',
        'aux_output1': 'categorical_crossentropy', 
        'aux_output2': 'categorical_crossentropy'
    },
    loss_weights={
        'main_output': 1.0,
        'aux_output1': 0.5,
        'aux_output2': 0.25
    },
    metrics={
        'main_output': [MeanIoU(num_classes=med_num_classes)],
        'aux_output1': [MeanIoU(num_classes=med_num_classes)],
        'aux_output2': [MeanIoU(num_classes=med_num_classes)]
    }
)

# Train medical model (quick demo)
print("\n=== Training Medical Segmentation Model ===")

med_callbacks = [
    tf.keras.callbacks.EarlyStopping(
        monitor='val_main_output_mean_iou',
        patience=5,
        restore_best_weights=True,
        mode='max'
    ),
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_main_output_mean_iou',
        factor=0.5,
        patience=2,
        min_lr=1e-7,
        mode='max'
    )
]

# Prepare target data for multiple outputs
med_y_train_dict = {
    'main_output': med_y_train_prep,
    'aux_output1': med_y_train_prep,
    'aux_output2': med_y_train_prep
}

med_y_test_dict = {
    'main_output': med_y_test_prep,
    'aux_output1': med_y_test_prep,
    'aux_output2': med_y_test_prep
}

med_history = medical_unet.fit(
    med_X_train_prep, med_y_train_dict,
    validation_data=(med_X_test_prep, med_y_test_dict),
    epochs=3,  # Reduced for demo
    batch_size=4,
    callbacks=med_callbacks,
    verbose=1
)

# Evaluate medical model
med_predictions = medical_unet.predict(med_X_test_prep, verbose=0)
main_predictions = med_predictions[0]  # Use main output

# Visualize medical segmentation results
print("\nMedical Segmentation Results:")
visualize_segmentation(med_X_test[:4], med_y_test[:4], 
                      main_predictions[:4], num_samples=4)

# Calculate medical-specific metrics
def calculate_medical_metrics(y_true, y_pred, class_names=['Background', 'Organ1', 'Organ2']):
    """Calculate medical imaging specific metrics"""
    
    metrics = {}
    
    for class_idx, class_name in enumerate(class_names):
        # Binary masks for current class
        true_binary = (y_true == class_idx).astype(np.float32)
        pred_binary = (np.argmax(y_pred, axis=-1) == class_idx).astype(np.float32)
        
        # Dice coefficient
        dice = SegmentationLosses.dice_coefficient(true_binary, pred_binary).numpy()
        
        # Sensitivity (Recall)
        true_pos = np.sum(true_binary * pred_binary)
        false_neg = np.sum(true_binary * (1 - pred_binary))
        sensitivity = true_pos / (true_pos + false_neg + 1e-6)
        
        # Specificity
        true_neg = np.sum((1 - true_binary) * (1 - pred_binary))
        false_pos = np.sum((1 - true_binary) * pred_binary)
        specificity = true_neg / (true_neg + false_pos + 1e-6)
        
        metrics[class_name] = {
            'dice': dice,
            'sensitivity': sensitivity,
            'specificity': specificity
        }
    
    return metrics

# Calculate medical metrics
medical_metrics = calculate_medical_metrics(med_y_test, main_predictions)

print("\nMedical Segmentation Metrics:")
for class_name, metrics in medical_metrics.items():
    print(f"{class_name}:")
    print(f"  Dice: {metrics['dice']:.4f}")
    print(f"  Sensitivity: {metrics['sensitivity']:.4f}")
    print(f"  Specificity: {metrics['specificity']:.4f}")
```

## 7. Model Deployment and Inference

```python
# Model deployment utilities
class SegmentationInference:
    """Segmentation model inference pipeline"""
    
    def __init__(self, model_path, input_shape, num_classes):
        self.model = tf.keras.models.load_model(model_path, compile=False)
        self.input_shape = input_shape
        self.num_classes = num_classes
        
    def preprocess_image(self, image_path):
        """Preprocess single image for inference"""
        
        if isinstance(image_path, str):
            image = cv2.imread(image_path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        else:
            image = image_path
        
        # Resize to model input size
        image = cv2.resize(image, self.input_shape[:2])
        image = image.astype(np.float32) / 255.0
        
        # Add batch dimension
        image = np.expand_dims(image, axis=0)
        
        return image
    
    def predict(self, image):
        """Make prediction on single image"""
        
        preprocessed = self.preprocess_image(image)
        prediction = self.model.predict(preprocessed, verbose=0)
        
        # Handle multiple outputs (deep supervision)
        if isinstance(prediction, list):
            prediction = prediction[0]  # Use main output
        
        return prediction[0]  # Remove batch dimension
    
    def predict_batch(self, images):
        """Make predictions on batch of images"""
        
        batch = np.stack([self.preprocess_image(img)[0] for img in images])
        predictions = self.model.predict(batch, verbose=0)
        
        if isinstance(predictions, list):
            predictions = predictions[0]
        
        return predictions
    
    def postprocess_prediction(self, prediction, threshold=0.5):
        """Postprocess prediction to get final mask"""
        
        if self.num_classes > 1:
            # Multi-class: use argmax
            mask = np.argmax(prediction, axis=-1)
        else:
            # Binary: use threshold
            mask = (prediction[:, :, 0] > threshold).astype(np.uint8)
        
        return mask.astype(np.uint8)
    
    def visualize_prediction(self, image, prediction, save_path=None):
        """Visualize prediction result"""
        
        mask = self.postprocess_prediction(prediction)
        
        # Create overlay
        if isinstance(image, str):
            image = cv2.imread(image)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        image = cv2.resize(image, self.input_shape[:2])
        overlay = image.copy()
        
        # Apply colormap to mask
        colored_mask = plt.cm.viridis(mask / mask.max())[:, :, :3]
        overlay = (0.7 * image/255.0 + 0.3 * colored_mask * 255).astype(np.uint8)
        
        # Plot results
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        axes[0].imshow(image)
        axes[0].set_title('Original Image')
        axes[0].axis('off')
        
        axes[1].imshow(mask, cmap='viridis')
        axes[1].set_title('Segmentation Mask')
        axes[1].axis('off')
        
        axes[2].imshow(overlay)
        axes[2].set_title('Overlay')
        axes[2].axis('off')
        
        if save_path:
            plt.savefig(save_path, bbox_inches='tight', dpi=150)
        
        plt.show()

# Save trained model for deployment
unet_trained.save('segmentation_unet_model.h5')
print("Model saved for deployment!")

# Create inference pipeline
inference_pipeline = SegmentationInference(
    model_path='segmentation_unet_model.h5',
    input_shape=input_shape,
    num_classes=num_classes
)

# Test inference on new data
test_image = X_test[0]
prediction = inference_pipeline.predict(test_image)

print("Inference test completed!")
inference_pipeline.visualize_prediction(test_image, prediction)

# Batch inference example
batch_predictions = inference_pipeline.predict_batch(X_test[:4])
print(f"Batch inference completed: {batch_predictions.shape}")

# Performance optimization for deployment
@tf.function
def optimized_inference(model, inputs):
    """TensorFlow function for optimized inference"""
    return model(inputs, training=False)

# Convert to TensorFlow Lite for mobile deployment
def convert_to_tflite(model, save_path):
    """Convert model to TensorFlow Lite format"""
    
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    
    # Optimization settings
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.target_spec.supported_types = [tf.float16]
    
    # Convert model
    tflite_model = converter.convert()
    
    # Save model
    with open(save_path, 'wb') as f:
        f.write(tflite_model)
    
    print(f"TFLite model saved to: {save_path}")
    
    return tflite_model

# Convert U-Net to TFLite
tflite_model = convert_to_tflite(unet_trained, 'segmentation_model.tflite')

# TensorFlow Lite inference example
def tflite_inference(tflite_model_path, input_image):
    """Run inference with TensorFlow Lite model"""
    
    # Load TFLite model
    interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
    interpreter.allocate_tensors()
    
    # Get input and output tensors
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    
    # Preprocess input
    input_shape = input_details[0]['shape']
    input_image = cv2.resize(input_image, (input_shape[1], input_shape[2]))
    input_image = np.expand_dims(input_image, axis=0).astype(np.float32)
    
    # Set input tensor
    interpreter.set_tensor(input_details[0]['index'], input_image)
    
    # Run inference
    interpreter.invoke()
    
    # Get output
    output_data = interpreter.get_tensor(output_details[0]['index'])
    
    return output_data[0]

# Test TFLite inference
test_tflite_pred = tflite_inference('segmentation_model.tflite', X_test[0])
print(f"TFLite inference shape: {test_tflite_pred.shape}")

# Production deployment utilities
class SegmentationAPI:
    """REST API wrapper for segmentation model"""
    
    def __init__(self, model_path, input_shape, num_classes):
        self.inference = SegmentationInference(model_path, input_shape, num_classes)
        
    def create_flask_app(self):
        """Create Flask API for segmentation model"""
        
        from flask import Flask, request, jsonify
        import base64
        from io import BytesIO
        from PIL import Image
        
        app = Flask(__name__)
        
        @app.route('/health', methods=['GET'])
        def health_check():
            return jsonify({'status': 'healthy'})
        
        @app.route('/segment', methods=['POST'])
        def segment_image():
            try:
                # Get image from request
                data = request.json
                image_data = base64.b64decode(data['image'])
                image = Image.open(BytesIO(image_data))
                image = np.array(image)
                
                # Make prediction
                prediction = self.inference.predict(image)
                mask = self.inference.postprocess_prediction(prediction)
                
                # Convert mask to base64
                mask_pil = Image.fromarray((mask * 255 / mask.max()).astype(np.uint8))
                buffer = BytesIO()
                mask_pil.save(buffer, format='PNG')
                mask_b64 = base64.b64encode(buffer.getvalue()).decode()
                
                return jsonify({
                    'status': 'success',
                    'mask': mask_b64,
                    'shape': mask.shape,
                    'classes_found': np.unique(mask).tolist()
                })
                
            except Exception as e:
                return jsonify({'status': 'error', 'message': str(e)})
        
        return app

# Create API instance
api = SegmentationAPI('segmentation_unet_model.h5', input_shape, num_classes)
flask_app = api.create_flask_app()

print("Flask API created for segmentation model!")

# Performance benchmarking
def benchmark_model(model, test_data, num_runs=50):
    """Benchmark model inference speed"""
    
    import time
    
    # Warm up
    for _ in range(5):
        _ = model.predict(test_data[:1], verbose=0)
    
    # Benchmark
    times = []
    for _ in range(num_runs):
        start_time = time.time()
        _ = model.predict(test_data[:1], verbose=0)
        end_time = time.time()
        times.append(end_time - start_time)
    
    avg_time = np.mean(times)
    std_time = np.std(times)
    
    print(f"Average inference time: {avg_time*1000:.2f} ± {std_time*1000:.2f} ms")
    print(f"Throughput: {1/avg_time:.2f} images/second")
    
    return avg_time, std_time

# Benchmark U-Net
print("\n=== Model Benchmarking ===")
avg_time, std_time = benchmark_model(unet_trained, X_test[:10])

# Model comparison summary
def create_model_comparison_report():
    """Create comprehensive model comparison report"""
    
    comparison_data = {
        'Model': ['U-Net', 'Attention U-Net', 'DeepLab v3+', 'PSPNet'],
        'Parameters': [
            unet_model.count_params(),
            attention_unet_model.count_params(), 
            deeplab_model.count_params(),
            pspnet_model.count_params()
        ],
        'Architecture': ['Encoder-Decoder', 'Attention + Skip', 'ASPP + Decoder', 'Pyramid Pooling'],
        'Best_For': [
            'General segmentation',
            'Fine detail preservation', 
            'Multi-scale features',
            'Scene parsing'
        ],
        'Complexity': ['Low', 'Medium', 'High', 'Medium']
    }
    
    import pandas as pd
    df = pd.DataFrame(comparison_data)
    
    print("\n=== Model Architecture Comparison ===")
    print(df.to_string(index=False))
    
    return df

comparison_report = create_model_comparison_report()

print("\n=== Advanced Applications ===")

# Instance segmentation preview
def build_mask_rcnn_backbone(input_shape):
    """Simple Mask R-CNN inspired backbone"""
    
    inputs = layers.Input(shape=input_shape)
    
    # Feature pyramid backbone
    # Level 1
    c1 = layers.Conv2D(64, 3, padding='same', activation='relu')(inputs)
    c1 = layers.Conv2D(64, 3, padding='same', activation='relu')(c1)
    p1 = layers.MaxPooling2D(2)(c1)
    
    # Level 2  
    c2 = layers.Conv2D(128, 3, padding='same', activation='relu')(p1)
    c2 = layers.Conv2D(128, 3, padding='same', activation='relu')(c2)
    p2 = layers.MaxPooling2D(2)(c2)
    
    # Level 3
    c3 = layers.Conv2D(256, 3, padding='same', activation='relu')(p2)
    c3 = layers.Conv2D(256, 3, padding='same', activation='relu')(c3)
    p3 = layers.MaxPooling2D(2)(c3)
    
    # Simplified mask head (normally would include RPN + ROI pooling)
    mask_head = layers.Conv2D(256, 3, padding='same', activation='relu')(p3)
    mask_head = layers.Conv2DTranspose(128, 2, strides=2, padding='same', activation='relu')(mask_head)
    mask_head = layers.Conv2DTranspose(64, 2, strides=2, padding='same', activation='relu')(mask_head)
    mask_head = layers.Conv2DTranspose(32, 2, strides=2, padding='same', activation='relu')(mask_head)
    
    # Output
    masks = layers.Conv2D(num_classes, 1, activation='sigmoid', padding='same')(mask_head)
    
    model = models.Model(inputs, masks, name='SimpleInstanceSeg')
    return model

# Real-time segmentation optimization
def create_lightweight_segmentation_model(input_shape, num_classes):
    """Create lightweight model for real-time applications"""
    
    inputs = layers.Input(shape=input_shape)
    
    # Efficient encoder using depthwise separable convolutions
    def depthwise_conv_block(x, filters, strides=1):
        x = layers.DepthwiseConv2D(3, strides=strides, padding='same', use_bias=False)(x)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)
        
        x = layers.Conv2D(filters, 1, padding='same', use_bias=False)(x)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)
        
        return x
    
    # Encoder
    x = layers.Conv2D(32, 3, strides=2, padding='same', activation='relu')(inputs)
    e1 = depthwise_conv_block(x, 64)
    
    e2 = depthwise_conv_block(e1, 128, strides=2)
    e3 = depthwise_conv_block(e2, 256, strides=2)
    e4 = depthwise_conv_block(e3, 512, strides=2)
    
    # Lightweight decoder
    d3 = layers.Conv2DTranspose(256, 2, strides=2, padding='same', activation='relu')(e4)
    d3 = layers.Add()([d3, e3])
    
    d2 = layers.Conv2DTranspose(128, 2, strides=2, padding='same', activation='relu')(d3)
    d2 = layers.Add()([d2, e2])
    
    d1 = layers.Conv2DTranspose(64, 2, strides=2, padding='same', activation='relu')(d2)
    d1 = layers.Add()([d1, e1])
    
    # Output
    outputs = layers.Conv2DTranspose(num_classes, 2, strides=2, padding='same', 
                                   activation='softmax')(d1)
    
    model = models.Model(inputs, outputs, name='LightweightSegNet')
    return model

# Build lightweight model
lightweight_model = create_lightweight_segmentation_model(input_shape, num_classes)
print(f"Lightweight model: {lightweight_model.count_params():,} parameters")

print("\n=== Deployment Checklist ===")
deployment_checklist = [
    "✓ Model trained and validated",
    "✓ Inference pipeline created", 
    "✓ TensorFlow Lite conversion completed",
    "✓ API wrapper implemented",
    "✓ Performance benchmarked",
    "✓ Lightweight variant available"
]

for item in deployment_checklist:
    print(item)
```

## Summary

This comprehensive image segmentation notebook demonstrates production-ready segmentation with tf.keras:

**Core Architectures Implemented:**
- **U-Net**: Standard encoder-decoder with skip connections for general segmentation
- **Attention U-Net**: Enhanced U-Net with attention gates for fine detail preservation  
- **DeepLab v3+**: ASPP module with multi-scale feature extraction
- **PSPNet**: Pyramid pooling for scene understanding
- **Medical U-Net**: Deep supervision for medical imaging applications

**Advanced Features:**
- **Synchronized data augmentation** for image-mask pairs
- **Comprehensive loss functions**: Dice, Focal, Tversky, Boundary losses
- **Medical-specific metrics**: Sensitivity, Specificity, class-wise Dice coefficients
- **Multi-output training** with deep supervision for improved accuracy

**Production Deployment:**
- **Optimized inference pipeline** with preprocessing and postprocessing
- **TensorFlow Lite conversion** for mobile deployment
- **REST API wrapper** with Flask for web deployment
- **Performance benchmarking** tools for production optimization
- **Lightweight models** for real-time applications

**Key Technical Achievements:**
- **Modular architecture** supporting easy model comparison
- **Medical imaging pipeline** with specialized evaluation metrics
- **Real-time inference** optimization with TensorFlow functions
- **Cross-platform deployment** ready for mobile and web
- **Comprehensive evaluation framework** with visualization tools

The notebook provides a complete foundation for building, training, and deploying segmentation models across various domains including medical imaging, autonomous driving, and general computer vision applications, with production-ready code and deployment strategies.