In [None]:
# ------------------------------------------------------------
# STEP 6D — DeepLabV3+ (R50) — PAPER BASELINE
# Cell 1/4 — Imports & tiny wrapper to keep API consistent
# ------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import segmentation_models_pytorch as smp

class SMPWrap(nn.Module):
    """Wrap an SMP model so forward(x) -> {'logits': (B,1,H,W)}"""
    def __init__(self, core: nn.Module):
        super().__init__()
        self.core = core
    def forward(self, x):
        logits = self.core(x)              # raw logits from SMP (activation=None)
        if isinstance(logits, (list, tuple)):
            logits = logits[0]
        return {"logits": logits}


In [None]:
# ------------------------------------------------------------
# STEP 6D — DeepLabV3+ (R50) definition — PAPER BASELINE
# ------------------------------------------------------------
# Cell 2/4 — Imports & tiny wrapper to keep API consistent
# ------------------------------------------------------------
def build_deeplabv3plus_r50_binary():
    model = smp.DeepLabV3Plus(
        encoder_name="resnet50",
        encoder_weights="imagenet",
        in_channels=3,
        classes=1,
        activation=None
    )
    return SMPWrap(model)


In [None]:
# ------------------------------------------------------------
# STEP 6D — Instantiate & check
# ------------------------------------------------------------
# Cell 3/4 — Imports & tiny wrapper to keep API consistent
# ------------------------------------------------------------
dlab_r50_model = build_deeplabv3plus_r50_binary()
xb, yb = next(iter(train_loader))
out = dlab_r50_model(xb)
print("DeepLabV3+ (R50) logits:", tuple(out["logits"].shape))


In [None]:
# ------------------------------------------------------------
# STEP 6D — Params
# ------------------------------------------------------------
# Cell 4/4 — Imports & tiny wrapper to keep API consistent
# ------------------------------------------------------------
num_params = sum(p.numel() for p in dlab_r50_model.parameters() if p.requires_grad)
print(f"[PARAMS] DeepLabV3+ (R50): {num_params/1e6:.2f} M")
