# Oil Spill Segmentation - Training & Evaluation

**Train and evaluate U-Net for oil spill segmentation, with early stopping based on validation IoU/F1.**

- Training stops when validation mIoU stops improving for several epochs (patience).
- All key metrics shown (IoU, Dice, F1, accuracy, precision, recall).
- False alarm (FP) rates: aim for F1 > 0.7 and IoU > 0.6, false alarms 5-10% acceptable.
- Continue training >20 epochs if metrics keep improving.
- Save best model by validation mIoU.

Run this after dataset and model notebooks.

In [13]:
# Imports & Setup
import os, random, time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
import json
import gc
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
# Load config from previous notebook
with open('config/model_config.json', 'r') as f:
    cfg = json.load(f)

# Make results dir
os.makedirs('results', exist_ok=True)

# Device
device = torch.device(cfg.get('device','cuda' if torch.cuda.is_available() else 'cpu'))
print('Using device:', device)
gc.collect()
if torch.cuda.is_available(): torch.cuda.empty_cache()

Using device: cpu


In [14]:
# Set seed for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
set_seed(42)

In [15]:
# Define dataset class (should match your previous dataset pipeline exactly)
from torch.utils.data import Dataset
class OilSpillDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, mask_transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.mask_transform = mask_transform
        self.images = []
        if os.path.exists(image_dir):
            self.images = [f for f in os.listdir(image_dir) if f.lower().endswith(('.jpg','.jpeg','.png'))]
            self.images.sort()
    def __len__(self):
        return len(self.images)
    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(self.mask_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L') if os.path.exists(mask_path) else Image.new('L', image.size, 0)
        if self.transform:
            image = self.transform(image)
        if self.mask_transform:
            mask = self.mask_transform(mask)
        mask = (mask > 0.5).float()
        return image, mask.squeeze(0).long(), img_name

In [16]:
# Transforms
from torchvision import transforms
def get_transforms(is_training=True):
    if is_training:
        return transforms.Compose([
            transforms.Resize((cfg['image_size'], cfg['image_size'])),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
        ])
    else:
        return transforms.Compose([
            transforms.Resize((cfg['image_size'], cfg['image_size'])),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
        ])
def get_mask_transforms():
    return transforms.Compose([
        transforms.Resize((cfg['image_size'], cfg['image_size'])),
        transforms.ToTensor()
    ])

In [17]:
# DataLoader setup
train_dataset = OilSpillDataset(
    image_dir=cfg['data_paths']['train_images'],
    mask_dir=cfg['data_paths']['train_masks'],
    transform=get_transforms(is_training=True),
    mask_transform=get_mask_transforms()
)
val_dataset = OilSpillDataset(
    image_dir=cfg['data_paths']['val_images'],
    mask_dir=cfg['data_paths']['val_masks'],
    transform=get_transforms(is_training=False),
    mask_transform=get_mask_transforms()
)
train_loader = DataLoader(train_dataset, batch_size=cfg['batch_size'], shuffle=True, num_workers=cfg['num_workers'], pin_memory=cfg['pin_memory'])
val_loader = DataLoader(val_dataset, batch_size=cfg['batch_size'], shuffle=False, num_workers=cfg['num_workers'], pin_memory=cfg['pin_memory'])
print(f"Train samples: {len(train_dataset)}  Val samples: {len(val_dataset)}")

Train samples: 811  Val samples: 203


In [18]:
# Load U-Net model (from previous notebook)
import torch.nn.functional as F
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)
class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
    def forward(self, x):
        return self.maxpool_conv(x)
class Up(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, 2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)
    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)
class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=1):
        super().__init__()
        self.inc = DoubleConv(n_channels,64)
        self.down1 = Down(64,128)
        self.down2 = Down(128,256)
        self.down3 = Down(256,512)
        self.down4 = Down(512,1024)
        self.up1 = Up(1024,512)
        self.up2 = Up(512,256)
        self.up3 = Up(256,128)
        self.up4 = Up(128,64)
        self.outc = nn.Conv2d(64, n_classes, 1)
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits
# Instantiate model
model = UNet(n_channels=3, n_classes=1).to(device)

In [19]:
# Define binary loss: BCE + Dice
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth
    def forward(self, preds, targets):
        preds = torch.sigmoid(preds)
        preds = preds.view(-1)
        targets = targets.view(-1)
        intersection = (preds * targets).sum()
        dice = (2.*intersection + self.smooth) / (preds.sum() + targets.sum() + self.smooth)
        return 1 - dice
class CombinedLoss(nn.Module):
    def __init__(self, bce_weight=0.5, dice_weight=0.5):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.dice = DiceLoss()
        self.bce_weight = bce_weight
        self.dice_weight = dice_weight
    def forward(self, preds, targets):
        return self.bce_weight*self.bce(preds, targets.float()) + self.dice_weight*self.dice(preds, targets.float())
criterion = CombinedLoss()

In [20]:
# Optimizer & Scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=3, factor=0.5)
scaler = torch.cuda.amp.GradScaler(enabled=(device.type=='cuda'))

  scaler = torch.cuda.amp.GradScaler(enabled=(device.type=='cuda'))


In [21]:
# Metrics function (IoU, Dice, F1, accuracy, precision, recall)
def calc_metrics(preds, targets):
    preds = preds.cpu().numpy().astype(np.uint8)
    targets = targets.cpu().numpy().astype(np.uint8)
    acc = (preds == targets).mean()
    tp = np.logical_and(preds==1, targets==1).sum()
    fp = np.logical_and(preds==1, targets==0).sum()
    fn = np.logical_and(preds==0, targets==1).sum()
    precision = tp/(tp+fp+1e-8)
    recall = tp/(tp+fn+1e-8)
    f1 = 2*precision*recall/(precision+recall+1e-8)
    inter = np.logical_and(preds==1, targets==1).sum()
    union = np.logical_or(preds==1, targets==1).sum()
    iou = inter/union if union>0 else 1.0
    dice = (2*inter)/(preds.sum()+targets.sum()+1e-8) if (preds.sum()+targets.sum())>0 else 1.0
    return {
        'acc': float(acc), 'iou': float(iou), 'dice': float(dice),
        'precision': float(precision), 'recall': float(recall), 'f1': float(f1)
    }

In [22]:
# Early stopping by validation IoU
class EarlyStopper:
    def __init__(self, patience=6):
        self.patience = patience
        self.counter = 0
        self.best_score = None
    def step(self, score):
        if self.best_score is None or score > self.best_score:
            self.best_score = score
            self.counter = 0
            return True  # Improved
        else:
            self.counter += 1
            return False
    def should_stop(self):
        return self.counter >= self.patience

In [23]:
# Training and validation loop
def run_epoch(loader, model, criterion, optimizer=None, scaler=None, train=False):
    if train:
        model.train()
    else:
        model.eval()
    tot_loss, tot_metrics, n = 0.0, {'acc':0,'iou':0,'dice':0,'precision':0,'recall':0,'f1':0}, 0
    pbar = tqdm(loader, desc='Train' if train else 'Val', leave=False)
    for img, mask, _ in pbar:
        img, mask = img.to(device), mask.to(device)
        if train:
            optimizer.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=(device.type=='cuda')):
                out = model(img)
                loss = criterion(out, mask.unsqueeze(1))
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            with torch.no_grad():
                out = model(img)
                loss = criterion(out, mask.unsqueeze(1))
        with torch.no_grad():
            prob = torch.sigmoid(out)
            pred = (prob > 0.5).long().squeeze(1)
            metrics = calc_metrics(pred, mask)
        tot_loss += loss.item()
        for k in tot_metrics: tot_metrics[k] += metrics[k]
        n += 1
    avg_loss = tot_loss / n
    avg_metrics = {k: tot_metrics[k]/n for k in tot_metrics}
    return avg_loss, avg_metrics

In [None]:
# Train loop with early stopping on val IoU
n_epochs = 40  # or higher: stop early if IoU/Dice plateau
early_stopper = EarlyStopper(patience=6)
best_iou = 0.0
history = {'train_loss': [], 'val_loss': [], 'train_iou': [], 'val_iou': [], 'val_f1': [], 'lr': []}
model_save_path = 'results/best_unet_oilspill.pth'

for epoch in range(1, n_epochs+1):
    print(f"\nEpoch {epoch}/{n_epochs}")
    train_loss, train_metrics = run_epoch(train_loader, model, criterion, optimizer, scaler, train=True)
    val_loss, val_metrics = run_epoch(val_loader, model, criterion, train=False)
    scheduler.step(val_metrics['iou'])

    print(f"Train Loss: {train_loss:.4f}  mIoU: {train_metrics['iou']:.4f}  mDice: {train_metrics['dice']:.4f}  F1: {train_metrics['f1']:.4f}")
    print(f"Val   Loss: {val_loss:.4f}  mIoU: {val_metrics['iou']:.4f}  mDice: {val_metrics['dice']:.4f}  F1: {val_metrics['f1']:.4f}")
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_iou'].append(train_metrics['iou'])
    history['val_iou'].append(val_metrics['iou'])
    history['val_f1'].append(val_metrics['f1'])
    history['lr'].append(optimizer.param_groups[0]['lr'])

    improved = early_stopper.step(val_metrics['iou'])
    if improved:
        best_iou = val_metrics['iou']
        torch.save(model.state_dict(), model_save_path)
        print(f"✅ Saved new best model (val mIoU={best_iou:.4f})")
    if early_stopper.should_stop():
        print("⏹️ Early stopping: validation IoU did not improve for several epochs.")
        break

print(f"\nBest validation mIoU: {best_iou:.4f}")


Epoch 1/40


  with torch.cuda.amp.autocast(enabled=(device.type=='cuda')):
                                                        

Train Loss: 0.4424  mIoU: 0.6369  mDice: 0.7656  F1: 0.7656
Val   Loss: 0.3647  mIoU: 0.7067  mDice: 0.7959  F1: 0.7959
✅ Saved new best model (val mIoU=0.7067)

Epoch 2/40


                                                        

Train Loss: 0.4182  mIoU: 0.6661  mDice: 0.7830  F1: 0.7830
Val   Loss: 0.3421  mIoU: 0.7291  mDice: 0.8146  F1: 0.8146
✅ Saved new best model (val mIoU=0.7291)

Epoch 3/40


                                                        

Train Loss: 0.3802  mIoU: 0.6997  mDice: 0.8092  F1: 0.8092
Val   Loss: 0.3634  mIoU: 0.7201  mDice: 0.8017  F1: 0.8017

Epoch 4/40


                                                        

Train Loss: 0.3786  mIoU: 0.7000  mDice: 0.8089  F1: 0.8089
Val   Loss: 0.3545  mIoU: 0.7308  mDice: 0.8065  F1: 0.8065
✅ Saved new best model (val mIoU=0.7308)

Epoch 5/40


                                                            

Train Loss: 0.3593  mIoU: 0.7126  mDice: 0.8193  F1: 0.8193
Val   Loss: 0.3273  mIoU: 0.7307  mDice: 0.8124  F1: 0.8124

Epoch 6/40


                                                        

Train Loss: 0.3552  mIoU: 0.7216  mDice: 0.8229  F1: 0.8229
Val   Loss: 0.3010  mIoU: 0.7398  mDice: 0.8142  F1: 0.8142
✅ Saved new best model (val mIoU=0.7398)

Epoch 7/40


                                                        

Train Loss: 0.3413  mIoU: 0.7315  mDice: 0.8304  F1: 0.8304
Val   Loss: 0.2974  mIoU: 0.7431  mDice: 0.8235  F1: 0.8235
✅ Saved new best model (val mIoU=0.7431)

Epoch 8/40


                                                        

Train Loss: 0.3284  mIoU: 0.7403  mDice: 0.8375  F1: 0.8375
Val   Loss: 0.2802  mIoU: 0.7576  mDice: 0.8287  F1: 0.8287
✅ Saved new best model (val mIoU=0.7576)

Epoch 9/40


                                                        

Train Loss: 0.3290  mIoU: 0.7354  mDice: 0.8336  F1: 0.8336
Val   Loss: 0.3030  mIoU: 0.7502  mDice: 0.8216  F1: 0.8216

Epoch 10/40


                                                        

Train Loss: 0.3125  mIoU: 0.7510  mDice: 0.8430  F1: 0.8430
Val   Loss: 0.2835  mIoU: 0.7483  mDice: 0.8243  F1: 0.8243

Epoch 11/40


                                                        

Train Loss: 0.3025  mIoU: 0.7602  mDice: 0.8526  F1: 0.8526
Val   Loss: 0.2932  mIoU: 0.7527  mDice: 0.8195  F1: 0.8195

Epoch 12/40


                                                        

Train Loss: 0.2948  mIoU: 0.7655  mDice: 0.8553  F1: 0.8553
Val   Loss: 0.2578  mIoU: 0.7749  mDice: 0.8404  F1: 0.8404
✅ Saved new best model (val mIoU=0.7749)

Epoch 13/40


                                                        

Train Loss: 0.2880  mIoU: 0.7695  mDice: 0.8573  F1: 0.8573
Val   Loss: 0.2754  mIoU: 0.7648  mDice: 0.8341  F1: 0.8341

Epoch 14/40


                                                        

Train Loss: 0.2841  mIoU: 0.7742  mDice: 0.8608  F1: 0.8608
Val   Loss: 0.2814  mIoU: 0.7562  mDice: 0.8250  F1: 0.8250

Epoch 15/40


                                                        

Train Loss: 0.2753  mIoU: 0.7836  mDice: 0.8672  F1: 0.8672
Val   Loss: 0.2777  mIoU: 0.7652  mDice: 0.8290  F1: 0.8290

Epoch 16/40


                                                        

Train Loss: 0.2727  mIoU: 0.7845  mDice: 0.8666  F1: 0.8666
Val   Loss: 0.2829  mIoU: 0.7614  mDice: 0.8270  F1: 0.8270

Epoch 17/40


                                                        

Train Loss: 0.2364  mIoU: 0.8210  mDice: 0.8936  F1: 0.8936
Val   Loss: 0.2288  mIoU: 0.7891  mDice: 0.8505  F1: 0.8505
✅ Saved new best model (val mIoU=0.7891)

Epoch 18/40


                                                        

Train Loss: 0.2342  mIoU: 0.8174  mDice: 0.8885  F1: 0.8885
Val   Loss: 0.2556  mIoU: 0.7737  mDice: 0.8349  F1: 0.8349

Epoch 19/40


                                                        

Train Loss: 0.2124  mIoU: 0.8394  mDice: 0.9057  F1: 0.9057
Val   Loss: 0.2633  mIoU: 0.7840  mDice: 0.8421  F1: 0.8421

Epoch 20/40


                                                        

Train Loss: 0.2093  mIoU: 0.8415  mDice: 0.9057  F1: 0.9057
Val   Loss: 0.2199  mIoU: 0.8009  mDice: 0.8596  F1: 0.8596
✅ Saved new best model (val mIoU=0.8009)

Epoch 21/40


                                                        

Train Loss: 0.2283  mIoU: 0.8225  mDice: 0.8903  F1: 0.8903
Val   Loss: 0.2409  mIoU: 0.7866  mDice: 0.8437  F1: 0.8437

Epoch 22/40


                                                        

Train Loss: 0.2125  mIoU: 0.8359  mDice: 0.9032  F1: 0.9032
Val   Loss: 0.2332  mIoU: 0.8024  mDice: 0.8597  F1: 0.8597
✅ Saved new best model (val mIoU=0.8024)

Epoch 23/40


                                                        

Train Loss: 0.2013  mIoU: 0.8462  mDice: 0.9101  F1: 0.9101
Val   Loss: 0.2576  mIoU: 0.7813  mDice: 0.8385  F1: 0.8385

Epoch 24/40


                                                        

Train Loss: 0.2027  mIoU: 0.8460  mDice: 0.9082  F1: 0.9082
Val   Loss: 0.2111  mIoU: 0.7990  mDice: 0.8566  F1: 0.8566

Epoch 25/40


                                                        

Train Loss: 0.1970  mIoU: 0.8470  mDice: 0.9088  F1: 0.9088
Val   Loss: 0.2264  mIoU: 0.7916  mDice: 0.8488  F1: 0.8488

Epoch 26/40


                                                        

Train Loss: 0.1886  mIoU: 0.8556  mDice: 0.9143  F1: 0.9143
Val   Loss: 0.2322  mIoU: 0.7758  mDice: 0.8345  F1: 0.8345

Epoch 27/40


                                                        

Train Loss: 0.1779  mIoU: 0.8671  mDice: 0.9214  F1: 0.9214
Val   Loss: 0.2272  mIoU: 0.7892  mDice: 0.8479  F1: 0.8479

Epoch 28/40


                                                        

Train Loss: 0.1699  mIoU: 0.8708  mDice: 0.9255  F1: 0.9255
Val   Loss: 0.2156  mIoU: 0.7945  mDice: 0.8516  F1: 0.8516
⏹️ Early stopping: validation IoU did not improve for several epochs.

Best validation mIoU: 0.8024




: 

In [None]:
# Plot IoU, F1, and loss curves
plt.figure(figsize=(12,4))
plt.subplot(1,3,1)
plt.plot(history['train_loss'], label='Train')
plt.plot(history['val_loss'], label='Val')
plt.title('Loss')
plt.legend()
plt.subplot(1,3,2)
plt.plot(history['train_iou'], label='Train IoU')
plt.plot(history['val_iou'], label='Val IoU')
plt.title('Mean IoU')
plt.legend()
plt.subplot(1,3,3)
plt.plot(history['val_f1'], label='Val F1')
plt.title('Val F1')
plt.legend()
plt.tight_layout()
plt.show()

## Load Best Model and Evaluate on Validation/Test Sets

In [None]:
# Load best model
model.load_state_dict(torch.load(model_save_path, map_location=device))
model.eval()

In [None]:
# Evaluate on validation set with all metrics
def evaluate_model(model, loader):
    model.eval()
    tot_metrics = {'acc':0,'iou':0,'dice':0,'precision':0,'recall':0,'f1':0}
    n = 0
    with torch.no_grad():
        for img, mask, _ in tqdm(loader, desc='Eval'):
            img, mask = img.to(device), mask.to(device)
            out = model(img)
            prob = torch.sigmoid(out)
            pred = (prob > 0.5).long().squeeze(1)
            m = calc_metrics(pred, mask)
            for k in tot_metrics: tot_metrics[k] += m[k]
            n += 1
    return {k: tot_metrics[k]/n for k in tot_metrics}

val_metrics = evaluate_model(model, val_loader)
print("Validation metrics:")
for k,v in val_metrics.items(): print(f"{k}: {v:.4f}")

## Visualize Predictions

In [None]:
def visualize_predictions(model, loader, num_samples=3):
    model.eval()
    img_batch, mask_batch, names = next(iter(loader))
    img_batch, mask_batch = img_batch.to(device), mask_batch.to(device)
    with torch.no_grad():
        out = model(img_batch)
        prob = torch.sigmoid(out)
        pred = (prob > 0.5).long().squeeze(1)
    for i in range(min(num_samples, img_batch.size(0))):
        img = img_batch[i].cpu().permute(1,2,0).numpy()
        img = (img - img.min()) / (img.max()-img.min() + 1e-8)
        mask = mask_batch[i].cpu().numpy()
        pr = pred[i].cpu().numpy()
        plt.figure(figsize=(12,3))
        plt.subplot(1,3,1); plt.imshow(img); plt.title('Image'); plt.axis('off')
        plt.subplot(1,3,2); plt.imshow(mask, cmap='gray'); plt.title('GT Mask'); plt.axis('off')
        plt.subplot(1,3,3); plt.imshow(pr, cmap='gray'); plt.title('Pred Mask'); plt.axis('off')
        plt.tight_layout(); plt.show()
visualize_predictions(model, val_loader, num_samples=3)

## Notes
- **Stop training if validation IoU or F1 stops improving for >6 epochs (early stopping).**
- **IoU > 0.6 and F1 > 0.7** are good targets, but always check qualitative results for false alarms.
- If validation metrics still improve, continue training >20 epochs (don't stop arbitrarily at 20).
- Acceptable false alarm (FP) rate: **5-10%** (as per mentor advice).