In [None]:
# 2_DeepLabV3_Training.ipynb

import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, AveragePooling2D, UpSampling2D, concatenate
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.applications import MobileNetV2 # --- ACCURACY IMPROVEMENT: Added MobileNetV2 ---
from tensorflow.keras import backend as K

# --- CONFIGURATION ---
IMG_HEIGHT = 256
IMG_WIDTH = 256
DATA_PATH = '../data/train'
MODEL_SAVE_PATH = '../saved_models/deeplabv3_mobilenet_improved.h5'

# --- ACCURACY IMPROVEMENT: New Metrics and Loss Function Definitions ---
# Intersection over Union (IoU) - Standard segmentation metric
def iou_metric(y_true, y_pred, smooth=1e-6):
    y_pred = K.cast(K.greater(y_pred, 0.5), K.floatx())
    intersection = K.sum(K.abs(y_true * y_pred), axis=[1,2,3])
    union = K.sum(y_true,[1,2,3])+K.sum(y_pred,[1,2,3])-intersection
    iou = K.mean((intersection + smooth) / (union + smooth), axis=0)
    return iou

# Dice Coefficient (F1-Score) - Also used for Dice Loss
def dice_coeff(y_true, y_pred, smooth=1e-6):
    y_pred = K.cast(K.greater(y_pred, 0.5), K.floatx())
    intersection = K.sum(y_true * y_pred, axis=[1,2,3])
    union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])
    dice = K.mean((2. * intersection + smooth) / (union + smooth), axis=0)
    return dice

# Dice Loss: Focuses on overlap, penalizing misclassification of small objects
def dice_loss(y_true, y_pred):
    return 1 - dice_coeff(y_true, y_pred)


# --- DATA LOADER (No changes needed here) ---
def load_data(path):
    images = []
    masks = []
    img_dir = os.path.join(path, 'images')
    mask_dir = os.path.join(path, 'labels') 

    if not os.path.exists(mask_dir):
        mask_dir = os.path.join(path, 'masks') 

    files = os.listdir(img_dir)[:500]
    for file_name in files:
        img_path = os.path.join(img_dir, file_name)
        img = cv2.imread(img_path)
        if img is None: continue

        mask_path = os.path.join(mask_dir, file_name)
        if not os.path.exists(mask_path):
            mask_path = os.path.join(mask_dir, os.path.splitext(file_name)[0] + ".png")
            
        mask = cv2.imread(mask_path, 0)
        if mask is None: continue

        img = cv2.resize(img, (IMG_WIDTH, IMG_HEIGHT))
        mask = cv2.resize(mask, (IMG_WIDTH, IMG_HEIGHT))
        images.append(img)
        masks.append(mask)

    images = np.array(images) / 255.0
    masks = np.array(masks) / 255.0
    masks = np.expand_dims(masks, axis=-1)
    return images, masks

# Load and Split
try:
    X, y = load_data(DATA_PATH)
    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
    print(f"Loaded {len(X)} images.")
except Exception as e:
    print(f"Data loading failed: {e}")


# --- DEEPLAB V3+ BLOCKS (Modified to use simpler conv block and BN/ReLU activation) ---
def convolution_block(block_input, num_filters=256, kernel_size=3, dilation_rate=1, padding="same", use_bias=False):
    x = Conv2D(num_filters, kernel_size=kernel_size, dilation_rate=dilation_rate, padding="same", use_bias=use_bias)(block_input)
    x = BatchNormalization()(x)
    return Activation("relu")(x)

def DilatedSpatialPyramidPooling(dspp_input):
    dims = dspp_input.shape
    x = convolution_block(dspp_input, kernel_size=1, dilation_rate=1)
    out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6)
    out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12)
    out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18)
    
    x_pool = AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input)
    x_pool = convolution_block(x_pool, kernel_size=1, use_bias=True)
    x_pool = UpSampling2D(size=(dims[-3] // K.int_shape(x_pool)[1], dims[-2] // K.int_shape(x_pool)[2]), interpolation="bilinear")(x_pool) # Fixed upsampling logic
    
    x = concatenate([x, out_6, out_12, out_18, x_pool])
    return x

# --- ACCURACY IMPROVEMENT: DeepLabV3+ with MobileNetV2 Backbone ---
def improved_deeplabv3_plus(input_shape, n_classes=1):
    inputs = Input(input_shape)
    
    # Use MobileNetV2 as the feature extractor (Encoder)
    # Exclude top layers and load ImageNet weights for powerful feature extraction
    mobilenet = MobileNetV2(input_shape=input_shape, 
                            include_top=False, 
                            weights="imagenet", 
                            input_tensor=inputs) 
    
    # Select feature maps for the high-level features (used for ASPP) and low-level features (used for decoder skip connection)
    # MobileNetV2 output layers: 'block_13_squeeze' or 'block_13_expand' for high-level, 'block_3_expand' for low-level
    
    # Use the output from the 13th block (high-level features) for DSPP
    high_level_features = mobilenet.get_layer("block_13_expand").output
    
    # ASPP module for multi-scale feature extraction
    x = DilatedSpatialPyramidPooling(high_level_features)
    
    # Decoder path
    # Upsample ASPP output (ASPP output is typically 1/16th of input size)
    x = UpSampling2D((4, 4), interpolation="bilinear")(x)
    
    # Get low-level features for skip connection
    low_level_features = mobilenet.get_layer("block_3_expand").output 
    
    # 1x1 Conv on low-level features to reduce channels before concatenation
    low_level_conv = Conv2D(48, (1, 1), padding="same", use_bias=False)(low_level_features)
    low_level_conv = BatchNormalization()(low_level_conv)
    low_level_conv = Activation("relu")(low_level_conv)

    # Concatenate features and further convolution
    x = concatenate([x, low_level_conv])
    x = convolution_block(x, 256)
    x = convolution_block(x, 256)

    # Final upsampling to match input size and output layer
    x = UpSampling2D((4, 4), interpolation="bilinear")(x)
    
    # Final output with sigmoid activation for binary classification
    outputs = Conv2D(n_classes, (1, 1), activation="sigmoid")(x)

    model = Model(inputs, outputs)
    return model

if 'X_train' in locals():
    # Model uses MobileNetV2 pre-trained weights
    model = improved_deeplabv3_plus((IMG_HEIGHT, IMG_WIDTH, 3)) 
    
    # --- ACCURACY IMPROVEMENT: Changed Loss and Metrics ---
    model.compile(optimizer=Adam(learning_rate=0.001), 
                  loss=dice_loss, 
                  metrics=['accuracy', iou_metric, dice_coeff])
    model.summary()
    
    print("Starting DeepLab Training...")
    # Using small batch size (8) and increased epochs for complex model
    history = model.fit(X_train, y_train, batch_size=8, epochs=30, validation_data=(X_val, y_val)) # Epochs increased to 30
    model.save(MODEL_SAVE_PATH)
    print(f"DeepLab model saved to {MODEL_SAVE_PATH}")
    
    plt.plot(history.history['loss'], label='train_loss')
    plt.plot(history.history['val_loss'], label='val_loss')
    plt.title('Loss Over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Dice Loss')
    plt.legend()
    plt.show()
else:
    print("Data not loaded, skipping training.")