In [2]:
import os
import glob
import cv2
import numpy as np
import pandas as pd
from PIL import Image
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# === Section 2: Load CoralSeg Dataset ===

BASE_PATH = "..\\benthic_data"
CORALSEG_PATH = os.path.join(BASE_PATH, "Coralseg")

splits = ["train", "val", "test"]
coralseg_data = []

for split in splits:
    img_dir = os.path.join(CORALSEG_PATH, split, "Image")
    mask_dir = os.path.join(CORALSEG_PATH, split, "Mask")

    img_files = sorted(glob.glob(os.path.join(img_dir, "*.jpg")))
    mask_files = sorted(glob.glob(os.path.join(mask_dir, "*.png")))

    # Match by filename
    for img_path in img_files:
        fname = os.path.basename(img_path).replace(".jpg", "")
        mask_path = os.path.join(mask_dir, fname + ".png")
        if os.path.exists(mask_path):
            coralseg_data.append({
                "dataset": "CoralSeg",
                "split": split,
                "image_path": img_path,
                "mask_path": mask_path
            })

coralseg_df = pd.DataFrame(coralseg_data)
print(f"✅ CoralSeg loaded: {len(coralseg_df)} total samples")
print(coralseg_df.sample(3))

✅ CoralSeg loaded: 4922 total samples
       dataset  split                                         image_path  \
3968  CoralSeg    val  ..\benthic_data\Coralseg\val\Image\FR5_6656_71...   
1982  CoralSeg  train  ..\benthic_data\Coralseg\train\Image\PAL69_972...   
1098  CoralSeg  train  ..\benthic_data\Coralseg\train\Image\PAL132_25...   

                                              mask_path  
3968  ..\benthic_data\Coralseg\val\Mask\FR5_6656_716...  
1982  ..\benthic_data\Coralseg\train\Mask\PAL69_9728...  
1098  ..\benthic_data\Coralseg\train\Mask\PAL132_256...  


In [4]:
# === Section 3: Load reef_support datasets ===

REEF_SUPPORT_PATH = os.path.join(BASE_PATH, "reef_support")

reef_data = []

# Loop through each reef site
for site in sorted(os.listdir(REEF_SUPPORT_PATH)):
    site_dir = os.path.join(REEF_SUPPORT_PATH, site)
    img_dir = os.path.join(site_dir, "images")
    stitched_dir = os.path.join(site_dir, "masks_stitched")
    masks_dir = os.path.join(site_dir, "masks")

    if not os.path.isdir(img_dir):
        continue

    print(f"📂 Processing site: {site}")

    # Prefer stitched masks (cleaner)
    stitched_masks = sorted(glob.glob(os.path.join(stitched_dir, "*.png")))
    for mask_path in stitched_masks:
        fname = os.path.basename(mask_path).replace("_mask.png", "").replace(".png", "")
        img_candidates = glob.glob(os.path.join(img_dir, f"{fname}.*"))
        if len(img_candidates) == 0:
            continue
        img_path = img_candidates[0]

        reef_data.append({
            "dataset": site,
            "split": "train",  # no official split, will randomize later
            "image_path": img_path,
            "mask_path": mask_path
        })

reef_df = pd.DataFrame(reef_data)
print(f"✅ reef_support loaded: {len(reef_df)} samples across {reef_df['dataset'].nunique()} sites")
reef_df.sample(5)


📂 Processing site: SEAFLOWER_BOLIVAR
📂 Processing site: SEAFLOWER_COURTOWN
📂 Processing site: SEAVIEW_ATL
📂 Processing site: SEAVIEW_IDN_PHL
📂 Processing site: SEAVIEW_PAC_AUS
📂 Processing site: SEAVIEW_PAC_USA
📂 Processing site: TETES_PROVIDENCIA
📂 Processing site: UNAL_BLEACHING_TAYRONA
✅ reef_support loaded: 3311 samples across 8 sites


Unnamed: 0,dataset,split,image_path,mask_path
769,SEAVIEW_ATL,train,..\benthic_data\reef_support\SEAVIEW_ATL\image...,..\benthic_data\reef_support\SEAVIEW_ATL\masks...
1661,SEAVIEW_PAC_AUS,train,..\benthic_data\reef_support\SEAVIEW_PAC_AUS\i...,..\benthic_data\reef_support\SEAVIEW_PAC_AUS\m...
3307,UNAL_BLEACHING_TAYRONA,train,..\benthic_data\reef_support\UNAL_BLEACHING_TA...,..\benthic_data\reef_support\UNAL_BLEACHING_TA...
1893,SEAVIEW_PAC_AUS,train,..\benthic_data\reef_support\SEAVIEW_PAC_AUS\i...,..\benthic_data\reef_support\SEAVIEW_PAC_AUS\m...
1600,SEAVIEW_IDN_PHL,train,..\benthic_data\reef_support\SEAVIEW_IDN_PHL\i...,..\benthic_data\reef_support\SEAVIEW_IDN_PHL\m...


In [5]:
import cv2
from tqdm import tqdm

SAVE_UNION_DIR = "../coral_project_outputs/union_masks"
os.makedirs(SAVE_UNION_DIR, exist_ok=True)

merged_data = []

def make_union_mask(mask_dir, target_name):
    """Combine all *_mask_*.png files into one binary union mask."""
    masks = glob.glob(os.path.join(mask_dir, f"{target_name}_mask_*.png"))
    if not masks:
        return None

    combined = None
    for mpath in masks:
        mask = cv2.imread(mpath, cv2.IMREAD_GRAYSCALE)
        if mask is None:
            continue
        mask = (mask > 0).astype(np.uint8)
        combined = mask if combined is None else np.maximum(combined, mask)

    if combined is None:
        return None

    save_path = os.path.join(SAVE_UNION_DIR, f"{target_name}_union.png")
    cv2.imwrite(save_path, combined * 255)
    return save_path

# Process reef_support sites (including union masks)
for _, row in tqdm(reef_df.iterrows(), total=len(reef_df)):
    img_path = row["image_path"]
    site_dir = os.path.dirname(os.path.dirname(img_path))
    masks_dir = os.path.join(site_dir, "masks")
    fname = os.path.splitext(os.path.basename(img_path))[0]

    # Try to create or find best mask
    if os.path.exists(os.path.join(site_dir, "masks_stitched", f"{fname}_mask.png")):
        mask_path = os.path.join(site_dir, "masks_stitched", f"{fname}_mask.png")
    else:
        mask_path = make_union_mask(masks_dir, fname)

    if mask_path and os.path.exists(mask_path):
        merged_data.append({
            "dataset": row["dataset"],
            "split": "train",
            "image_path": img_path,
            "mask_path": mask_path
        })

# Add CoralSeg dataset
for _, row in tqdm(coralseg_df.iterrows(), total=len(coralseg_df)):
    merged_data.append({
        "dataset": row["dataset"],
        "split": row["split"],
        "image_path": row["image_path"],
        "mask_path": row["mask_path"]
    })

merged_df = pd.DataFrame(merged_data)

# Clean — remove empties
def valid_mask(path):
    if not os.path.exists(path):
        return False
    mask = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    return mask is not None and mask.sum() > 0

merged_df = merged_df[merged_df["mask_path"].apply(valid_mask)].reset_index(drop=True)

print(f"✅ Final merged dataset size: {len(merged_df)} samples")
print(merged_df.groupby("dataset").size())

# Save CSV metadata for reuse
csv_path = "../coral_project_outputs/merged_dataset.csv"
merged_df.to_csv(csv_path, index=False)
print(f"💾 Saved metadata to {csv_path}")


100%|██████████| 3311/3311 [00:00<00:00, 5692.17it/s]
100%|██████████| 4922/4922 [00:00<00:00, 32095.84it/s]


✅ Final merged dataset size: 3276 samples
dataset
SEAFLOWER_BOLIVAR         245
SEAFLOWER_COURTOWN        241
SEAVIEW_ATL               651
SEAVIEW_IDN_PHL           466
SEAVIEW_PAC_AUS           657
SEAVIEW_PAC_USA           276
TETES_PROVIDENCIA         105
UNAL_BLEACHING_TAYRONA    635
dtype: int64
💾 Saved metadata to ../coral_project_outputs/merged_dataset.csv


In [1]:
merged_df

NameError: name 'merged_df' is not defined

In [6]:
csv_path = "../coral_project_outputs/merged_dataset.csv"

merged_df = pd.read_csv(csv_path)
print(f"✅ Reloaded merged dataset: {len(merged_df)} samples")
print(merged_df.groupby("dataset").size())

# Optional sanity check
sample = merged_df.sample(3, random_state=42)
for _, row in sample.iterrows():
    assert os.path.exists(row["image_path"]), f"Missing image {row['image_path']}"
    assert os.path.exists(row["mask_path"]), f"Missing mask {row['mask_path']}"
print("✅ Random sample files verified")


✅ Reloaded merged dataset: 3276 samples
dataset
SEAFLOWER_BOLIVAR         245
SEAFLOWER_COURTOWN        241
SEAVIEW_ATL               651
SEAVIEW_IDN_PHL           466
SEAVIEW_PAC_AUS           657
SEAVIEW_PAC_USA           276
TETES_PROVIDENCIA         105
UNAL_BLEACHING_TAYRONA    635
dtype: int64
✅ Random sample files verified


In [9]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import numpy as np
from PIL import Image
import torch

# -------------------
# Split data
train_df, val_df = train_test_split(merged_df, test_size=0.2, random_state=42, stratify=None)
print(f"📊 Train: {len(train_df)} | Val: {len(val_df)}")

# -------------------
# Augmentations
train_transform = A.Compose([
    A.Resize(256, 256),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    A.Normalize(),
    ToTensorV2(),
])

val_transform = A.Compose([
    A.Resize(256, 256),
    A.Normalize(),
    ToTensorV2(),
])

# -------------------
# Custom Dataset
class CoralDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = self.df.loc[idx, "image_path"]
        mask_path = self.df.loc[idx, "mask_path"]

        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"))  # grayscale
        mask = (mask > 0).astype(np.float32)  # binary coral/non-coral

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented["image"]
            mask = augmented["mask"].unsqueeze(0)

        return image, mask


📊 Train: 2620 | Val: 656


In [13]:

# -------------------
# Datasets & Dataloaders
train_dataset = CoralDataset(train_df, transform=train_transform)
val_dataset = CoralDataset(val_df, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)

# # -------------------
# # Quick check
# imgs, masks = next(iter(train_loader))
# print(f"✅ Dataloader OK — batch shapes: imgs {imgs.shape}, masks {masks.shape}")


In [14]:
merged_df = pd.read_csv("../coral_project_outputs/merged_dataset.csv")


In [None]:
# === Section 6: Model setup and training (best config) ===
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
from tqdm import tqdm
import copy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# -------------------
# Model
model = smp.Unet(
    encoder_name="efficientnet-b0",
    encoder_weights="imagenet",
    in_channels=3,
    classes=1,
    activation=None
).to(device)

# -------------------
# Loss: BCE + Dice combo
bce_loss = nn.BCEWithLogitsLoss()
dice_loss = smp.losses.DiceLoss(mode="binary")

def criterion(y_pred, y_true):
    return 0.5 * bce_loss(y_pred, y_true) + 0.5 * dice_loss(y_pred, y_true)

# -------------------
# Metrics
def iou_score(y_pred, y_true, threshold=0.5):
    y_pred_bin = (torch.sigmoid(y_pred) > threshold).float()
    intersection = (y_pred_bin * y_true).sum()
    union = y_pred_bin.sum() + y_true.sum() - intersection
    return (intersection / union).item() if union > 0 else 1.0

def pixel_accuracy(y_pred, y_true, threshold=0.5):
    y_pred_bin = (torch.sigmoid(y_pred) > threshold).float()
    correct = (y_pred_bin == y_true).float().sum()
    total = torch.numel(y_true)
    return (correct / total).item()

# -------------------
# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# -------------------
# Training loop
def train_model(model, train_loader, val_loader, criterion, optimizer, epochs=30, patience=5, save_path="../coral_project_outputs/best_merged_model.pth"):
    best_iou = 0.0
    best_model_wts = copy.deepcopy(model.state_dict())
    patience_counter = 0

    for epoch in range(epochs):
        model.train()
        train_loss, train_iou, train_acc = 0, 0, 0

        for imgs, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} - Train"):
            imgs, masks = imgs.to(device), masks.to(device)

            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_iou += iou_score(outputs, masks)
            train_acc += pixel_accuracy(outputs, masks)

        # Validation
        model.eval()
        val_loss, val_iou, val_acc = 0, 0, 0
        with torch.no_grad():
            for imgs, masks in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} - Val"):
                imgs, masks = imgs.to(device), masks.to(device)
                outputs = model(imgs)
                loss = criterion(outputs, masks)
                val_loss += loss.item()
                val_iou += iou_score(outputs, masks)
                val_acc += pixel_accuracy(outputs, masks)

        # Averages
        train_loss /= len(train_loader)
        train_iou /= len(train_loader)
        train_acc /= len(train_loader)
        val_loss /= len(val_loader)
        val_iou /= len(val_loader)
        val_acc /= len(val_loader)

        print(f"\nEpoch {epoch+1}/{epochs}")
        print(f" Train Loss: {train_loss:.4f} | IoU: {train_iou:.4f} | Acc: {train_acc:.4f}")
        print(f" Val   Loss: {val_loss:.4f} | IoU: {val_iou:.4f} | Acc: {val_acc:.4f}")

        # Save best model
        if val_iou > best_iou:
            best_iou = val_iou
            best_model_wts = copy.deepcopy(model.state_dict())
            torch.save(model.state_dict(), save_path)
            print(f"  ✅ Saved new best model with IoU={best_iou:.4f}")
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"⏹️ Early stopping at epoch {epoch+1}")
                break

    model.load_state_dict(best_model_wts)
    print(f"Training complete. Best IoU: {best_iou:.4f}")
    return model

# -------------------
# Run training
final_model = train_model(
    model,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    epochs=25,
    patience=4
)


Using device: cuda


Epoch 1/25 - Train:   0%|          | 0/655 [00:00<?, ?it/s]

In [None]:
# ===== Datasets & Dataloaders (optimized for CUDA in notebooks) =====
import os, platform, time
import torch
from torch.utils.data import DataLoader
import cv2

# Assumes: CoralDataset, train_df, val_df, train_transform, val_transform exist
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

# Prevent thread oversubscription (helps with OpenCV and albumentations)
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
try:
    cv2.setNumThreads(0)
except Exception:
    pass

# Throughput knobs (tune if needed)
NUM_WORKERS = 2
PERSISTENT = True
PREFETCH   = 2
BATCH_SIZE = 8

def make_loader(df, transform, batch_size=BATCH_SIZE, shuffle=False):
    return DataLoader(
        dataset=CoralDataset(df, transform=transform),
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=NUM_WORKERS,
        pin_memory=True,                 # benefit with CUDA when using non_blocking=True
        drop_last=False,
        persistent_workers=PERSISTENT,
        prefetch_factor=PREFETCH,
    )

train_loader = make_loader(train_df, train_transform, shuffle=True)
val_loader   = make_loader(val_df,   val_transform,   shuffle=False)

# (Optional) quick timing probe
def time_first_batch(loader, label):
    t0 = time.perf_counter()
    it = iter(loader)
    t1 = time.perf_counter()
    batch = next(it)
    t2 = time.perf_counter()
    print(f"[{label}] iter(): {(t1 - t0):.3f}s | first batch: {(t2 - t1):.3f}s")

if torch.cuda.is_available():
    print("CUDA device:", torch.cuda.get_device_name(0))
time_first_batch(train_loader, "train")
time_first_batch(val_loader,   "val")

CUDA device: NVIDIA GeForce RTX 3080
