In [None]:
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, AveragePooling2D, UpSampling2D, concatenate
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras import backend as K

# --- CONFIGURATION ---
IMG_HEIGHT = 256
IMG_WIDTH = 256
DATA_PATH = '../data/train'
MODEL_SAVE_PATH = '../saved_models/deeplabv3_oil_spill.h5'

# --- 1. CRITICAL FIX: DIFFERENTIABLE LOSS FUNCTIONS ---

# Soft Dice Coefficient (Use this for LOSS)
# It uses the raw probabilities (0.0 to 1.0) instead of thresholding.
# This allows the gradient to flow back through the network.
def dice_coeff_soft(y_true, y_pred, smooth=1e-6):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

# Dice Loss function (Minimize this)
def dice_loss(y_true, y_pred):
    return 1 - dice_coeff_soft(y_true, y_pred)

# --- METRICS (For Human Monitoring) ---
# It is okay to use thresholding here because metrics are not used for backpropagation.

# Intersection over Union (IoU)
def iou_metric(y_true, y_pred, smooth=1e-6):
    y_pred_binary = K.cast(K.greater(y_pred, 0.5), K.floatx())
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred_binary)
    intersection = K.sum(y_true_f * y_pred_f)
    union = K.sum(y_true_f) + K.sum(y_pred_f) - intersection
    return (intersection + smooth) / (union + smooth)

# Hard Dice Score (F1 Score) - Strictly for reporting accuracy
def dice_metric_hard(y_true, y_pred, smooth=1e-6):
    y_pred_binary = K.cast(K.greater(y_pred, 0.5), K.floatx())
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred_binary)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

# --- 2. DATA LOADER ---
def load_data(path):
    images = []
    masks = []
    img_dir = os.path.join(path, 'images')
    mask_dir = os.path.join(path, 'labels') 

    if not os.path.exists(mask_dir):
        mask_dir = os.path.join(path, 'masks') 

    files = os.listdir(img_dir)[:500] # Loading first 500 for training
    for file_name in files:
        img_path = os.path.join(img_dir, file_name)
        img = cv2.imread(img_path)
        if img is None: continue

        mask_path = os.path.join(mask_dir, file_name)
        if not os.path.exists(mask_path):
            mask_path = os.path.join(mask_dir, os.path.splitext(file_name)[0] + ".png")
            
        mask = cv2.imread(mask_path, 0)
        if mask is None: continue

        img = cv2.resize(img, (IMG_WIDTH, IMG_HEIGHT))
        mask = cv2.resize(mask, (IMG_WIDTH, IMG_HEIGHT))
        images.append(img)
        masks.append(mask)

    images = np.array(images) / 255.0
    masks = np.array(masks) / 255.0
    masks = np.expand_dims(masks, axis=-1)
    return images, masks

# Load and Split
try:
    print("Loading data...")
    X, y = load_data(DATA_PATH)
    if len(X) == 0:
        raise ValueError(f"No images found in {DATA_PATH}. Check your paths.")
    
    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
    print(f"Loaded {len(X)} images successfully.")
except Exception as e:
    print(f"FATAL ERROR during data loading: {e}")
    X_train = None # Prevent subsequent code from crashing blindly

# --- 3. MODEL ARCHITECTURE ---
def convolution_block(block_input, num_filters=256, kernel_size=3, dilation_rate=1, padding="same", use_bias=False):
    x = Conv2D(num_filters, kernel_size=kernel_size, dilation_rate=dilation_rate, padding="same", use_bias=use_bias)(block_input)
    x = BatchNormalization()(x)
    return Activation("relu")(x)

def DilatedSpatialPyramidPooling(dspp_input):
    dims = dspp_input.shape
    x = convolution_block(dspp_input, kernel_size=1, dilation_rate=1)
    out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6)
    out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12)
    out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18)
    
    x_pool = AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input)
    x_pool = convolution_block(x_pool, kernel_size=1, use_bias=True)
    
    # Robust upsampling calculation
    pool_shape = K.int_shape(x_pool)
    target_h = dims[1]
    target_w = dims[2]
    
    # Calculate scale factors dynamically
    if target_h is not None and pool_shape[1] is not None:
        scale_h = target_h // pool_shape[1]
        scale_w = target_w // pool_shape[2]
        x_pool = UpSampling2D(size=(scale_h, scale_w), interpolation="bilinear")(x_pool)
    
    x = concatenate([x, out_6, out_12, out_18, x_pool])
    return x

def improved_deeplabv3_plus(input_shape, n_classes=1):
    inputs = Input(input_shape)
    
    # Encoder: MobileNetV2
    mobilenet = MobileNetV2(input_shape=input_shape, 
                            include_top=False, 
                            weights="imagenet", 
                            input_tensor=inputs) 
    
    # Feature extraction points
    high_level_features = mobilenet.get_layer("block_13_expand").output
    low_level_features = mobilenet.get_layer("block_3_expand").output 
    
    # ASPP
    x = DilatedSpatialPyramidPooling(high_level_features)
    x = UpSampling2D((4, 4), interpolation="bilinear")(x)
    
    # Low-level features processing
    low_level_conv = Conv2D(48, (1, 1), padding="same", use_bias=False)(low_level_features)
    low_level_conv = BatchNormalization()(low_level_conv)
    low_level_conv = Activation("relu")(low_level_conv)

    # Decoder
    x = concatenate([x, low_level_conv])
    x = convolution_block(x, 256)
    x = convolution_block(x, 256)

    x = UpSampling2D((4, 4), interpolation="bilinear")(x)
    outputs = Conv2D(n_classes, (1, 1), activation="sigmoid")(x)

    model = Model(inputs, outputs)
    return model

# --- 4. TRAINING ---
if X_train is not None:
    # Build Model
    model = improved_deeplabv3_plus((IMG_HEIGHT, IMG_WIDTH, 3))
    
    # Compile with the NEW differentiable loss
    print("Compiling model with Soft Dice Loss...")
    model.compile(optimizer=Adam(learning_rate=0.001), 
                  loss=dice_loss, 
                  metrics=['accuracy', iou_metric, dice_metric_hard])
    
    # model.summary() # Optional: Uncomment to see architecture

    print("Starting DeepLab Training...")
    history = model.fit(X_train, y_train, 
                        batch_size=8, 
                        epochs=30, 
                        validation_data=(X_val, y_val))
    
    # Save Model
    model.save(MODEL_SAVE_PATH)
    print(f"DeepLab model saved to {MODEL_SAVE_PATH}")
    
    # Plot Results
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(history.history['loss'], label='Train Loss')
    plt.plot(history.history['val_loss'], label='Val Loss')
    plt.title('Loss (Soft Dice)')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(history.history['dice_metric_hard'], label='Train F1 (Dice)')
    plt.plot(history.history['val_dice_metric_hard'], label='Val F1 (Dice)')
    plt.title('Accuracy (Hard Dice/F1)')
    plt.legend()
    
    plt.show()
else:
    print("Skipping training due to data loading errors.")