In [4]:
import tensorflow as tf
import numpy as np
import cv2
import os
import glob
import matplotlib.pyplot as plt
from datetime import datetime

# Enable mixed precision
if tf.config.list_physical_devices('GPU'):
    tf.keras.mixed_precision.set_global_policy('mixed_float16')
    print("Mixed precision enabled")
else:
    print("Using standard float32")

# Configuration
IMG_SIZE = 512
INPUT_SHAPE = (IMG_SIZE, IMG_SIZE, 2)
BATCH_SIZE = 4
AUTOTUNE = tf.data.AUTOTUNE
EPOCHS_STAGE1 = 40  # Reduced for faster convergence
EPOCHS_STAGE2 = 60  # Reduced
TOTAL_EPOCHS = EPOCHS_STAGE1 + EPOCHS_STAGE2
LR_STAGE1 = 1e-4
LR_STAGE2 = 1e-5

# ====================================================================
# ENHANCED DATA PROCESSING WITH AUGMENTATION
# ====================================================================

def preprocess_image_for_tfrecord(img_path, mask_path):
    try:
        # Read image
        image = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
        if image is None:
            raise ValueError(f"Failed to read image: {img_path}")
        
        # Handle image channels
        if image.ndim == 2:
            image = np.stack([image, image], axis=-1)
        elif image.shape[-1] > 2:
            image = image[:, :, :2]  # Use first 2 channels
        elif image.shape[-1] == 1:
            image = np.concatenate([image, image], axis=-1)

        # Convert to float32 and normalize (safely)
        image = image.astype(np.float32)
        max_val = np.max(image)
        if max_val > 0:
            image /= max_val

        # Read mask - FIXED THRESHOLDING
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        if mask is None:
            raise ValueError(f"Failed to read mask: {mask_path}")
        
        # CRITICAL FIX: Proper threshold for 8-bit masks (0-255)
        _, mask = cv2.threshold(mask, 128, 1, cv2.THRESH_BINARY)

        # Resize
        image = cv2.resize(image, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_LINEAR)
        mask = cv2.resize(mask, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_NEAREST)

        return image, np.expand_dims(mask, axis=-1)

    except Exception as e:
        print(f"Error processing {img_path}: {str(e)}")
        raise

def create_tfrecord_dataset(image_dir, mask_dir, output_path):
    """Create TFRecords with enhanced numeric filename matching"""
    example_count = 0
    
    # Check if TFRecord exists
    if tf.io.gfile.exists(output_path):
        dataset = tf.data.TFRecordDataset(output_path)
        example_count = sum(1 for _ in dataset)
        if example_count > 0:
            print(f"TFRecord found with {example_count} examples: {output_path}")
            return example_count
    
    print(f"Creating new TFRecord: {output_path}")
    
    # Get all files with numeric sorting
    image_paths = sorted(glob.glob(os.path.join(image_dir, "*.tif*")), 
                        key=lambda x: int(''.join(filter(str.isdigit, os.path.basename(x))) or '0'))
    mask_paths = sorted(glob.glob(os.path.join(mask_dir, "*.tif*")), 
                       key=lambda x: int(''.join(filter(str.isdigit, os.path.basename(x))) or '0'))
    
    # Extract numeric IDs
    def extract_id(path):
        filename = os.path.basename(path)
        digits = ''.join(filter(str.isdigit, filename))
        return int(digits) if digits else -1
    
    image_ids = {extract_id(p): p for p in image_paths}
    mask_ids = {extract_id(p): p for p in mask_paths}
    
    # Find common numeric IDs
    common_ids = set(image_ids.keys()) & set(mask_ids.keys())
    
    if not common_ids:
        raise ValueError("No matching image-mask pairs found.")
    
    print(f"Found {len(common_ids)} valid image-mask pairs")
    
    # Sort numeric IDs
    sorted_ids = sorted(common_ids)
    
    with tf.io.TFRecordWriter(output_path) as writer:
        for id_val in sorted_ids:
            img_path = image_ids[id_val]
            mask_path = mask_ids[id_val]
            
            try:
                image, mask = preprocess_image_for_tfrecord(img_path, mask_path)
                
                feature = {
                    'image': tf.train.Feature(
                        bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(image).numpy()])
                    ),
                    'mask': tf.train.Feature(
                        bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(mask).numpy()])
                    )
                }
                example = tf.train.Example(features=tf.train.Features(feature=feature))
                writer.write(example.SerializeToString())
                example_count += 1
            except Exception as e:
                print(f"Skipping ID {id_val} ({img_path}) due to error: {str(e)}")
    
    print(f"Created TFRecord with {example_count} examples")
    return example_count

def parse_tfrecord(example, augment=False):
    feature_description = {
        'image': tf.io.FixedLenFeature([], tf.string),
        'mask': tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, feature_description)
    image = tf.io.parse_tensor(example['image'], out_type=tf.float32)
    
    # Convert mask to float32
    mask = tf.io.parse_tensor(example['mask'], out_type=tf.uint8)
    mask = tf.cast(mask, tf.float32)
    
    image.set_shape(INPUT_SHAPE)
    mask.set_shape((IMG_SIZE, IMG_SIZE, 1))
    
    # Apply augmentations only during training
    if augment:
        # Random rotations (0, 90, 180, 270 degrees)
        k = tf.random.uniform([], maxval=4, dtype=tf.int32)
        image = tf.image.rot90(image, k)
        mask = tf.image.rot90(mask, k)
        
        # Random flips
        if tf.random.uniform([]) > 0.5:
            image = tf.image.flip_left_right(image)
            mask = tf.image.flip_left_right(mask)
        if tf.random.uniform([]) > 0.5:
            image = tf.image.flip_up_down(image)
            mask = tf.image.flip_up_down(mask)
        
        # Random brightness/contrast (mild)
        image = tf.image.random_brightness(image, max_delta=0.1)
        image = tf.image.random_contrast(image, lower=0.9, upper=1.1)
        
        # Add Gaussian noise (SAR-specific)
        noise = tf.random.normal(tf.shape(image), mean=0.0, stddev=0.03)
        image = tf.add(image, noise)
    
    return image, mask

# ====================================================================
# ENHANCED MODEL ARCHITECTURE WITH IMPROVED ATTENTION
# ====================================================================

def build_attention_resunet():
    inputs = tf.keras.Input(shape=INPUT_SHAPE, name="input_layer")
    
    # Channel adapter
    x = tf.keras.layers.Conv2D(
        3, (1, 1), 
        padding="same", 
        name="channel_adapter",
        kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.01)
    )(inputs)
    
    # Create ResNet backbone
    resnet_base = tf.keras.applications.ResNet50V2(
        include_top=False,
        weights="imagenet",
        input_tensor=x,
        input_shape=(IMG_SIZE, IMG_SIZE, 3)
    )
    
    # Freeze initial layers
    for layer in resnet_base.layers[:150]:
        layer.trainable = False
    
    # Get feature maps
    s1 = resnet_base.get_layer("conv1_conv").output
    s2 = resnet_base.get_layer("conv2_block3_1_relu").output
    s3 = resnet_base.get_layer("conv3_block4_1_relu").output
    s4 = resnet_base.get_layer("conv4_block6_1_relu").output
    bridge = resnet_base.get_layer("conv5_block3_1_relu").output

    # Improved attention gate with residual connection
    def attention_gate(g, x, inter_channel):
        g1 = tf.keras.layers.Conv2D(inter_channel, (1, 1), padding='same')(g)
        g1 = tf.keras.layers.BatchNormalization()(g1)
        g1 = tf.keras.layers.Activation('relu')(g1)
        
        x1 = tf.keras.layers.Conv2D(inter_channel, (1, 1), padding='same')(x)
        x1 = tf.keras.layers.BatchNormalization()(x1)
        
        psi = tf.keras.layers.Add()([g1, x1])
        psi = tf.keras.layers.Activation('relu')(psi)
        psi = tf.keras.layers.Conv2D(1, (1, 1), padding='same', activation='sigmoid')(psi)
        
        # Residual connection
        attented = tf.keras.layers.Multiply()([x, psi])
        return tf.keras.layers.Add()([x, attented])  # Residual attention

    # Enhanced decoder block with dropout
    def decoder_block(input_tensor, skip_connection, filters):
        u = tf.keras.layers.Conv2DTranspose(
            filters, (2, 2), strides=2, padding='same')(input_tensor)
        
        if skip_connection is not None:
            # Apply attention
            attn = attention_gate(u, skip_connection, filters)
            u = tf.keras.layers.Concatenate()([u, attn])
        
        # Add spatial dropout for regularization
        u = tf.keras.layers.SpatialDropout2D(0.2)(u)
        
        # Residual block
        conv1 = tf.keras.layers.Conv2D(filters, (3, 3), padding='same')(u)
        conv1 = tf.keras.layers.BatchNormalization()(conv1)
        conv1 = tf.keras.layers.Activation('relu')(conv1)
        
        conv2 = tf.keras.layers.Conv2D(filters, (3, 3), padding='same')(conv1)
        conv2 = tf.keras.layers.BatchNormalization()(conv2)
        
        # Skip connection
        if u.shape[-1] == filters:
            shortcut = u
        else:
            shortcut = tf.keras.layers.Conv2D(filters, (1, 1), padding='same')(u)
        
        res = tf.keras.layers.Add()([conv2, shortcut])
        res = tf.keras.layers.Activation('relu')(res)
        return res

    # Bridge processing with dilation
    b = tf.keras.layers.Conv2D(1024, (3, 3), padding='same', dilation_rate=2)(bridge)
    b = tf.keras.layers.BatchNormalization()(b)
    b = tf.keras.layers.Activation('relu')(b)
    
    # Decoder path
    d1 = decoder_block(b, s4, 512)
    d2 = decoder_block(d1, s3, 256)
    d3 = decoder_block(d2, s2, 128)
    d4 = decoder_block(d3, s1, 64)
    
    # Final upsampling
    u_final = tf.keras.layers.Conv2DTranspose(32, (2, 2), strides=2, padding='same')(d4)
    u_final = tf.keras.layers.Conv2D(32, (3, 3), padding='same', activation='relu')(u_final)
    
    # Output layer
    outputs = tf.keras.layers.Conv2D(
        1, (1, 1), activation='sigmoid', dtype=tf.float32, name="output"
    )(u_final)

    return tf.keras.Model(inputs=inputs, outputs=outputs, name="EnhancedAttentionResUNet")

# ====================================================================
# ENHANCED LOSS & METRICS
# ====================================================================

def focal_loss(gamma=3.0, alpha=0.8):  # More focus on hard examples
    def loss(y_true, y_pred):
        y_pred = tf.clip_by_value(y_pred, 1e-7, 1 - 1e-7)
        ce = -y_true * tf.math.log(y_pred) - (1 - y_true) * tf.math.log(1 - y_pred)
        fl = alpha * tf.pow(1 - y_pred, gamma) * ce
        return tf.reduce_mean(fl)
    return loss

def dice_loss(y_true, y_pred):
    numerator = 2 * tf.reduce_sum(y_true * y_pred)
    denominator = tf.reduce_sum(y_true + y_pred)
    return 1 - (numerator + 1e-7) / (denominator + 1e-7)

def hybrid_loss(y_true, y_pred):
    focal = focal_loss(gamma=3.0, alpha=0.8)(y_true, y_pred)
    dice = dice_loss(y_true, y_pred)
    return focal + dice  # Combine both losses

def iou(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred > 0.5, tf.float32)
    intersection = tf.reduce_sum(y_true * y_pred)
    union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) - intersection
    return (intersection + 1e-7) / (union + 1e-7)

# ====================================================================
# OPTIMIZED TRAINING PIPELINE
# ====================================================================

# Paths (update these with your actual paths)
TRAIN_IMAGES = r"C:\Users\Admin\Documents\train\images"
TRAIN_MASKS = r"C:\Users\Admin\Documents\train\mask"
VAL_IMAGES = r"C:\Users\Admin\Documents\val\images"
VAL_MASKS = r"C:\Users\Admin\Documents\val\mask"
TF_TRAIN_PATH = "/kaggle/input/sar-tfrecords/tfrecords/train.tfrecord"
TF_VAL_PATH = "/kaggle/input/sar-tfrecords/tfrecords/val.tfrecord"
OUTPUT_DIR = "/kaggle/working/"

# Create output directory
experiment_name = f"oilspill_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
output_path = os.path.join(OUTPUT_DIR, experiment_name)
os.makedirs(output_path, exist_ok=True)

# Model paths
CHECKPOINT_PATH = os.path.join(output_path, "checkpoint.weights.h5")
FINAL_MODEL_PATH = os.path.join(output_path, "oil_spill_model.keras")
LOG_FILE = os.path.join(output_path, "training_log.csv")

# Create TFRecords
print("\nCreating training TFRecord...")
train_count = create_tfrecord_dataset(TRAIN_IMAGES, TRAIN_MASKS, TF_TRAIN_PATH)
print("\nCreating validation TFRecord...")
val_count = create_tfrecord_dataset(VAL_IMAGES, VAL_MASKS, TF_VAL_PATH)

# Calculate class weights (critical for imbalance)
print("\nCalculating class weights...")
def calculate_class_weights():
    total_pixels = 0
    oil_pixels = 0
    
    dataset = tf.data.TFRecordDataset(TF_TRAIN_PATH)
    dataset = dataset.map(lambda x: parse_tfrecord(x, augment=False))
    dataset = dataset.take(100).batch(10)  # Sample 100 batches
    
    for images, masks in dataset:
        oil_pixels += tf.reduce_sum(masks)
        total_pixels += tf.size(masks)
    
    oil_weight = (total_pixels - oil_pixels) / total_pixels
    background_weight = oil_pixels / total_pixels
    
    print(f"Weights - Oil: {oil_weight:.4f}, Background: {background_weight:.4f}")
    return {0: background_weight, 1: oil_weight}

class_weights = calculate_class_weights()

# Prepare datasets
steps_per_epoch = train_count // BATCH_SIZE
validation_steps = max(1, val_count // BATCH_SIZE)

# Training dataset with augmentation
train_ds = (
    tf.data.TFRecordDataset(TF_TRAIN_PATH)
    .repeat()
    .map(lambda x: parse_tfrecord(x, augment=True), num_parallel_calls=AUTOTUNE)
    .shuffle(100)
    .batch(BATCH_SIZE)
    .prefetch(AUTOTUNE)
)

# Validation dataset without augmentation
val_ds = (
    tf.data.TFRecordDataset(TF_VAL_PATH)
    .repeat()
    .map(lambda x: parse_tfrecord(x, augment=False), num_parallel_calls=AUTOTUNE)
    .batch(BATCH_SIZE)
    .prefetch(AUTOTUNE)
)

# Build model
print("\nBuilding model...")
model = build_attention_resunet()
model.summary()

# Diagnostic check
print("\nPerforming diagnostic check...")
for images, masks in train_ds.take(1):
    print("Image range:", tf.reduce_min(images).numpy(), tf.reduce_max(images).numpy())
    print("Mask values:", tf.unique(tf.reshape(masks, [-1]))[0].numpy())
    plt.figure(figsize=(12, 6))
    plt.subplot(121)
    plt.title("SAR Image (VH)")
    plt.imshow(images[0, :, :, 1], cmap='gray')
    plt.subplot(122)
    plt.title("Mask")
    plt.imshow(masks[0, :, :, 0], cmap='jet')
    plt.savefig(os.path.join(output_path, "sample_data.png"))
    plt.close()

# Callbacks
class MemorySavingCSVLogger(tf.keras.callbacks.CSVLogger):
    def on_epoch_end(self, epoch, logs=None):
        super().on_epoch_end(epoch, logs)
        tf.keras.backend.clear_session()

checkpoint = tf.keras.callbacks.ModelCheckpoint(
    CHECKPOINT_PATH,
    save_weights_only=True,
    save_best_only=True,
    monitor='val_iou',
    mode='max',
    verbose=1
)

reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_iou',
    mode='max',
    factor=0.5,
    patience=8,  # Reduce LR after 8 epochs without improvement
    min_lr=1e-6,
    verbose=1
)

early_stop = tf.keras.callbacks.EarlyStopping(
    patience=30,  # Reduced patience
    restore_best_weights=True,
    monitor='val_iou',
    mode='max'
)

# Compile model with standard Adam optimizer
print("\nCompiling model...")
model.compile(
    optimizer=tf.keras.optimizers.Adam(LR_STAGE1),
    loss=hybrid_loss,  # Using combined focal + dice loss
    metrics=[
        'accuracy',
        tf.keras.metrics.Precision(name='precision'),
        tf.keras.metrics.Recall(name='recall'),
        iou
    ]
)

# Train model
print("\n" + "="*50)
print(f"STARTING TRAINING ({TOTAL_EPOCHS} EPOCHS)")
print("="*50)

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=TOTAL_EPOCHS,
    steps_per_epoch=steps_per_epoch,
    validation_steps=validation_steps,
    class_weight=class_weights,  # Critical for imbalance
    callbacks=[
        MemorySavingCSVLogger(LOG_FILE), 
        checkpoint, 
        early_stop,
        reduce_lr
    ]
)

# Load best weights and save full model
print("\nLoading best weights and saving full model...")
model.load_weights(CHECKPOINT_PATH)
model.save(FINAL_MODEL_PATH)
print(f"Full model saved to {FINAL_MODEL_PATH}")

# Clean up checkpoint file
os.remove(CHECKPOINT_PATH)
print(f"Removed temporary checkpoint: {CHECKPOINT_PATH}")

# Plot training history
def plot_history(history, output_dir):
    plt.figure(figsize=(15, 10))
    
    # Loss
    plt.subplot(2, 2, 1)
    plt.plot(history.history['loss'], label='Train Loss')
    plt.plot(history.history['val_loss'], label='Val Loss')
    plt.title('Loss Evolution')
    plt.legend()
    
    # IoU
    plt.subplot(2, 2, 2)
    plt.plot(history.history['iou'], label='Train IoU')
    plt.plot(history.history['val_iou'], label='Val IoU')
    plt.title('IoU Evolution')
    plt.legend()
    
    # Precision-Recall
    plt.subplot(2, 2, 3)
    plt.plot(history.history['precision'], label='Train Precision')
    plt.plot(history.history['val_precision'], label='Val Precision')
    plt.plot(history.history['recall'], label='Train Recall')
    plt.plot(history.history['val_recall'], label='Val Recall')
    plt.title('Precision & Recall')
    plt.legend()
    
    # Learning Rate
    if 'lr' in history.history:
        plt.subplot(2, 2, 4)
        plt.plot(history.history['lr'], label='Learning Rate')
        plt.title('Learning Rate Schedule')
        plt.yscale('log')
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'training_history.png'))
    plt.close()

plot_history(history, output_path)
print("Training history plot saved")

Mixed precision enabled

Creating training TFRecord...


I0000 00:00:1754165803.084885      36 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 15513 MB memory:  -> device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0


TFRecord found with 2066 examples: /kaggle/input/sar-tfrecords/tfrecords/train.tfrecord

Creating validation TFRecord...
TFRecord found with 504 examples: /kaggle/input/sar-tfrecords/tfrecords/val.tfrecord

Calculating class weights...


InvalidArgumentError: cannot compute Sub as input #1(zero-based) was expected to be a int32 tensor but is a float tensor [Op:Sub] name: 