In [None]:
import tensorflow as tf
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import (Input, Conv2D, MaxPooling2D, Dropout, UpSampling2D,
                                   concatenate, BatchNormalization, Activation,
                                   GlobalAveragePooling2D, Dense, Flatten)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.metrics import Precision, Recall, AUC
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
from datetime import datetime
import json
from sklearn.metrics import (precision_score, recall_score, f1_score,
                           jaccard_score, confusion_matrix, classification_report,
                           roc_curve, auc)
from tqdm import tqdm
import gc

# Set random seeds for reproducibility
tf.random.set_seed(42)
np.random.seed(42)

# Enable memory growth to avoid OOM errors
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

In [None]:
class OilSpillPipeline:
    def __init__(self, input_shape=(256, 256, 3)):
        self.input_shape = input_shape
        self.classifier = None
        self.segmenter = None
        self.classifier_metrics = {}
        self.segmentation_metrics = {}

    def build_classifier(self):
        """Build a lightweight CNN classifier"""
        inputs = Input(self.input_shape)

        # Feature extraction layers
        x = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
        x = MaxPooling2D((2, 2))(x)
        x = BatchNormalization()(x)

        x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
        x = MaxPooling2D((2, 2))(x)
        x = BatchNormalization()(x)

        x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
        x = MaxPooling2D((2, 2))(x)
        x = BatchNormalization()(x)

        # Classification head
        x = GlobalAveragePooling2D()(x)
        x = Dense(128, activation='relu')(x)
        x = Dropout(0.3)(x)
        x = Dense(64, activation='relu')(x)
        outputs = Dense(1, activation='sigmoid')(x)

        model = Model(inputs=inputs, outputs=outputs)
        self.classifier = model
        return model

    def build_segmenter(self, filters=32):
        """Build a compact U-Net for segmentation"""
        inputs = Input(self.input_shape)

        # Encoder
        def conv_block(x, filters, dropout_rate=0.0):
            x = Conv2D(filters, 3, padding='same')(x)
            x = BatchNormalization()(x)
            x = Activation('relu')(x)

            x = Conv2D(filters, 3, padding='same')(x)
            x = BatchNormalization()(x)
            x = Activation('relu')(x)

            if dropout_rate > 0:
                x = Dropout(dropout_rate)(x)
            return x

        # Contracting path
        c1 = conv_block(inputs, filters)
        p1 = MaxPooling2D((2, 2))(c1)

        c2 = conv_block(p1, filters*2)
        p2 = MaxPooling2D((2, 2))(c2)

        c3 = conv_block(p2, filters*4)
        p3 = MaxPooling2D((2, 2))(c3)

        # Bottleneck
        c4 = conv_block(p3, filters*8, dropout_rate=0.3)

        # Expanding path
        u1 = UpSampling2D((2, 2))(c4)
        u1 = Conv2D(filters*4, 2, padding='same')(u1)
        u1 = concatenate([u1, c3])
        c5 = conv_block(u1, filters*4)

        u2 = UpSampling2D((2, 2))(c5)
        u2 = Conv2D(filters*2, 2, padding='same')(u2)
        u2 = concatenate([u2, c2])
        c6 = conv_block(u2, filters*2)

        u3 = UpSampling2D((2, 2))(c6)
        u3 = Conv2D(filters, 2, padding='same')(u3)
        u3 = concatenate([u3, c1])
        c7 = conv_block(u3, filters)

        outputs = Conv2D(1, 1, activation='sigmoid')(c7)

        model = Model(inputs=inputs, outputs=outputs)
        self.segmenter = model
        return model

    def dice_coefficient(self, y_true, y_pred, smooth=1e-6):
        """Dice coefficient metric for segmentation evaluation"""
        y_true_f = tf.reshape(y_true, [-1])
        y_pred_f = tf.reshape(y_pred, [-1])
        intersection = tf.reduce_sum(y_true_f * y_pred_f)
        return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)

    def dice_loss(self, y_true, y_pred):
        """Dice loss for segmentation"""
        return 1 - self.dice_coefficient(y_true, y_pred)

    def bce_dice_loss(self, y_true, y_pred):
        """Combined BCE and Dice loss"""
        y_true_flat = tf.reshape(y_true, [-1])
        y_pred_flat = tf.reshape(y_pred, [-1])
        bce = tf.keras.losses.binary_crossentropy(y_true_flat, y_pred_flat)
        dice = self.dice_loss(y_true, y_pred)
        return bce + dice

    def compile_models(self, classifier_lr=1e-4, segmenter_lr=1e-4):
        """Compile both models"""
        if self.classifier is None or self.segmenter is None:
            raise ValueError("Models not built yet. Please build models first.")

        # Compile classifier
        self.classifier.compile(
            optimizer=Adam(learning_rate=classifier_lr),
            loss='binary_crossentropy',
            metrics=['accuracy', Precision(), Recall(), AUC()]
        )

        # Compile segmenter
        self.segmenter.compile(
            optimizer=Adam(learning_rate=segmenter_lr),
            loss=self.bce_dice_loss,
            metrics=['accuracy', self.dice_coefficient]
        )

In [None]:
    def train_classifier(self, X_train, y_train, X_val, y_val, epochs=20, batch_size=32, class_weight=None):
        """Train the classifier model"""
        if self.classifier is None:
            raise ValueError("Classifier not built yet.")

        callbacks = [
            ModelCheckpoint(
                'best_classifier.h5',
                monitor='val_accuracy',
                save_best_only=True,
                mode='max',
                verbose=1
            ),
            EarlyStopping(
                monitor='val_accuracy',
                patience=8,
                restore_best_weights=True,
                verbose=1
            ),
            ReduceLROnPlateau(
                monitor='val_loss',
                factor=0.2,
                patience=4,
                min_lr=1e-7,
                verbose=1
            )
        ]

        history = self.classifier.fit(
            X_train, y_train,
            batch_size=batch_size,
            epochs=epochs,
            validation_data=(X_val, y_val),
            callbacks=callbacks,
            class_weight=class_weight,
            verbose=1
        )

        return history

    def train_segmenter(self, X_train, y_train, X_val, y_val, epochs=15, batch_size=8):
        """Train the segmentation model"""
        if self.segmenter is None:
            raise ValueError("Segmenter not built yet.")

        callbacks = [
            ModelCheckpoint(
                'best_segmenter.h5',
                monitor='val_loss',
                save_best_only=True,
                mode='min',
                verbose=1
            ),
            EarlyStopping(
                monitor='val_loss',
                patience=6,
                restore_best_weights=True,
                verbose=1
            ),
            ReduceLROnPlateau(
                monitor='val_loss',
                factor=0.2,
                patience=3,
                min_lr=1e-7,
                verbose=1
            )
        ]

        history = self.segmenter.fit(
            X_train, y_train,
            batch_size=batch_size,
            epochs=epochs,
            validation_data=(X_val, y_val),
            callbacks=callbacks,
            verbose=1
        )

        return history

In [None]:
    def evaluate_classifier(self, X_test, y_test, threshold=0.5):
        """Evaluate the classifier model"""
        if self.classifier is None:
            raise ValueError("Classifier not trained yet.")

        # Keras evaluation
        results = self.classifier.evaluate(X_test, y_test, verbose=0)
        metrics = {
            'loss': results[0],
            'accuracy': results[1],
            'precision': results[2],
            'recall': results[3],
            'auc': results[4]
        }

        # Additional metrics
        y_pred_probs = self.classifier.predict(X_test, verbose=0)
        y_pred = (y_pred_probs > threshold).astype(int)

        metrics['f1_score'] = f1_score(y_test, y_pred)
        metrics['confusion_matrix'] = confusion_matrix(y_test, y_pred).tolist()

        # ROC curve data
        fpr, tpr, _ = roc_curve(y_test, y_pred_probs)
        metrics['roc_curve'] = {'fpr': fpr.tolist(), 'tpr': tpr.tolist()}
        metrics['roc_auc'] = auc(fpr, tpr)

        self.classifier_metrics = metrics
        return metrics

    def evaluate_segmenter(self, X_test, y_test, batch_size=8):
        """Evaluate the segmentation model"""
        if self.segmenter is None:
            raise ValueError("Segmenter not trained yet.")

        # Evaluate in batches
        num_batches = int(np.ceil(len(X_test) / batch_size))
        metrics_accum = {
            'loss': 0,
            'accuracy': 0,
            'dice_coefficient': 0
        }

        for i in range(num_batches):
            start_idx = i * batch_size
            end_idx = min((i + 1) * batch_size, len(X_test))

            X_batch = X_test[start_idx:end_idx]
            y_batch = y_test[start_idx:end_idx]

            results = self.segmenter.evaluate(X_batch, y_batch, verbose=0)
            metrics_accum['loss'] += results[0] * len(X_batch)
            metrics_accum['accuracy'] += results[1] * len(X_batch)
            metrics_accum['dice_coefficient'] += results[2] * len(X_batch)

        # Average metrics
        metrics = {
            'loss': float(metrics_accum['loss'] / len(X_test)),
            'accuracy': float(metrics_accum['accuracy'] / len(X_test)),
            'dice_coefficient': float(metrics_accum['dice_coefficient'] / len(X_test))
        }

        # Calculate additional metrics on a subset
        subset_size = min(50, len(X_test))
        indices = np.random.choice(len(X_test), subset_size, replace=False)
        X_subset = X_test[indices]
        y_subset = y_test[indices]

        y_pred = self.segmenter.predict(X_subset, verbose=0)
        y_pred_binary = (y_pred > 0.5).astype(np.uint8)

        y_test_flat = y_subset.squeeze().flatten()
        y_pred_flat = y_pred_binary.flatten()

        metrics['f1_score'] = float(f1_score(y_test_flat, y_pred_flat, zero_division=1))
        metrics['jaccard_score'] = float(jaccard_score(y_test_flat, y_pred_flat, zero_division=1))
        metrics['confusion_matrix'] = confusion_matrix(y_test_flat, y_pred_flat).tolist()

        self.segmentation_metrics = metrics
        return metrics

In [None]:
    def predict_pipeline(self, images, classifier_threshold=0.5, segmenter_threshold=0.5):
        """
        End-to-end prediction: first classify, then segment if oil spill detected
        Returns: classifications, segmentations, confidence scores
        """
        if self.classifier is None or self.segmenter is None:
            raise ValueError("Models not trained yet.")

        # Add batch dimension if needed
        if len(images.shape) == 3:
            images = np.expand_dims(images, axis=0)

        # Classify images
        class_probs = self.classifier.predict(images, verbose=0).flatten()
        classifications = (class_probs > classifier_threshold).astype(int)

        # Only segment images classified as having oil spills
        segmentations = np.zeros((len(images), *self.input_shape[:2]), dtype=np.uint8)

        oil_spill_indices = np.where(classifications == 1)[0]
        if len(oil_spill_indices) > 0:
            oil_images = images[oil_spill_indices]
            oil_masks = self.segmenter.predict(oil_images, verbose=0)
            oil_masks_binary = (oil_masks.squeeze() > segmenter_threshold).astype(np.uint8)

            for i, idx in enumerate(oil_spill_indices):
                segmentations[idx] = oil_masks_binary[i]

        return classifications, segmentations, class_probs

    def visualize_results(self, images, true_classifications=None, true_masks=None,
                         num_samples=5, save_path=None):
        """
        Visualize pipeline results with optional ground truth comparison
        """
        classifications, segmentations, confidences = self.predict_pipeline(images)

        # Select random samples
        if len(images) > num_samples:
            indices = np.random.choice(len(images), num_samples, replace=False)
        else:
            indices = range(len(images))

        fig, axes = plt.subplots(len(indices), 4, figsize=(16, 4 * len(indices)))
        if len(indices) == 1:
            axes = axes.reshape(1, -1)

        for i, idx in enumerate(indices):
            # Original image
            axes[i, 0].imshow(images[idx])
            axes[i, 0].set_title('Original Image')
            axes[i, 0].axis('off')

            # Classification result
            class_text = f"Oil Spill: {classifications[idx]} (Conf: {confidences[idx]:.3f})"
            axes[i, 1].imshow(images[idx])
            axes[i, 1].set_title(class_text, color='green' if classifications[idx] else 'red')
            axes[i, 1].axis('off')

            # Segmentation result (if applicable)
            if classifications[idx] == 1:
                axes[i, 2].imshow(segmentations[idx], cmap='jet')
                axes[i, 2].set_title('Predicted Oil Spill')
            else:
                axes[i, 2].imshow(np.zeros_like(segmentations[idx]), cmap='gray')
                axes[i, 2].set_title('No Oil Spill Predicted')
            axes[i, 2].axis('off')

            # Ground truth (if available)
            if true_classifications is not None and true_masks is not None:
                gt_class = true_classifications[idx]
                axes[i, 3].imshow(images[idx])

                if gt_class == 1:
                    # Show ground truth mask
                    axes[i, 3].imshow(true_masks[idx].squeeze(), alpha=0.5, cmap='jet')
                    title_text = f"Ground Truth: Oil Spill"
                else:
                    title_text = f"Ground Truth: No Oil Spill"

                axes[i, 3].set_title(title_text)
                axes[i, 3].axis('off')
            else:
                axes[i, 3].axis('off')

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, dpi=150, bbox_inches='tight')
        plt.show()

    def save_models(self, classifier_path='oil_spill_classifier.h5',
                   segmenter_path='oil_spill_segmenter.h5'):
        """Save both models"""
        if self.classifier is not None:
            self.classifier.save(classifier_path)
            print(f"Classifier saved to {classifier_path}")

        if self.segmenter is not None:
            self.segmenter.save(segmenter_path, custom_objects={
                'dice_coefficient': self.dice_coefficient,
                'dice_loss': self.dice_loss,
                'bce_dice_loss': self.bce_dice_loss
            })
            print(f"Segmenter saved to {segmenter_path}")

    def load_models(self, classifier_path='oil_spill_classifier.h5',
                   segmenter_path='oil_spill_segmenter.h5'):
        """Load pre-trained models"""
        self.classifier = load_model(classifier_path)
        print(f"Classifier loaded from {classifier_path}")

        self.segmenter = load_model(segmenter_path, custom_objects={
            'dice_coefficient': self.dice_coefficient,
            'dice_loss': self.dice_loss,
            'bce_dice_loss': self.bce_dice_loss
        })
        print(f"Segmenter loaded from {segmenter_path}")