In [1]:
import os
import glob
import cv2
import numpy as np
import pandas as pd
from PIL import Image
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
from tqdm import tqdm
import copy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# === Section 2: Load CoralSeg Dataset ===

BASE_PATH = "../benthic_data"
CORALSEG_PATH = os.path.join(BASE_PATH, "Coralseg")

splits = ["train", "val", "test"]
coralseg_data = []

for split in splits:
    img_dir = os.path.join(CORALSEG_PATH, split, "Image")
    mask_dir = os.path.join(CORALSEG_PATH, split, "Mask")

    img_files = sorted(glob.glob(os.path.join(img_dir, "*.jpg")))
    mask_files = sorted(glob.glob(os.path.join(mask_dir, "*.png")))

    # Match by filename
    for img_path in img_files:
        fname = os.path.basename(img_path).replace(".jpg", "")
        mask_path = os.path.join(mask_dir, fname + ".png")
        if os.path.exists(mask_path):
            coralseg_data.append({
                "dataset": "CoralSeg",
                "split": split,
                "image_path": img_path,
                "mask_path": mask_path
            })

coralseg_df = pd.DataFrame(coralseg_data)
print(f"✅ CoralSeg loaded: {len(coralseg_df)} total samples")
print(coralseg_df.sample(3))


✅ CoralSeg loaded: 4922 total samples
       dataset  split                                         image_path  \
2784  CoralSeg  train  ../benthic_data\Coralseg\train\Image\PALWave14...   
2209  CoralSeg  train  ../benthic_data\Coralseg\train\Image\PALStrawn...   
3950  CoralSeg    val  ../benthic_data\Coralseg\val\Image\FR3_8704_92...   

                                              mask_path  
2784  ../benthic_data\Coralseg\train\Mask\PALWave14_...  
2209  ../benthic_data\Coralseg\train\Mask\PALStrawn_...  
3950  ../benthic_data\Coralseg\val\Mask\FR3_8704_921...  


In [3]:
# === Section 3: Load reef_support datasets ===

REEF_SUPPORT_PATH = os.path.join(BASE_PATH, "reef_support")
print(REEF_SUPPORT_PATH)
reef_data = []

# Loop through each reef site
for site in sorted(os.listdir(REEF_SUPPORT_PATH)):
    site_dir = os.path.join(REEF_SUPPORT_PATH, site)
    img_dir = os.path.join(site_dir, "images")
    stitched_dir = os.path.join(site_dir, "masks_stitched")
    masks_dir = os.path.join(site_dir, "masks")

    if not os.path.isdir(img_dir):
        continue

    print(f"📂 Processing site: {site}")

    # Prefer stitched masks (cleaner)
    stitched_masks = sorted(glob.glob(os.path.join(stitched_dir, "*.png")))
    for mask_path in stitched_masks:
        fname = os.path.basename(mask_path).replace("_mask.png", "").replace(".png", "")
        img_candidates = glob.glob(os.path.join(img_dir, f"{fname}.*"))
        if len(img_candidates) == 0:
            continue
        img_path = img_candidates[0]

        reef_data.append({
            "dataset": site,
            "split": "train",  # no official split, will randomize later
            "image_path": img_path,
            "mask_path": mask_path
        })

reef_df = pd.DataFrame(reef_data)
print(f"✅ reef_support loaded: {len(reef_df)} samples across {reef_df['dataset'].nunique()} sites")
reef_df.sample(5)


../benthic_data\reef_support
📂 Processing site: SEAFLOWER_BOLIVAR
📂 Processing site: SEAFLOWER_COURTOWN
📂 Processing site: SEAVIEW_ATL
📂 Processing site: SEAVIEW_IDN_PHL
📂 Processing site: SEAVIEW_PAC_AUS
📂 Processing site: SEAVIEW_PAC_USA
📂 Processing site: TETES_PROVIDENCIA
📂 Processing site: UNAL_BLEACHING_TAYRONA
✅ reef_support loaded: 3311 samples across 8 sites


Unnamed: 0,dataset,split,image_path,mask_path
2766,UNAL_BLEACHING_TAYRONA,train,../benthic_data\reef_support\UNAL_BLEACHING_TA...,../benthic_data\reef_support\UNAL_BLEACHING_TA...
677,SEAVIEW_ATL,train,../benthic_data\reef_support\SEAVIEW_ATL\image...,../benthic_data\reef_support\SEAVIEW_ATL\masks...
1954,SEAVIEW_PAC_AUS,train,../benthic_data\reef_support\SEAVIEW_PAC_AUS\i...,../benthic_data\reef_support\SEAVIEW_PAC_AUS\m...
978,SEAVIEW_ATL,train,../benthic_data\reef_support\SEAVIEW_ATL\image...,../benthic_data\reef_support\SEAVIEW_ATL\masks...
2650,TETES_PROVIDENCIA,train,../benthic_data\reef_support\TETES_PROVIDENCIA...,../benthic_data\reef_support\TETES_PROVIDENCIA...


In [4]:
SAVE_UNION_DIR = "../coral_project_outputs/union_masks"
os.makedirs(SAVE_UNION_DIR, exist_ok=True)


In [5]:
SAVE_UNION_DIR = "../coral_project_outputs/union_masks"
os.makedirs(SAVE_UNION_DIR, exist_ok=True)

merged_data = []

def make_union_mask(mask_dir, target_name):
    """Combine all *_mask_*.png files into one binary union mask."""
    masks = glob.glob(os.path.join(mask_dir, f"{target_name}_mask_*.png"))
    if not masks:
        return None

    combined = None
    for mpath in masks:
        mask = cv2.imread(mpath, cv2.IMREAD_GRAYSCALE)
        if mask is None:
            continue
        mask = (mask > 0).astype(np.uint8)
        combined = mask if combined is None else np.maximum(combined, mask)

    if combined is None:
        return None

    save_path = os.path.join(SAVE_UNION_DIR, f"{target_name}_union.png")
    cv2.imwrite(save_path, combined * 255)
    return save_path

# Process reef_support sites (including union masks)
for _, row in tqdm(reef_df.iterrows(), total=len(reef_df)):
    img_path = row["image_path"]
    site_dir = os.path.dirname(os.path.dirname(img_path))
    masks_dir = os.path.join(site_dir, "masks")
    fname = os.path.splitext(os.path.basename(img_path))[0]

    # Try to create or find best mask
    if os.path.exists(os.path.join(site_dir, "masks_stitched", f"{fname}_mask.png")):
        mask_path = os.path.join(site_dir, "masks_stitched", f"{fname}_mask.png")
    else:
        mask_path = make_union_mask(masks_dir, fname)

    if mask_path and os.path.exists(mask_path):
        merged_data.append({
            "dataset": row["dataset"],
            "split": "train",
            "image_path": img_path,
            "mask_path": mask_path
        })
print(1)
# # Add CoralSeg dataset
# for _, row in tqdm(coralseg_df.iterrows(), total=len(coralseg_df)):
#     merged_data.append({
#         "dataset": row["dataset"],
#         "split": row["split"],
#         "image_path": row["image_path"],
#         "mask_path": row["mask_path"]
#     })
print(2)
merged_df = pd.DataFrame(merged_data)

# Clean — remove empties
def valid_mask(path):
    if not os.path.exists(path):
        return False
    mask = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    return mask is not None and mask.sum() > 0

merged_df = merged_df[merged_df["mask_path"].apply(valid_mask)].reset_index(drop=True)

print(f"✅ Final merged dataset size: {len(merged_df)} samples")
print(merged_df.groupby("dataset").size())

# Save CSV metadata for reuse
CSV_PATH = "../coral_project_outputs/merged_dataset.csv"
merged_df.to_csv(CSV_PATH, index=False)
print(f"💾 Saved metadata to {CSV_PATH}")


100%|██████████| 3311/3311 [00:00<00:00, 4909.57it/s]


1
2
✅ Final merged dataset size: 3276 samples
dataset
SEAFLOWER_BOLIVAR         245
SEAFLOWER_COURTOWN        241
SEAVIEW_ATL               651
SEAVIEW_IDN_PHL           466
SEAVIEW_PAC_AUS           657
SEAVIEW_PAC_USA           276
TETES_PROVIDENCIA         105
UNAL_BLEACHING_TAYRONA    635
dtype: int64
💾 Saved metadata to ../coral_project_outputs/merged_dataset.csv


In [6]:
csv_path = "../coral_project_outputs/merged_dataset.csv"

merged_df = pd.read_csv(csv_path)
print(f"✅ Reloaded merged dataset: {len(merged_df)} samples")
print(merged_df.groupby("dataset").size())

# Optional sanity check
sample = merged_df.sample(3, random_state=42)
for _, row in sample.iterrows():
    assert os.path.exists(row["image_path"]), f"Missing image {row['image_path']}"
    assert os.path.exists(row["mask_path"]), f"Missing mask {row['mask_path']}"
print("✅ Random sample files verified")


✅ Reloaded merged dataset: 3276 samples
dataset
SEAFLOWER_BOLIVAR         245
SEAFLOWER_COURTOWN        241
SEAVIEW_ATL               651
SEAVIEW_IDN_PHL           466
SEAVIEW_PAC_AUS           657
SEAVIEW_PAC_USA           276
TETES_PROVIDENCIA         105
UNAL_BLEACHING_TAYRONA    635
dtype: int64
✅ Random sample files verified


In [7]:
# -------------------
# Split data
train_df, val_df = train_test_split(merged_df, test_size=0.2, random_state=42, stratify=None)
print(f"📊 Train: {len(train_df)} | Val: {len(val_df)}")

# -------------------
# Augmentations
train_transform = A.Compose([
    A.Resize(256, 256),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    A.Normalize(),
    ToTensorV2(),
])

val_transform = A.Compose([
    A.Resize(256, 256),
    A.Normalize(),
    ToTensorV2(),
])

📊 Train: 2620 | Val: 656


In [8]:


# -------------------
# Custom Dataset
class CoralDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = self.df.loc[idx, "image_path"]
        mask_path = self.df.loc[idx, "mask_path"]

        # image = np.array(Image.open(img_path).convert("RGB"))
        # mask = np.array(Image.open(mask_path).convert("L"))  # grayscale
        # mask = (mask > 0).astype(np.float32)  # binary coral/non-coral
        image_bgr = cv2.imread(img_path, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = (mask > 0).astype('uint8')


        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented["image"]
            mask = augmented["mask"].unsqueeze(0).float().div(255)

        return image, mask




In [9]:
# ===== Datasets & Dataloaders (Notebook-friendly, CUDA-aware) =====
import os, platform, time
import torch
from torch.utils.data import DataLoader
import cv2

# --- Assumes these already exist in your notebook ---
# CoralDataset, train_df, val_df, train_transform, val_transform

# 1) Basic sanity + CUDA setup
assert torch.cuda.is_available(), "CUDA not available but you said you're using it."
device = torch.device("cuda")
torch.backends.cudnn.benchmark = True  # speed up on fixed-size inputs

# 2) Keep thread oversubscription in check (helps a lot with OpenCV + PyTorch workers)
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
try:
    cv2.setNumThreads(0)
except Exception:
    pass

# 3) Notebook/Windows-aware worker policy:
#    - In notebooks (and on Windows spawn), many workers make the *first batch* very slow.
#    - Start conservative; you can bump to 2 later if iter time looks good.
ON_WINDOWS = platform.system() == "Windows"
IN_NOTEBOOK = True  # you're in a notebook here

# Toggle: set to "dev" for snappy startup during experimentation,
# switch to "train" when doing longer runs (can set to 2 workers).
MODE = "dev"  # "dev" or "train"

if MODE == "dev":
    num_workers = 0  # fastest startup in notebooks
    persistent_workers = False
    prefetch_factor = None
else:
    # modest, safe default for notebooks; increase to 2 if disk & CPU keep up
    num_workers = 1 if (IN_NOTEBOOK or ON_WINDOWS) else min(4, max(1, os.cpu_count() // 2))
    persistent_workers = num_workers > 0
    prefetch_factor = 2 if num_workers > 0 else None

# 4) Dataloader builder
def make_loader(df, transform, batch_size=8, shuffle=False):
    kwargs = dict(
        dataset=CoralDataset(df, transform=transform),
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=True,               # good with CUDA; remember to use non_blocking=True on .to(device)
        drop_last=False,
    )
    # only set these when workers > 0 (PyTorch will error otherwise)
    if num_workers > 0:
        kwargs.update(dict(
            persistent_workers=persistent_workers,
            prefetch_factor=prefetch_factor,
        ))
    return DataLoader(**kwargs)

# 5) Create loaders
BATCH_SIZE = 8
train_loader = make_loader(train_df, train_transform, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = make_loader(val_df,   val_transform,   batch_size=BATCH_SIZE, shuffle=False)

# 6) Quick timing probe to understand where the wait happens
def time_first_batch(loader, label):
    t0 = time.perf_counter()
    it = iter(loader)
    t1 = time.perf_counter()
    batch = next(it)
    t2 = time.perf_counter()
    print(f"[{label}] iter() construction: {(t1 - t0):.3f}s | first batch: {(t2 - t1):.3f}s")
    # Optional: move to GPU once to verify pinning benefits (use non_blocking)
    imgs, masks = batch
    if isinstance(imgs, torch.Tensor):
        _ = imgs.to(device, non_blocking=True)
    if isinstance(masks, torch.Tensor):
        _ = masks.to(device, non_blocking=True)

print("CUDA device:", torch.cuda.get_device_name(device.index))
time_first_batch(train_loader, "train")
time_first_batch(val_loader,   "val")


CUDA device: NVIDIA GeForce RTX 3080
[train] iter() construction: 0.002s | first batch: 0.942s
[val] iter() construction: 0.000s | first batch: 0.688s


In [10]:
# =========================
# Section 6 — Model setup, losses, metrics, AMP training
# =========================
import os, math, time
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

# ---- Device & performance knobs ----
assert torch.cuda.is_available(), "CUDA not available; please enable GPU."
device = torch.device("cuda")
torch.backends.cudnn.benchmark = True
try:
    torch.set_float32_matmul_precision("high")  # PyTorch >= 2.0
except Exception:
    pass

# Ensure a model exists and move to GPU
assert "model" in globals(), "Expected a `model` to be defined earlier in the notebook."
model = model.to(device)

# ---- Optimizer & (optional) scheduler ----
LR = 3e-4  # adjust if you like
optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=1e-4)
# (Optional) scheduler example:
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5, verbose=True)

# ---- Metrics computed on logits (apply sigmoid only for metrics) ----
@torch.no_grad()
def binary_metrics_from_logits(logits, targets, thr=0.5):
    """
    logits: (B,1,H,W) raw outputs; targets: (B,1,H,W) float {0,1}
    returns: (IoU, Accuracy)
    """
    probs = torch.sigmoid(logits)
    preds = (probs > thr)
    targets_b = (targets > 0.5)

    inter = (preds & targets_b).sum(dim=(1,2,3)).float()
    union = (preds | targets_b).sum(dim=(1,2,3)).float()
    iou = torch.where(union > 0, inter / union, torch.zeros_like(union)).mean().item()

    acc = (preds == targets_b).float().mean().item()
    return iou, acc

# ---- Class imbalance: estimate pos_weight for BCEWithLogitsLoss ----
@torch.no_grad()
def estimate_pos_weight(loader, max_batches=50, device=device):
    pos = 0.0
    neg = 0.0
    for i, (_, y) in enumerate(loader):
        y = y.to(device, non_blocking=True).float()
        pos += y.sum().item()
        neg += y.numel() - y.sum().item()
        if i + 1 >= max_batches:
            break
    pos = max(pos, 1.0); neg = max(neg, 1.0)
    pw = neg / pos
    print(f"[pos_weight] positive fraction ≈ {pos/(pos+neg):.6f}  ->  pos_weight={pw:.2f}")
    return torch.tensor([pw], device=device, dtype=torch.float32)

pos_weight = estimate_pos_weight(train_loader)  # run once

# ---- Losses: BCE logits (with pos_weight) + Dice on probabilities ----
bce_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

def dice_loss_from_probs(probs, targets, eps=1e-7):
    targets = targets.float()
    inter = (probs * targets).sum(dim=(1,2,3))
    union = probs.sum(dim=(1,2,3)) + targets.sum(dim=(1,2,3))
    dice = (2 * inter + eps) / (union + eps)
    return 1 - dice.mean()

def criterion(logits, targets):
    probs = torch.sigmoid(logits)
    return 0.5 * bce_loss(logits, targets) + 0.5 * dice_loss_from_probs(probs, targets)

# ---- AMP training & validation steps ----
scaler = torch.cuda.amp.GradScaler()

def train_one_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = total_iou = total_acc = 0.0
    n_batches = 0

    pbar = tqdm(loader, desc="Train", leave=False)
    for imgs, masks in pbar:
        imgs  = imgs.to(device, non_blocking=True)
        masks = masks.to(device, non_blocking=True).float()

        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(dtype=torch.float16):
            logits = model(imgs)
            loss = criterion(logits, masks)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        iou, acc = binary_metrics_from_logits(logits, masks)
        total_loss += loss.item(); total_iou += iou; total_acc += acc; n_batches += 1

    return (total_loss / n_batches, total_iou / n_batches, total_acc / n_batches)

@torch.no_grad()
def validate(model, loader, device):
    model.eval()
    total_loss = total_iou = total_acc = 0.0
    n_batches = 0

    pbar = tqdm(loader, desc="Val", leave=False)
    for imgs, masks in pbar:
        imgs  = imgs.to(device, non_blocking=True)
        masks = masks.to(device, non_blocking=True).float()
        with torch.cuda.amp.autocast(dtype=torch.float16):
            logits = model(imgs)
            loss = criterion(logits, masks)
        iou, acc = binary_metrics_from_logits(logits, masks)
        total_loss += loss.item(); total_iou += iou; total_acc += acc; n_batches += 1

    return (total_loss / n_batches, total_iou / n_batches, total_acc / n_batches)

# ---- Full training loop with early stopping & best-model saving ----
def train_model(model, train_loader, val_loader, optimizer, epochs=150, patience=20, save_path="best_model.pth"):
    best_iou = -1.0
    best_state = None
    stale = 0

    print(f"Using device: {torch.cuda.get_device_name(0)}  |  LR={LR:.1e}")
    for epoch in range(1, epochs + 1):
        # Train
        t0 = time.time()
        tr_loss, tr_iou, tr_acc = train_one_epoch(model, train_loader, optimizer, device)
        # Validate
        va_loss, va_iou, va_acc = validate(model, val_loader, device)
        t1 = time.time()

        # (Optional) scheduler on val loss:
        # scheduler.step(va_loss)

        # Logging
        print(f"\nEpoch {epoch}/{epochs}")
        print(f" Train Loss: {tr_loss:.4f} | IoU: {tr_iou:.4f} | Acc: {tr_acc:.4f}")
        print(f" Val   Loss: {va_loss:.4f} | IoU: {va_iou:.4f} | Acc: {va_acc:.4f}")
        print(f"  ⏱️  epoch time: {t1 - t0:.1f}s")

        # Early stopping on IoU (use what you care about most)
        improved = va_iou > best_iou + 1e-6
        if improved:
            best_iou = va_iou
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            torch.save(best_state, save_path)
            print(f"  ✅ Saved new best model with IoU={best_iou:.4f} -> {save_path}")
            stale = 0
        else:
            stale += 1
            if stale >= patience:
                print(f"Early stopping (no improvement for {patience} epochs).")
                break

    # Load best weights back into the model
    if best_state is not None:
        model.load_state_dict(best_state)
    return model

# -------------------
# Run training (edit epochs/patience as you wish)
final_model = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    epochs=150,
    patience=20,
    save_path="best_model.pth",
)


AssertionError: Expected a `model` to be defined earlier in the notebook.

In [11]:
# === Section 6: Model setup and training (best config) ===
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
from tqdm import tqdm
import copy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# -------------------
# Model
model = smp.Unet(
    encoder_name="efficientnet-b0",
    encoder_weights="imagenet",
    in_channels=3,
    classes=1,
    activation=None
).to(device)

# -------------------
# Loss: BCE + Dice combo
bce_loss = nn.BCEWithLogitsLoss()
dice_loss = smp.losses.DiceLoss(mode="binary")

def criterion(y_pred, y_true):
    return 0.5 * bce_loss(y_pred, y_true) + 0.5 * dice_loss(y_pred, y_true)

# -------------------
# Metrics
def iou_score(y_pred, y_true, threshold=0.5):
    y_pred_bin = (torch.sigmoid(y_pred) > threshold).float()
    intersection = (y_pred_bin * y_true).sum()
    union = y_pred_bin.sum() + y_true.sum() - intersection
    return (intersection / union).item() if union > 0 else 1.0

def pixel_accuracy(y_pred, y_true, threshold=0.5):
    y_pred_bin = (torch.sigmoid(y_pred) > threshold).float()
    correct = (y_pred_bin == y_true).float().sum()
    total = torch.numel(y_true)
    return (correct / total).item()

# -------------------
# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# -------------------
# Training loop
def train_model(model, train_loader, val_loader, criterion, optimizer, epochs=30, patience=5, save_path="../coral_project_outputs/best_merged_model.pth"):
    best_iou = 0.0
    best_model_wts = copy.deepcopy(model.state_dict())
    patience_counter = 0

    for epoch in range(epochs):
        model.train()
        train_loss, train_iou, train_acc = 0, 0, 0

        for imgs, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} - Train"):
            imgs, masks = imgs.to(device), masks.to(device)

            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_iou += iou_score(outputs, masks)
            train_acc += pixel_accuracy(outputs, masks)

        # Validation
        model.eval()
        val_loss, val_iou, val_acc = 0, 0, 0
        with torch.no_grad():
            for imgs, masks in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} - Val"):
                imgs, masks = imgs.to(device), masks.to(device)
                outputs = model(imgs)
                loss = criterion(outputs, masks)
                val_loss += loss.item()
                val_iou += iou_score(outputs, masks)
                val_acc += pixel_accuracy(outputs, masks)

        # Averages
        train_loss /= len(train_loader)
        train_iou /= len(train_loader)
        train_acc /= len(train_loader)
        val_loss /= len(val_loader)
        val_iou /= len(val_loader)
        val_acc /= len(val_loader)

        print(f"\nEpoch {epoch+1}/{epochs}")
        print(f" Train Loss: {train_loss:.4f} | IoU: {train_iou:.4f} | Acc: {train_acc:.4f}")
        print(f" Val   Loss: {val_loss:.4f} | IoU: {val_iou:.4f} | Acc: {val_acc:.4f}")

        # Save best model
        if val_iou > best_iou:
            best_iou = val_iou
            best_model_wts = copy.deepcopy(model.state_dict())
            torch.save(model.state_dict(), save_path)
            print(f"  ✅ Saved new best model with IoU={best_iou:.4f}")
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"⏹️ Early stopping at epoch {epoch+1}")
                break

    model.load_state_dict(best_model_wts)
    print(f"Training complete. Best IoU: {best_iou:.4f}")
    return model

# -------------------
# Run training
final_model = train_model(
    model,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    epochs=25,
    patience=4
)


Using device: cuda


Epoch 1/25 - Train:   0%|          | 0/328 [00:02<?, ?it/s]


KeyboardInterrupt: 