In [1]:
"""
Lightweight U-Net for segmentation with:
- RepConv blocks (encoder/decoder/head)
- SimAM attention on skip-add fusion
- Bottleneck: SPPF-like (1x1 -> Depthwise DCNv2 -> 1x1)
- Cheap DW+PW offset/mask heads for deformable conv

Includes a minimal synthetic segmentation example (circles).
"""

import math
import inspect
import argparse
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

# =========================
# Optional: torchvision deform_conv2d
# =========================
try:
    import torchvision.ops as tvops
    _HAS_TV = True
except Exception:
    tvops = None
    _HAS_TV = False

def _supports_groups_in_deform() -> bool:
    if not _HAS_TV: 
        return False
    try:
        sig = inspect.signature(tvops.deform_conv2d)
        return "groups" in sig.parameters
    except Exception:
        return False

_HAS_GROUPS = _supports_groups_in_deform()

# =========================
# SimAM: parameter-free attention
# =========================
class SimAM(nn.Module):
    def __init__(self, lambda_val: float = 1e-4):
        super().__init__()
        self.lambda_val = lambda_val

    def forward(self, x):
        # x: [B, C, H, W]
        b, c, h, w = x.shape
        n = h * w - 1
        mean = x.mean(dim=(2, 3), keepdim=True)
        var = ((x - mean) ** 2).sum(dim=(2, 3), keepdim=True) / (n + 1e-6)
        e = (x - mean) ** 2 / (4 * (var + self.lambda_val)) + 0.5
        w = torch.sigmoid(e)
        return x * w

# =========================
# RepConv (train-time multi-branch, deploy single 3x3)
# =========================
class RepConv(nn.Module):
    def __init__(self, c_in, c_out, s=1, deploy=False, act=True):
        super().__init__()
        self.deploy = deploy
        self.stride = s
        self.act = nn.ReLU6(inplace=True) if act else nn.Identity()

        if deploy:
            self.rbr_reparam = nn.Conv2d(c_in, c_out, 3, s, 1, bias=True)
        else:
            self.rbr_identity = nn.BatchNorm2d(c_in) if (c_out == c_in and s == 1) else None
            self.rbr_dense = nn.Sequential(
                nn.Conv2d(c_in, c_out, 3, s, 1, bias=False),
                nn.BatchNorm2d(c_out),
            )
            self.rbr_1x1 = nn.Sequential(
                nn.Conv2d(c_in, c_out, 1, s, 0, bias=False),
                nn.BatchNorm2d(c_out),
            )

    def forward(self, x):
        if hasattr(self, "rbr_reparam"):
            out = self.rbr_reparam(x)
        else:
            id_out = self.rbr_identity(x) if self.rbr_identity is not None else 0
            out = self.rbr_dense(x) + self.rbr_1x1(x) + id_out
        return self.act(out)

    def get_equivalent_kernel_bias(self):
        """Fuse 3x3, 1x1 and (optional) identity branches into one 3x3 Conv2d kernel+bias."""
        if hasattr(self, "rbr_reparam"):
            return self.rbr_reparam.weight, self.rbr_reparam.bias

        # Helper to fuse conv+bn
        def fuse_conv_bn(conv, bn):
            w = conv.weight
            if conv.kernel_size == (1, 1):
                # pad to 3x3
                w = F.pad(w, [1, 1, 1, 1])
            gamma = bn.weight
            beta = bn.bias
            mean = bn.running_mean
            var = bn.running_var
            eps = bn.eps
            std = torch.sqrt(var + eps)
            w_fused = w * (gamma / std).reshape(-1, 1, 1, 1)
            b_fused = beta - mean * gamma / std
            return w_fused, b_fused

        w3, b3 = fuse_conv_bn(self.rbr_dense[0], self.rbr_dense[1])
        w1, b1 = fuse_conv_bn(self.rbr_1x1[0], self.rbr_1x1[1])

        if self.rbr_identity is not None:
            # Identity BN as 3x3 kernel (delta kernel)
            id_ch = self.rbr_identity.num_features
            id_kernel = torch.zeros((id_ch, id_ch, 3, 3), device=w3.device, dtype=w3.dtype)
            for i in range(id_ch):
                id_kernel[i, i, 1, 1] = 1.0
            gamma = self.rbr_identity.weight
            beta = self.rbr_identity.bias
            mean = self.rbr_identity.running_mean
            var = self.rbr_identity.running_var
            eps = self.rbr_identity.eps
            std = torch.sqrt(var + eps)
            w_id = id_kernel * (gamma / std).reshape(-1, 1, 1, 1)
            b_id = beta - mean * gamma / std
        else:
            w_id = torch.zeros_like(w3)
            b_id = torch.zeros_like(b3)

        w = w3 + w1 + w_id
        b = b3 + b1 + b_id
        return w, b

    def fuse_reparam(self):
        if hasattr(self, "rbr_reparam"):
            return
        W, B = self.get_equivalent_kernel_bias()
        self.rbr_reparam = nn.Conv2d(
            in_channels=self.rbr_dense[0].in_channels,
            out_channels=self.rbr_dense[0].out_channels,
            kernel_size=3,
            stride=self.stride,
            padding=1,
            bias=True,
        )
        self.rbr_reparam.weight.data = W
        self.rbr_reparam.bias.data = B
        # Delete old branches
        del self.rbr_dense, self.rbr_1x1
        if self.rbr_identity is not None:
            del self.rbr_identity

# =========================
# Cheap DW+PW block for offset/modulator
# =========================
class CheapOffsetConv(nn.Module):
    """Depthwise 3x3 + BN + ReLU6 + Pointwise 1x1 (DW-PW)."""
    def __init__(self, in_channels, out_channels, k=3, stride=1, padding=1):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=k, stride=stride,
                      padding=padding, groups=in_channels, bias=False),
            nn.BatchNorm2d(in_channels),
            nn.ReLU6(inplace=True),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True),
        )
        # init last conv to zero: start at "no deformation"
        nn.init.constant_(self.block[-1].weight, 0.)
        nn.init.constant_(self.block[-1].bias, 0.)

    def forward(self, x):
        return self.block(x)

# =========================
# Depthwise Deformable Conv2d (DW-DCNv2)
# =========================
class DeformableDWConv2d(nn.Module):
    """
    Depthwise deformable conv:
      - offset head: CheapOffsetConv -> 2*k*k
      - mask head:   CheapOffsetConv -> k*k (scaled to [0,2])
      - weight: [C, 1, kH, kW], groups=C
    Fast path uses torchvision.deform_conv2d with groups.
    Fallback: per-channel loop (slower, but correct).
    """
    def __init__(self, in_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False):
        super().__init__()
        k = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
        self.k = k
        self.stride = stride if isinstance(stride, tuple) else (stride, stride)
        self.padding = padding if isinstance(padding, tuple) else (padding, padding)
        self.dilation = dilation
        self.in_channels = in_channels
        self.bias_flag = bias
        ks2 = k[0] * k[1]

        self.offset_conv = CheapOffsetConv(in_channels, 2 * ks2, k=3, stride=self.stride, padding=1)
        self.modulator_conv = CheapOffsetConv(in_channels, ks2, k=3, stride=self.stride, padding=1)

        self.weight = nn.Parameter(torch.empty(in_channels, 1, *k))
        nn.init.kaiming_uniform_(self.weight, a=5**0.5)
        self.bias = nn.Parameter(torch.zeros(in_channels)) if bias else None

    def forward(self, x):
        if not _HAS_TV:
            raise RuntimeError("torchvision.ops.deform_conv2d is required for this module.")

        offset = self.offset_conv(x)                      # [B, 2*k*k, H', W']
        mask = 2.0 * torch.sigmoid(self.modulator_conv(x))  # [B, k*k, H', W']

        if _HAS_GROUPS:
            return tvops.deform_conv2d(
                input=x,
                offset=offset,
                weight=self.weight,
                bias=self.bias if self.bias_flag else None,
                stride=self.stride,
                padding=self.padding,
                dilation=self.dilation,
                mask=mask,
                groups=self.in_channels,
            )
        else:
            # Per-channel fallback
            outs = []
            for c in range(self.in_channels):
                xc = x[:, c:c+1]
                wc = self.weight[c:c+1]
                bc = self.bias[c:c+1] if self.bias is not None else None
                yc = tvops.deform_conv2d(
                    xc, offset, wc, bias=bc,
                    stride=self.stride, padding=self.padding,
                    dilation=self.dilation, mask=mask
                )
                outs.append(yc)
            return torch.cat(outs, dim=1)

# =========================
# SPPF-like bottleneck (1x1 -> DW-DCNv2 -> 1x1)
# =========================
class SPPF_DW_DCN(nn.Module):
    def __init__(self, c):
        super().__init__()
        self.reduce = nn.Sequential(
            nn.Conv2d(c, c, 1, bias=False),
            nn.BatchNorm2d(c),
            nn.ReLU6(inplace=True),
        )
        self.dw_dcn = DeformableDWConv2d(c, kernel_size=3, stride=1, padding=1, dilation=1, bias=True)
        self.expand = nn.Sequential(
            nn.Conv2d(c, c, 1, bias=False),
            nn.BatchNorm2d(c),
            nn.ReLU6(inplace=True),
        )

    def forward(self, x):
        x = self.reduce(x)
        x = self.dw_dcn(x)
        x = self.expand(x)
        return x

# =========================
# U-Net building blocks
# =========================
class Down(nn.Module):
    def __init__(self, c_in, c_out):
        super().__init__()
        self.conv = nn.Sequential(
            RepConv(c_in, c_out, s=2),  # downsample
            RepConv(c_out, c_out),
        )

    def forward(self, x):
        return self.conv(x)

class Up(nn.Module):
    def __init__(self, c_in, c_skip, c_out, use_simam=True):
        super().__init__()
        self.proj = nn.Conv2d(c_in, c_out, 1, bias=False)
        self.bn = nn.BatchNorm2d(c_out)
        self.skip_proj = nn.Conv2d(c_skip, c_out, 1, bias=False)
        self.skip_bn = nn.BatchNorm2d(c_out)
        self.attn = SimAM() if use_simam else nn.Identity()
        self.block = RepConv(c_out, c_out)

    def forward(self, x, skip):
        x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
        x = F.relu6(self.bn(self.proj(x)), inplace=True)
        s = F.relu6(self.skip_bn(self.skip_proj(skip)), inplace=True)
        x = self.attn(x + s)  # add-fusion + SimAM
        return self.block(x)

# =========================
# The Lightweight U-Net
# =========================
class UNetLite_Rep_SimAM_DWDCN(nn.Module):
    def __init__(self, in_ch=3, n_classes=1, width=1.0, use_simam=True):
        super().__init__()
        def C(c): return max(8, int(c * width))

        # Stem
        self.stem = nn.Sequential(
            RepConv(in_ch, C(16)),
            RepConv(C(16), C(16)),
        )

        # Encoder
        self.d1 = Down(C(16), C(32))
        self.d2 = Down(C(32), C(64))
        self.d3 = Down(C(64), C(128))
        self.d4 = Down(C(128), C(192))

        # Bottleneck
        self.bot = SPPF_DW_DCN(C(192))

        # Decoder
        self.u4 = Up(C(192), C(128), C(96), use_simam)
        self.u3 = Up(C(96),  C(64),  C(64), use_simam)
        self.u2 = Up(C(64),  C(32),  C(48), use_simam)
        self.u1 = Up(C(48),  C(16),  C(32), use_simam)

        # Head
        self.head = nn.Sequential(
            RepConv(C(32), C(32)),
            nn.Conv2d(C(32), n_classes, 1),
        )

    def forward(self, x):
        s0 = self.stem(x)       # C16
        s1 = self.d1(s0)        # C32
        s2 = self.d2(s1)        # C64
        s3 = self.d3(s2)        # C128
        s4 = self.d4(s3)        # C192
        b  = self.bot(s4)
        x  = self.u4(b, s3)
        x  = self.u3(x, s2)
        x  = self.u2(x, s1)
        x  = self.u1(x, s0)
        return self.head(x)

    def fuse_reparam(self):
        """Fuse RepConv branches for deployment."""
        for m in self.modules():
            if isinstance(m, RepConv):
                m.fuse_reparam()

# =========================
# Toy dataset: synthetic circles
# =========================
def make_circles_batch(B: int, H: int, W: int, device="cpu") -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Generates a batch of simple images with filled circles and corresponding masks.
    Images: 3-channel, masks: 1-channel (0/1).
    """
    yy, xx = torch.meshgrid(torch.linspace(-1, 1, H, device=device),
                            torch.linspace(-1, 1, W, device=device), indexing="ij")
    imgs, masks = [], []
    for _ in range(B):
        cx = torch.empty(1, device=device).uniform_(-0.4, 0.4).item()
        cy = torch.empty(1, device=device).uniform_(-0.4, 0.4).item()
        r  = torch.empty(1, device=device).uniform_(0.15, 0.35).item()
        circle = ((xx - cx)**2 + (yy - cy)**2) <= (r**2)
        mask = circle.float().unsqueeze(0)  # [1,H,W]
        # Simple texture: mask + noise
        noise = 0.2 * torch.randn(1, H, W, device=device)
        img = mask * 0.8 + (1 - mask) * 0.2 + noise
        img = img.clamp(0, 1).repeat(3, 1, 1)  # 3-channel
        imgs.append(img)
        masks.append(mask)
    return torch.stack(imgs, 0), torch.stack(masks, 0)

# =========================
# Loss: BCE + Soft Dice
# =========================
def dice_loss(logits, targets, eps=1e-6):
    probs = torch.sigmoid(logits)
    dims = (2, 3)
    num = 2 * (probs * targets).sum(dim=dims) + eps
    den = (probs.pow(2) + targets.pow(2)).sum(dim=dims) + eps
    return 1 - (num / den).mean()

class BCEDiceLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, logits, targets):
        return self.bce(logits, targets) + dice_loss(logits, targets)

# =========================
# Example training loop
# =========================
def demo_train(args):
    device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")

    model = UNetLite_Rep_SimAM_DWDCN(in_ch=3, n_classes=1, width=args.width).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    criterion = BCEDiceLoss()

    print(f"Using device: {device}")
    print(f"Fast grouped deform_conv2d available: {_HAS_GROUPS}")

    model.train()
    for it in range(args.iters):
        imgs, masks = make_circles_batch(args.batch_size, args.size, args.size, device=device)
        logits = model(imgs)
        loss = criterion(logits, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (it + 1) % max(1, args.iters // 10) == 0:
            with torch.no_grad():
                dice = 1 - dice_loss(logits, masks).item()
            print(f"[{it+1:04d}/{args.iters}] loss={loss.item():.4f}  dice={dice:.4f}")

    # Optionally fuse RepConv branches for deployment
    if args.fuse:
        model.fuse_reparam()
        print("RepConv branches fused for deployment.")

    # Quick sanity check on a batch
    model.eval()
    with torch.no_grad():
        imgs, masks = make_circles_batch(2, args.size, args.size, device=device)
        logits = model(imgs)
        preds = (torch.sigmoid(logits) > 0.5).float()
        print("Output shape:", logits.shape, " | Pred mask unique values:", preds.unique())

    return model

# =========================
# CLI
# =========================
def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--size", type=int, default=128, help="input H=W")
    p.add_argument("--batch_size", type=int, default=4)
    p.add_argument("--iters", type=int, default=200)
    p.add_argument("--lr", type=float, default=3e-3)
    p.add_argument("--width", type=float, default=0.75, help="global width multiplier")
    p.add_argument("--cpu", action="store_true", help="force CPU")
    p.add_argument("--fuse", action="store_true", help="fuse RepConv for deploy after training")
    return p.parse_args([])  # change to [] for notebook, or remove to parse CLI

if __name__ == "__main__":
    args = parse_args()
    demo_train(args)


Using device: cuda
Fast grouped deform_conv2d available: False
[0020/200] loss=0.1100  dice=0.9489
[0040/200] loss=0.0255  dice=0.9917
[0060/200] loss=0.0159  dice=0.9944
[0080/200] loss=0.0121  dice=0.9956
[0100/200] loss=0.0088  dice=0.9973
[0120/200] loss=0.0081  dice=0.9970
[0140/200] loss=0.0087  dice=0.9960
[0160/200] loss=0.0114  dice=0.9934
[0180/200] loss=0.0072  dice=0.9964
[0200/200] loss=0.0069  dice=0.9964
Output shape: torch.Size([2, 1, 128, 128])  | Pred mask unique values: tensor([0., 1.], device='cuda:0')


In [5]:
import time
import torch
from torch import nn

# =========================
# Helpers
# =========================
def count_params(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters())

@torch.no_grad()
def benchmark_latency(model: nn.Module, H=512, W=512, *, runs=30, device='cuda') -> float:
    model.eval().to(device)
    model.to(memory_format=torch.channels_last)
    x = torch.randn(1, 3, H, W, device=device).to(memory_format=torch.channels_last)

    # warmup
    for _ in range(8):
        _ = model(x)
    if device == 'cuda':
        torch.cuda.synchronize()

    t0 = time.time()
    for _ in range(runs):
        _ = model(x)
    if device == 'cuda':
        torch.cuda.synchronize()
    return (time.time() - t0) / runs

def try_profile_macs_flops(model: nn.Module, H=512, W=512, device='cuda'):
    """
    Returns (params_M, GMACs, GFLOPs) if ptflops or thop is available.
    If neither available, returns (params_M, None, None).
    """
    params_m = count_params(model) / 1e6
    model.eval().to(device).to(memory_format=torch.channels_last)
    x = torch.randn(1, 3, H, W, device=device).to(memory_format=torch.channels_last)

    # Prefer ptflops on Python 3.12
    try:
        from ptflops import get_model_complexity_info
        with torch.cuda.amp.autocast(enabled=False):
            macs, params = get_model_complexity_info(
                model, (3, H, W), as_strings=False, print_per_layer_stat=False, verbose=False
            )
        gmacs = macs / 1e9
        gflops = 2.0 * gmacs  # (multiply + add) convention
        return params_m, gmacs, gflops
    except Exception:
        pass

    # Fallback to thop if present
    try:
        from thop import profile
        macs, _ = profile(model, inputs=(x,), verbose=False)
        gmacs = macs / 1e9
        gflops = 2.0 * gmacs
        return params_m, gmacs, gflops
    except Exception:
        return params_m, None, None

# =========================
# Vanilla UNet (fixed decoder channel sizes)
# =========================
def conv3x3(ci, co): return nn.Conv2d(ci, co, 3, padding=1, bias=False)
def bn(c): return nn.BatchNorm2d(c)
def relu(): return nn.ReLU(inplace=True)

class DoubleConv(nn.Module):
    def __init__(self, ci, co):
        super().__init__()
        self.m = nn.Sequential(
            conv3x3(ci, co), bn(co), relu(),
            conv3x3(co, co), bn(co), relu()
        )
    def forward(self, x): return self.m(x)

class VanillaUNet(nn.Module):
    def __init__(self, in_ch=3, n_classes=1, base=64):
        super().__init__()
        C0, C1, C2, C3, C4 = base, base*2, base*4, base*8, base*16

        # Encoder
        self.enc1 = DoubleConv(in_ch, C0); self.pool1 = nn.MaxPool2d(2)
        self.enc2 = DoubleConv(C0, C1);    self.pool2 = nn.MaxPool2d(2)
        self.enc3 = DoubleConv(C1, C2);    self.pool3 = nn.MaxPool2d(2)
        self.enc4 = DoubleConv(C2, C3);    self.pool4 = nn.MaxPool2d(2)

        # Bottleneck
        self.bott = DoubleConv(C3, C4)

        # Decoder (note: concat doubles channels → 2*C? → C?)
        self.up4  = nn.ConvTranspose2d(C4, C3, 2, 2)
        self.dec4 = DoubleConv(C3 + C3, C3)

        self.up3  = nn.ConvTranspose2d(C3, C2, 2, 2)
        self.dec3 = DoubleConv(C2 + C2, C2)

        self.up2  = nn.ConvTranspose2d(C2, C1, 2, 2)
        self.dec2 = DoubleConv(C1 + C1, C1)

        self.up1  = nn.ConvTranspose2d(C1, C0, 2, 2)
        self.dec1 = DoubleConv(C0 + C0, C0)

        self.head = nn.Conv2d(C0, n_classes, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        e4 = self.enc4(self.pool3(e3))
        b  = self.bott(self.pool4(e4))

        d4 = self.up4(b); d4 = torch.cat([d4, e4], dim=1); d4 = self.dec4(d4)
        d3 = self.up3(d4); d3 = torch.cat([d3, e3], dim=1); d3 = self.dec3(d3)
        d2 = self.up2(d3); d2 = torch.cat([d2, e2], dim=1); d2 = self.dec2(d2)
        d1 = self.up1(d2); d1 = torch.cat([d1, e1], dim=1); d1 = self.dec1(d1)
        return self.head(d1)

# =========================
# Your proposed model
# =========================
# Make sure this class is already defined in your session or import it:
# from your_file import UNetLite_Rep_SimAM_DWDCN

# =========================
# Run the comparison
# =========================
if __name__ == "__main__":
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Instantiate models
    proposed = UNetLite_Rep_SimAM_DWDCN(in_ch=3, n_classes=1, width=0.75)
    vanilla  = VanillaUNet(in_ch=3, n_classes=1, base=64)  # classic UNet capacity

    # Params
    p_params_m = count_params(proposed) / 1e6
    v_params_m = count_params(vanilla)  / 1e6

    # Latency @ 512x512
    p_lat = benchmark_latency(proposed, 512, 512, runs=30, device=device)
    v_lat = benchmark_latency(vanilla,  512, 512, runs=30, device=device)

    print(f"[Params]  Proposed: {p_params_m:.2f} M | Vanilla: {v_params_m:.2f} M")
    print(f"[Latency] Proposed: {p_lat*1000:.2f} ms | Vanilla: {v_lat*1000:.2f} ms  (512x512, bs=1)")

    # MACs / GFLOPs if profiler available
    p_params_m2, p_gmacs, p_gflops = try_profile_macs_flops(proposed, 512, 512, device)
    v_params_m2, v_gmacs, v_gflops = try_profile_macs_flops(vanilla,  512, 512, device)

    if p_gmacs is not None and v_gmacs is not None:
        print(f"[Complexity] Proposed: MACs={p_gmacs:.2f} G | GFLOPs≈{p_gflops:.2f}")
        print(f"[Complexity] Vanilla : MACs={v_gmacs:.2f} G | GFLOPs≈{v_gflops:.2f}")
    else:
        print("[Complexity] Skipped MACs/GFLOPs (install ptflops or thop to enable).")


[Params]  Proposed: 0.71 M | Vanilla: 31.04 M
[Latency] Proposed: 26.70 ms | Vanilla: 20.23 ms  (512x512, bs=1)
[Complexity] Proposed: MACs=7.94 G | GFLOPs≈15.87
[Complexity] Vanilla : MACs=193.28 G | GFLOPs≈386.55


  with torch.cuda.amp.autocast(enabled=False):
