In [None]:
# ------------------------------------------------------------
# STEP 6C — U-Net (R34, ImageNet encoder) — PAPER BASELINE
# Cell 1/4 — Imports & tiny wrapper to keep API consistent
# ------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import segmentation_models_pytorch as smp

class SMPWrap(nn.Module):
    """Wrap an SMP model so forward(x) -> {'logits': (B,1,H,W)}"""
    def __init__(self, core: nn.Module):
        super().__init__()
        self.core = core
    def forward(self, x):
        logits = self.core(x)              # raw logits from SMP (activation=None)
        if isinstance(logits, (list, tuple)):
            logits = logits[0]
        return {"logits": logits}


In [None]:
# ------------------------------------------------------------
# STEP 6C — U-Net++ definition (R34 encoder, nested skips)
# ------------------------------------------------------------
# Cell 2/4 — Imports & tiny wrapper to keep API consistent
# ------------------------------------------------------------
def build_unetpp_r34_binary():
    model = smp.UnetPlusPlus(
        encoder_name="resnet34",
        encoder_weights="imagenet",
        in_channels=3,
        classes=1,
        activation=None
    )
    return SMPWrap(model)


In [None]:
# ------------------------------------------------------------
# STEP 6C — Instantiate & check
# ------------------------------------------------------------
# Cell 3/4 — Imports & tiny wrapper to keep API consistent
# ------------------------------------------------------------
unetpp_r34_model = build_unetpp_r34_binary()
xb, yb = next(iter(train_loader))
out = unetpp_r34_model(xb)
print("U-Net++ (R34) logits:", tuple(out["logits"].shape))


In [None]:
# ------------------------------------------------------------
# STEP 6C — Params
# ------------------------------------------------------------
# Cell 4/4 — Imports & tiny wrapper to keep API consistent
# ------------------------------------------------------------
num_params = sum(p.numel() for p in unetpp_r34_model.parameters() if p.requires_grad)
print(f"[PARAMS] U-Net++ (R34): {num_params/1e6:.2f} M")
