In [None]:
from google.colab import drive
drive.mount('/content/drive')


In [None]:
!python --version

!pip install torch==2.3.1+cu121 torchvision --extra-index-url https://download.pytorch.org/whl/cu121


In [None]:

!pip install -q  pycocotools tqdm torchmetrics lxml scipy
!pip install -U "datasets>=2.17.0"  "pyarrow>=14.0"
!pip install -U cmake ninja wheel
!pip install --no-binary=mmcv mmcv==2.2.0

In [None]:
!pip uninstall -y numpy

!pip install numpy==1.26.4

In [None]:
%%bash

pip install -U cmake ninja wheel

git clone --depth 1 --branch v0.14.6 https://github.com/SHI-Labs/NATTEN.git
cd NATTEN

export FORCE_CUDA=1
export TORCH_CUDA_ARCH_LIST="8.0"

pip install .

cd ..

In [None]:
!rm -rf DAT && git clone -q https://github.com/LeapLabTHU/DAT.git

In [None]:
%%bash
git clone --depth 1 https://github.com/OpenGVLab/DCNv4.git
cd DCNv4/DCNv4_op
export FORCE_CUDA=1
export TORCH_CUDA_ARCH_LIST="8.0"
python -m pip install . --no-build-isolation -v
cd ../..


In [None]:
!mkdir -p ~/.kaggle && chmod 600 ~/.kaggle/kaggle.json

!pip -q install opendatasets
import opendatasets as od
od.download(
    "https://www.kaggle.com/datasets/vijayabhaskar96/pascal-voc-2007-and-2012",
    data_dir="/content/voc")          # → /content/voc/VOCdevkit/{VOC2007,VOC2012}

In [None]:
from torchvision.datasets import VOCDetection
from torch.utils.data import ConcatDataset
ROOT = "/content/voc/pascal-voc-2007-and-2012"
train07 = VOCDetection(ROOT, "2007", "trainval")   # 5011
train12 = VOCDetection(ROOT, "2012", "trainval")   # 11540
combined = ConcatDataset([train07, train12])       # 16551
test07   = VOCDetection(ROOT, "2007", "test")      # 4952x

print(len(train07), len(train12), len(combined), len(test07))
# 5011 11540 16551 4952

In [None]:


# Import from actual DAT and DCNv4 modules
import importlib
from DAT.models.dat import DAT, LayerScale, TransformerStage
from DAT.models.dat_blocks import LayerNormProxy, TransformerMLP, TransformerMLPWithConv
from DAT.models.dat_blocks import LocalAttention, DAttentionBaseline, ShiftWindowAttention, PyramidAttention
from DCNv4.modules.dcnv4 import DCNv4
dat_mod = importlib.import_module("DAT.models.dat")
from torch.utils.checkpoint import checkpoint
from __future__ import annotations
import torch
import torch.nn as nn
from timm.models.layers import DropPath
import torch.nn.functional as F
import math
import numpy as np
from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention
from timm.models.layers import create_act_layer
from typing import Optional, List, Dict, Tuple,  Union





class NormFactory:
    SUPPORTED = {"bn", "gn", "lnp"}

    def __init__(self, kind: str = "gn", gn_groups: int = 32):
        kind = kind.lower()
        if kind not in self.SUPPORTED:
            raise ValueError(f"kind must be one of {self.SUPPORTED}")
        self.kind = kind
        self.gn_groups = gn_groups

    def __call__(self, num_feat: int) -> nn.Module:
        if self.kind == "bn":
            return nn.BatchNorm2d(num_feat)
        if self.kind == "gn":
            g = math.gcd(self.gn_groups, num_feat) or 1
            return nn.GroupNorm(g, num_feat)

        if self.kind == "lnp":
            return LayerNormProxy(num_feat)




# -----------------------------------------------------------------------------
# 1. CrossScaleInjection
# -----------------------------------------------------------------------------

class CrossScaleInjection(nn.Module):
    """Cross‑scale feature enjeksiyonu – kanal‑başına gating versiyonu"""
    def __init__(self, low_ch: int, high_ch: int):
        super().__init__()
        self.align = nn.Conv2d(low_ch, high_ch, 1, bias=False)
        self.norm = LayerNormProxy(high_ch)

        # Channel based weight; starting 0.1
        self.weight = nn.Parameter(torch.ones(1, high_ch, 1, 1) * 0.15)

        nn.init.kaiming_normal_(self.align.weight, mode="fan_out", nonlinearity="relu")

    def forward(self, low_res: torch.Tensor, high_res: torch.Tensor) -> torch.Tensor:
        low_aligned = self.align(low_res)
        low_up = F.interpolate(low_aligned, size=high_res.shape[-2:],
                              mode='bilinear', align_corners=False)
        low_up = self.norm(low_up)
        # sigmoid → [0,1] for per channel λ
        return high_res + torch.sigmoid(self.weight) * low_up

# -----------------------------------------------------------------------------
# PGI– Programmable Gradient Injection
# -----------------------------------------------------------------------------



class PGIModule(nn.Module):
    """
    Programmable Gradient Injection v2 – DDP‑safe
    * aux_channels verilirse init'te 1x1 'aux_adapter' kurulur (statik, DDP‑safe).
    * Aux katkısı sadece training modunda eklenir.
    """
    def __init__(
        self,
        channels: int,
        reduction: int = 4,
        lambda_bounds: tuple[float, float] = (0.2, 0.7),
        init_lambda: float = 0.5,
        use_bn: bool = True,
        drop_path: float = 0.0,
        norm_factory: NormFactory = NormFactory("gn"),
        aux_channels: int | None = None,
    ) -> None:
        super().__init__()
        self.channels = channels
        self.lambda_min, self.lambda_max = lambda_bounds

        # main branch
        self.main = nn.Sequential(
            nn.Conv2d(channels, channels, 3, 1, 1,
                      groups=max(1, channels // 4), bias=False),
            norm_factory(channels) if use_bn else nn.Identity(),
        )

        mid = max(channels // reduction, 32)
        self.aux = nn.Sequential(
            nn.Conv2d(channels, mid, 1, bias=False),
            norm_factory(mid) if use_bn else nn.Identity(),
            nn.SiLU(inplace=True),
            nn.Conv2d(mid, channels, 1, bias=False),
            norm_factory(channels) if use_bn else nn.Identity(),
        )

        # Aux channel adapter (static)
        if aux_channels is None or aux_channels == channels:
            self.aux_adapter = nn.Identity()
        else:
            self.aux_adapter = nn.Conv2d(aux_channels, channels, 1, bias=False)
            nn.init.kaiming_normal_(self.aux_adapter.weight, mode='fan_out', nonlinearity='relu')

        # λ (learnable)
        import math
        init_logit = math.log((init_lambda - self.lambda_min) / (self.lambda_max - init_lambda))
        self._lambda_logit = nn.Parameter(torch.tensor(init_logit))

        # Gate start 0.6
        self.gate = nn.Parameter(torch.ones(1, channels, 1, 1) * 0.5)

        self.dp = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self._init_weights()

    @property
    def lambda_val(self) -> torch.Tensor:
        σ = torch.sigmoid(self._lambda_logit)
        return self.lambda_min + (self.lambda_max - self.lambda_min) * σ

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                if hasattr(m, 'weight'): nn.init.ones_(m.weight)
                if hasattr(m, 'bias'):   nn.init.zeros_(m.bias)

    def _align_spatial(self, feat: torch.Tensor, target_hw: tuple[int, int]) -> torch.Tensor:
        if feat.shape[-2:] != target_hw:
            feat = F.interpolate(feat, size=target_hw, mode='bilinear', align_corners=False)
        return feat

    def forward(self, x: torch.Tensor, aux_input: torch.Tensor | None = None):
        main_out = self.main(x)

        if aux_input is not None and self.training:
            aux = self.aux_adapter(aux_input)
            aux = self._align_spatial(aux, x.shape[-2:])
            aux = self.aux(aux) * self.gate
            main_out = main_out + self.lambda_val * aux

        return x + self.dp(main_out)

def get_drop_path_rates(num_layers: int, max_rate: float) -> List[float]:
    """Generate drop path rates with cosine scheduling

    Args:
        num_layers: Total number of layers
        max_rate: Maximum drop path rate

    Returns:
        List of drop path rates for each layer
    """
    if num_layers <= 1:
        return [0.0] * num_layers

    # Cosine scheduling for smooth progression
    rates = [
        max_rate * (1.0 - math.cos(math.pi * i / (num_layers - 1))) * 0.5
        for i in range(num_layers)
    ]

    return rates


def init_weights_improved(model: nn.Module):
    """Improved weight initialization for all module types"""
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            # Special handling for different conv types
            if 'dw' in name or module.groups > 1:  # Depthwise
                nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='linear')
            elif 'pw' in name or module.kernel_size == (1, 1):  # Pointwise
                nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
            else:  # Regular conv
                nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')

            if module.bias is not None:
                nn.init.zeros_(module.bias)

        elif isinstance(module, nn.BatchNorm2d):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

        elif isinstance(module, LayerNormProxy):
            if hasattr(module, 'weight'):
                nn.init.ones_(module.weight)
            if hasattr(module, 'bias'):
                nn.init.zeros_(module.bias)

        elif isinstance(module, nn.Linear):
            nn.init.trunc_normal_(module.weight, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)

        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, std=0.02)


# -----------------------------------------------------------------------------
# 4. Channel-wise Attention (SE)
# -----------------------------------------------------------------------------
class ChannelAttention(nn.Module):

    def __init__(self, channels: int, reduction: int = 8):
        super().__init__()
        mid = max(channels // reduction, 16)

        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Conv2d(channels, mid, 1, bias=True)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(mid, channels, 1, bias=True)
        self.sigmoid = nn.Sigmoid()

        # --- parameters
        nn.init.kaiming_normal_(self.fc1.weight, mode='fan_out', nonlinearity='relu')
        nn.init.zeros_(self.fc1.bias)

        nn.init.kaiming_normal_(self.fc2.weight, mode='fan_out', nonlinearity='sigmoid')
        nn.init.zeros_(self.fc2.bias)

        # ►–– init scale ≈
        self.residual_scale = nn.Parameter(torch.tensor(-2.0))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = self.avgpool(x)
        y = self.relu(self.fc1(y))
        y = self.sigmoid(self.fc2(y))

        scale = torch.sigmoid(self.residual_scale)   # ∈(0,1)
        return x * (1 - scale) + x * y * scale


# 5. Gradient-clip helpers
# -----------------------------------------------------------------------------

def clip_dcnv4_grads(model: nn.Module, max_norm: float = 0.2) -> None:
    """Clip gradients for DCNv4 offset/mask parameters - more aggressive"""
    target_params = []
    for name, param in model.named_parameters():
        if param.grad is None:
            continue
        n_low = name.lower()
        if "offset" in n_low or "mask" in n_low:
            target_params.append(param)

    if target_params:
        torch.nn.utils.clip_grad_norm_(target_params, max_norm=max_norm)

# -----------------------------------------------------------------------------
# Helper classes
# -----------------------------------------------------------------------------

def fix_zero_init_params(model):
    """Fix parameters that are initialized to zero"""
    for name, param in model.named_parameters():
        if param.dim() > 0 and torch.all(param == 0):
            print(f"Fixing zero-initialized parameter: {name}")
            if 'bias' in name:
                nn.init.constant_(param, 0.01)
            else:
                nn.init.normal_(param, std=0.01)

def apply_gradient_checkpointing(model):
    """Apply gradient checkpointing to backbone stages - IMPROVED"""
    if hasattr(model, 'backbone'):
        # Apply a checkpoint for each stage
        for stage_name in ['st0', 'st1', 'st2', 'st3']:
            if hasattr(model.backbone, stage_name):
                stage = getattr(model.backbone, stage_name)
                # Make the sequential module checkpoint-friendly
                class CheckpointedSequential(nn.Module):
                    def __init__(self, *layers):
                        super().__init__()
                        self.layers = nn.ModuleList(layers)

                    def forward(self, x):
                        for layer in self.layers:
                            x = checkpoint(layer, x, use_reentrant=False)
                        return x

                # Recreate Stage
                checkpointed = CheckpointedSequential(*[block for block in stage])
                setattr(model.backbone, stage_name, checkpointed)

def init_conv(layer: nn.Conv2d, fan_out: bool = True):
    """Kaiming normal initialization"""
    nn.init.kaiming_normal_(
        layer.weight,
        mode="fan_out" if fan_out else "fan_in",
        nonlinearity="relu",
    )
    if layer.bias is not None:
        nn.init.zeros_(layer.bias)


class LNAct(nn.Module):
    """LayerNormProxy + SiLU"""
    def __init__(self, channels: int):
        super().__init__()
        self.norm = LayerNormProxy(channels)
        self.act = nn.SiLU(inplace=True)

    def forward(self, x):
        return self.act(self.norm(x))


class ConvLNAct(nn.Sequential):
    """Conv + LayerNorm + Act fusion"""
    def __init__(self, in_ch: int, out_ch: int, k: int = 3,
                s: int = 1, p: int = None, g: int = 1):
        p = p if p is not None else k // 2
        super().__init__(
            nn.Conv2d(in_ch, out_ch, k, s, p, groups=g, bias=False),
            LNAct(out_ch)
        )
        nn.init.kaiming_normal_(self[0].weight, mode="fan_out", nonlinearity="relu")


# -----------------------------------------------------------------------------
# LayerScale
# -----------------------------------------------------------------------------
class LayerScale(nn.Module):
    """
    Learnable scalar per layer (γ) – Broadcast-Safe version.
    Works with both (B,C,H,W) and (B,N,C) tensor layouts.
    """
    def __init__(
        self,
        dim: int,
        inplace: bool = False,
        init_values: float = 1.0,
        depth: int | None = None,
    ):
        super().__init__()
        self.inplace = inplace
        if depth is not None:
            init_values = min(1.0, 0.5 * depth / 3)
        self.weight = nn.Parameter(torch.ones(dim) * init_values)

    def _shape_for(self, x: torch.Tensor) -> tuple[int, ...]:
        """
        Determines the shape to broadcast the weight based on x's placement.
        Priority: channel-end (… , C) → channel-first (B, C, …) → singular cases.
        """
        C = self.weight.numel()
        nd = x.dim()

        # 4D  image: (B, C, H, W) veya (B, H, W, C)
        if nd == 4:
            if x.shape[1] == C:          # (B, C, H, W)
                return (1, C, 1, 1)
            if x.shape[-1] == C:         # (B, H, W, C) - nadir
                return (1, 1, 1, C)

        # 3D: (B, N, C) or (B, C, N) or (C, H, W)
        if nd == 3:
            if x.shape[-1] == C:         # (B, N, C)
                return (1, 1, C)
            if x.shape[1] == C:          # (B, C, N)
                return (1, C, 1)
            if x.shape[0] == C:          # (C, H, W) / (C, N, ?)such as extreme cases
                return (C, 1, 1)

        # 2D: (B, C) veya (C, B)
        if nd == 2:
            if x.shape[-1] == C:         # (B, C)
                return (1, C)
            if x.shape[0] == C:          # (C, B)
                return (C, 1)


        shape = [1] * max(1, nd)
        shape[-1] = C
        return tuple(shape)

    def forward(self, x: torch.Tensor):
        w = self.weight.view(*self._shape_for(x))
        if self.inplace:
            return x.mul_(w)
        return x * w

# Synchronize the LayerScale within DAT with this version (old behavior is preserved)
dat_mod.LayerScale = LayerScale


# -----------------------------------------------------------------------------
# Stem
# -----------------------------------------------------------------------------

class Stem(nn.Module):
    """Three 3×3 convolutions with stride pattern (2,1,2)"""
    def __init__(self, out_channels: int = 128):
        super().__init__()
        c_mid = out_channels // 2

        self.conv1 = nn.Conv2d(3, c_mid, kernel_size=3, stride=2, padding=1, bias=False)
        self.conv2 = nn.Conv2d(c_mid, c_mid, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv3 = nn.Conv2d(c_mid, out_channels, kernel_size=3, stride=2, padding=1, bias=False)

        self.norm = LayerNormProxy(out_channels)
        self.act = nn.GELU()

        self._init_weights()

    def _init_weights(self):
        for m in [self.conv1, self.conv2, self.conv3]:
            nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.act(self.conv1(x))
        x = self.act(self.conv2(x))
        x = self.conv3(x)
        x = self.norm(x)
        return x


# -----------------------------------------------------------------------------
# DCNv4Lite
# -----------------------------------------------------------------------------
## FULL DCNV4


class DCNv4Lite(nn.Module):
    """
    • Fully wraps the DCNv4 path (2D→3D→2D), NO baseline conv/residual mix.
    • Offset/mask parameters are initialized with small std (for stable fp32 training).
    • Silently discards ‘mix_*’ and ‘conv.*’ keys in old checkpoints.

    Args:
        channels (int)            : Num of channels
        group (int)               : DCNv4 group
        kernel_size (int)         : Kernel
        stride (int)              : Stride
        pad (int)                 : Padding
        dilation (int)            : Dilation
        offset_scale (float)      : DCNv4 offset scale
        without_pointwise (bool)  : Enable/disable the internal 1x1 project in DCNv4 (default: True)
        init_offset_mask_std (float): offset/mask start std (default: 0.01)
        **kwargs                  : It is forwarded to DCNv4 exactly as is.
    """
    def __init__(self,
                 channels: int,
                 group: int = 8,
                 kernel_size: int = 3,
                 stride: int = 1,
                 pad: int = 1,
                 dilation: int = 1,
                 offset_scale: float = 0.1,
                 *,
                 without_pointwise: bool = True,
                 init_offset_mask_std: float = 0.01,
                 **kwargs):
        super().__init__()
        self.channels = int(channels)

        #  DCNv4
        self.dcn = DCNv4(
            channels=channels, kernel_size=kernel_size, stride=stride, pad=pad,
            dilation=dilation, group=group, offset_scale=offset_scale,
            center_feature_scale=False, remove_center=False,
            output_bias=True, without_pointwise=without_pointwise, **kwargs
        )

        # Offset/Mask secure init
        self._init_dcn_params(std=init_offset_mask_std)

        # Save geometry for representation (debug)
        self._geom = dict(k=kernel_size, s=stride, p=pad, d=dilation, g=group,
                          wop=without_pointwise)

    # ---------------- init helpers ----------------
    def _init_dcn_params(self, std: float = 0.01):
        """
        Offset/mask parameters are initialized with small normals so that
        sigmoid(0)=0.5 exits the plateau lock; other parameters remain in DCNv4's
        default init.
        """
        std = float(max(std, 1e-5))
        for n, p in self.dcn.named_parameters():
            n_low = n.lower()
            if ('offset' in n_low) or ('mask' in n_low):
                nn.init.normal_(p, std=std)

    # ---------------- forward ----------------------
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, C, H, W) → y: (B, C, H, W)
        The DCNv4 interface (B, HW, C) requires a 2D→3D→2D conversion.
        """
        B, C, H, W = x.shape
        x3d = x.permute(0, 2, 3, 1).contiguous().view(B, H * W, C)   # (B, HW, C)
        y3d = self.dcn(x3d, shape=(H, W))                            # (B, HW, C)
        y   = y3d.view(B, H, W, C).permute(0, 3, 1, 2).contiguous()  # (B, C, H, W)
        return y

    # --------------- helpers -------------------
    @torch.no_grad()
    def freeze_offsets(self, flag: bool = True):
        """Freeze/unfreeze offset/mask parameters (e.g., for warming)."""
        for n, p in self.dcn.named_parameters():
            n_low = n.lower()
            if ('offset' in n_low) or ('mask' in n_low):
                p.requires_grad = (not flag)

    def offset_mask_parameters(self):
        """In the Optimizer, it yields the offset/mask parameters to provide separate LR/clip."""
        for n, p in self.dcn.named_parameters():
            n_low = n.lower()
            if ('offset' in n_low) or ('mask' in n_low):
                yield p

    # Ignore the mix/baseline keys remaining in old checkpoints
    def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
        obsolete = (
            "mix_weight", "mix_logit", "_mix_tau", "_mix_eps", "_mix_clamp",
            "conv.weight", "conv.bias"
        )
        for key in list(state_dict.keys()):
            if key.startswith(prefix) and any(key.endswith(obs) for obs in obsolete):
                state_dict.pop(key)
        return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)

    def extra_repr(self) -> str:
        g = self._geom
        return (f"channels={self.channels}, k={g['k']}, s={g['s']}, p={g['p']}, "
                f"d={g['d']}, group={g['g']}, without_pointwise={g['wop']}")


def boost_relative_position_bias(model: nn.Module, std: float = 0.02):
    """
    Resets all `relative_position_bias_table` parameters in all ShiftWindow/LocalAttention modules belonging to the model.
    """
    for m in model.modules():
        if hasattr(m, "relative_position_bias_table"):
            nn.init.trunc_normal_(m.relative_position_bias_table, std=std)

def init_decoder_sampling_offsets(decoder: nn.Module, bias_val: float = 0.1):
    """
    Set the sampling_offsets bias in the deformable decoder layers
    to a fixed value; a small shift triggers the gradient.
    """
    for n, p in decoder.named_parameters():
        if "sampling_offsets.bias" in n:
            nn.init.constant_(p, bias_val)



def apply_hybrid_fixup(model: nn.Module):
    """
    a) Relative-pos bias table
    b) Decoder sampling_offsets bias
    c) (Optional) DCNv4 offset/mask gradient LR boost
    """
    boost_relative_position_bias(model)
    if hasattr(model, "decoder"):
        init_decoder_sampling_offsets(model.decoder, 0.1)
# -----------------------------------------------------------------------------
# Stage-0 Block
# -----------------------------------------------------------------------------

class StageZeroBlock(nn.Module):
    def __init__(self, channels=128, drop_path_rate=0.0,
                 use_layer_scale=True, layer_scale_init=1.0,
                 norm_factory=NormFactory("gn")):
        super().__init__()
        self.conv3 = nn.Conv2d(channels, channels, 3, 1, 1, bias=False)
        nn.init.kaiming_normal_(self.conv3.weight, mode='fan_out', nonlinearity='relu')

        self.se   = ChannelAttention(channels, reduction=8)
        self.norm = norm_factory(channels)
        self.scale = LayerScale(channels, init_values=layer_scale_init) if use_layer_scale else nn.Identity()
        self.dp = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()

    def forward(self, x):
        res = x
        out = self.conv3(x)
        out = self.se(out)
        out = self.norm(out)
        out = self.scale(out)
        out = self.dp(out)
        return res + out

# -----------------------------------------------------------------------------
# Stage Blocks
# -----------------------------------------------------------------------------

class StageOneBlock(nn.Module):
    """
    DCNv4Lite + SE + (optional) LayerScale + DropPath
    """
    def __init__(self,
                channels: int = 128,
                drop_path: float = 0.0,
                use_layer_scale: bool = True,
                layer_scale_init: float = 1.0,
                norm_factory: NormFactory = NormFactory("gn")):
        super().__init__()
        self.dcn = DCNv4Lite(channels, group=8)
        self.se  = ChannelAttention(channels, reduction=8)
        self.post_norm = norm_factory(channels)

        self.scale = (LayerScale(channels, init_values=layer_scale_init)
                      if use_layer_scale else nn.Identity())

        self.dp = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

    def forward(self, x):
        res = x
        out = self.dcn(x)
        out = self.se(out)
        out = self.post_norm(out)
        out = self.scale(out)
        out = self.dp(out)
        return res + out


class DownABlock(nn.Module):
    """
    Depthwise‑Conv (s=2)  +  Pointwise‑Conv  +  Norm + Act
    """
    def __init__(self,
                in_ch: int = 128,
                out_ch: int = 256,
                norm_factory: NormFactory = NormFactory("gn")):
        super().__init__()

        # 1) Depth‑wise strided conv
        self.dw = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1,
                            groups=in_ch, bias=False)

        # 2) Point‑wise projection
        self.pw = nn.Conv2d(in_ch, out_ch, 1, bias=False)

        # 3) Normalization + Activation
        #    a) If the norm type is LayerNormProxy, use LNAct
        #    b) In other cases, use norm + SiLU
        nf_layer = norm_factory(out_ch)
        if isinstance(nf_layer, LayerNormProxy):
            # LNAct = LayerNormProxy + SiLU
            from DAT.models.dat_blocks import LNAct as _LNAct
            self.norm_act = _LNAct(out_ch)
        else:
            self.norm_act = nn.Sequential(nf_layer, nn.SiLU(inplace=True))

        # --- Kaiming init ---
        nn.init.kaiming_normal_(self.dw.weight, mode="fan_out", nonlinearity="relu")
        nn.init.kaiming_normal_(self.pw.weight, mode="fan_out", nonlinearity="relu")

    # --------------------------------------------------
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.dw(x)
        x = self.pw(x)
        return self.norm_act(x)


class StageTwoBlock(nn.Module):
    """
    Stage‑2: Only DCNv4Lite path (no DAT, no GateFuse).
    CSP split is kept for comparable FLOPs/latency.
    """
    def __init__(self,
                 channels: int = 256,
                 dcn_group: int = 16,
                 drop_path: float = 0.1,
                 norm_factory: NormFactory = NormFactory("gn")):
        super().__init__()
        assert channels % 2 == 0, "StageTwoBlock expects even channel count."
        self.split_channels = channels // 2

        # --- DCNv4 branch (only this one active) ---
        self.dcn = DCNv4Lite(self.split_channels, group=max(1, dcn_group // 2))
        self.post_norm = norm_factory(self.split_channels)
        self.post_act  = nn.SiLU(inplace=True)

        # --- Skip proj (for grad out) ---
        self.skip_proj = nn.Sequential(
            nn.Conv2d(self.split_channels, self.split_channels, 1, bias=False),
            norm_factory(self.split_channels)
        )

        # --- Output merging ---
        self.fusion = nn.Conv2d(channels, channels, 1, bias=False)
        self.ln     = norm_factory(channels)
        self.act    = nn.SiLU(inplace=True)
        self.dp     = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        # Init
        nn.init.kaiming_normal_(self.fusion.weight, mode="fan_out", nonlinearity="relu")
        nn.init.kaiming_normal_(self.skip_proj[0].weight, mode="fan_out", nonlinearity="relu")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # CSP split
        x1, x2 = torch.split(x, self.split_channels, dim=1)

        # Skip branch
        x1 = self.skip_proj(x1)

        # DCNv4 branch
        x2 = self.dcn(x2)
        x2 = self.post_act(self.post_norm(x2))

        # Fuse & residual
        out = torch.cat([x1, x2], dim=1)
        out = self.fusion(out)
        out = self.act(self.ln(out))
        return x + self.dp(out)

class DownBBlock(nn.Module):
    """DW 3×3 s2 → PW 1×1, 256c → 640c"""
    def __init__(self, in_ch: int = 256, out_ch: int = 640,norm_factory: NormFactory = NormFactory("gn")):
        super().__init__()
        self.dw = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1,
                          groups=in_ch, bias=False)
        self.pw = nn.Conv2d(in_ch, out_ch, 1, bias=False)
        self.norm = norm_factory(out_ch)
        self.act = nn.SiLU(inplace=True)

        for m in [self.dw, self.pw]:
            nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.act(self.norm(self.pw(self.dw(x))))


class StageThreeBlock(nn.Module):
    """Double DATHubLite + ChannelAttention"""
    def __init__(self, in_ch: int = 640, out_ch: int = 640, drop_path: float = 0.15,norm_factory: NormFactory = NormFactory("gn")):
        super().__init__()
        self.proj = nn.Conv2d(in_ch, out_ch, 1, bias=False)
        self.norm = norm_factory(out_ch)
        self.act = nn.SiLU(inplace=True)

        heads = max(4, out_ch // 64)
        self.dat1 = DATBlock(out_ch, heads=heads)
        self.dat2 = DATBlock(out_ch, heads=heads)
        self.dp1 = DropPath(drop_path)
        self.dp2 = DropPath(drop_path)
        self.ca = ChannelAttention(out_ch, reduction=8)

        nn.init.kaiming_normal_(self.proj.weight, mode="fan_out", nonlinearity="relu")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.act(self.norm(self.proj(x)))
        res = x
        x = self.dat1(x)
        x = self.dp1(x) + res
        res = x
        x = self.dat2(x)
        x = self.dp2(x) + res
        x = self.ca(x)
        return x


class DownCBlock(nn.Module):
    """DW 3×3 s2 → PW 1×1, 640c → 768c"""
    def __init__(self, in_ch: int = 640, out_ch: int = 768,norm_factory: NormFactory = NormFactory("gn")):
        super().__init__()
        self.dw = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1,
                          groups=in_ch, bias=False)
        self.pw = nn.Conv2d(in_ch, out_ch, 1, bias=False)
        self.norm = norm_factory(out_ch)
        self.act = nn.SiLU(inplace=True)

        nn.init.kaiming_normal_(self.dw.weight, mode="fan_out", nonlinearity="relu")
        nn.init.kaiming_normal_(self.pw.weight, mode="fan_out", nonlinearity="relu")

    def forward(self, x):
        x = self.dw(x)
        x = self.pw(x)

        return self.act(self.norm(x))
# -----------------------------------------------------------------------------
# StageAwareBackbone (
# -----------------------------------------------------------------------------

class StageAwareBackbone(nn.Module):
    """
    Four-tier backbone based on DAT + DCN
    """
    def __init__(self,
                depths: list[int] = (3, 4, 8, 3),
                drop_path_max: float = 0.0,
                num_classes: int = 80,
                voc_prior: bool = False,
                norm_factory: NormFactory = NormFactory("gn"),
                use_layer_scale: bool = True,
                layer_scale_init: float = 1.0):
        super().__init__()
        self.num_classes = num_classes
        self.voc_prior   = voc_prior

        # ---- DropPath rates ----
        total_blocks = sum(depths)
        dpr = get_drop_path_rates(total_blocks, drop_path_max)
        for i in range(depths[0] + depths[1]):   # st0, st1
            dpr[i] = 0.0

        # ---- Stem ----
        self.stem = Stem(128)

        # ---- Stage‑0 ----
        dp_idx = 0
        self.st0 = nn.Sequential(*[
          StageZeroBlock(
              128,
              drop_path_rate=dpr[i],
              use_layer_scale=use_layer_scale,
              layer_scale_init=layer_scale_init,
              norm_factory=norm_factory
          )
          for i in range(depths[0])
      ])
        dp_idx += depths[0]

        # ---- Stage‑1 ----
        self.st1 = nn.Sequential(*[
            StageOneBlock(
                128,
                drop_path=dpr[dp_idx + i],
                use_layer_scale=use_layer_scale,
                layer_scale_init=layer_scale_init,
                norm_factory=norm_factory)
            for i in range(depths[1])
        ])
        dp_idx += depths[1]

        # ---- PGI + Down/CSI ----
        self.pgi_s1 = PGIModule(128, norm_factory=norm_factory, aux_channels=128)   # s0_out → s1_raw
        self.da     = DownABlock(128, 256, norm_factory=norm_factory)
        self.csi_s1_s2 = CrossScaleInjection(low_ch=128, high_ch=256)
        self.pgi_s2 = PGIModule(256, norm_factory=norm_factory, aux_channels=128)

        # ---- Stage‑2 ----
        self.st2 = nn.Sequential(*[
            StageTwoBlock(
                256, dcn_group=16,
                drop_path=dpr[dp_idx + i],
                norm_factory=norm_factory)
            for i in range(depths[2])
        ])
        dp_idx += depths[2]

        # ---- DownB / PGI ----
        self.db     = DownBBlock(256, 640, norm_factory=norm_factory)
        self.pgi_s3 = PGIModule(640, norm_factory=norm_factory, aux_channels=256)

        # ---- Stage‑3 ----
        self.st3 = nn.Sequential(*[
            StageThreeBlock(
                640, 640,
                drop_path=dpr[dp_idx + i],
                norm_factory=norm_factory)
            for i in range(depths[3])
        ])

        # ---- DownC ----
        self.dc = DownCBlock(640, 768, norm_factory=norm_factory)

        # ---- Auxiliary dense heads ----
        self.aux_head_s1 = nn.Conv2d(128, num_classes + 4, 1, bias=True)
        self.aux_head_s2 = nn.Conv2d(256, num_classes + 4, 1, bias=True)
        self.aux_head_s3 = nn.Conv2d(640, num_classes + 4, 1, bias=True)
        self._init_aux_heads()

    # --------------------------------------------------
    def _init_aux_heads(self):
        prior_value = 19 if self.voc_prior else 99
        for head in (self.aux_head_s1, self.aux_head_s2, self.aux_head_s3):
            nn.init.normal_(head.weight, std=0.01)
            nn.init.constant_(head.bias[: self.num_classes], -math.log(prior_value))
            nn.init.constant_(head.bias[self.num_classes :], 0.0)

    # --------------------------------------------------
    def forward(self, x: torch.Tensor, need_aux: bool = False):
        x      = self.stem(x)
        s0_out = self.st0(x)

        s1_raw = self.st1(s0_out)
        p3     = self.pgi_s1(s1_raw, aux_input=s0_out)

        p3_down = self.da(p3)
        s2_in   = self.csi_s1_s2(p3, p3_down)
        s2_in   = self.pgi_s2(s2_in, aux_input=p3)

        p4 = self.st2(s2_in)
        p4_down = self.db(p4)
        s3_in   = self.pgi_s3(p4_down, aux_input=p4)

        p5 = self.st3(s3_in)
        p6 = self.dc(p5)

        if self.training and need_aux:
            aux = {
                's1': self.aux_head_s1(p3),
                's2': self.aux_head_s2(p4),
                's3': self.aux_head_s3(p5),
            }
            return (p3, p4, p5, p6), aux
        return p3, p4, p5, p6








# ----------------------------------------------------------
# SpatialFuseNode – softplus-normalized, gradient-safe
# ----------------------------------------------------------
class SpatialFuseNode(nn.Module):
    """
    Spatial-aware fuse (group-wise, production-ready, RPB-friendly)
      • Group-average energy map per input (B,G,K,H,W)
      • 1×1 projection per group (K→K), softplus-normalized
      • Initial fully uniform: proj.weight=0, bias=0  → equal w for each branch
      • Learnable temperature τ (with clamp), optional uniform floor (no dead branches)
      • No detach anywhere → gradients are unclipped (including RPBs)

    Args:
        n_inputs (int): K, how many features to combine (>=2)
        channels (int): C, number of channels
        groups (int): G, number of groups (automatically drops to 1 if C % G ≠ 0)
        tau_init (float): τ start (recommended: 0.9)
        learnable_tau (bool): Should τ be learned?
        eps (float): Numerical epsilon for normalization and splitting
        init_noise (float): (backward compatibility parameter; not used with uniform init)
        gate_floor (float): w ← (1-α)·w + α·(1/K); α∈[0,0.1] recommended, 0 closed
        tau_bounds (tuple): τ lower/upper bounds (e.g., (0.5, 2.0))
        uniform_init (bool): True ⇒ weight=0, bias=0 (full uniform initialization)
        save_last (bool): If True, last w is saved (for debug/regularizer)
        prenorm (str): “none” | “rms”  — RMS prenorm on the K axis (reduces saturation)

    Notes:
      • Returns extra_loss() for the entropy regularizer (optional, can be added to the loss).
      • The last w (B,G,K,H,W) can be inspected with get_last_weights() (if save_last=True).
    """
    def __init__(self,
                 n_inputs: int,
                 channels: int,
                 groups: int = 4,
                 tau_init: float = 0.9,
                 learnable_tau: bool = True,
                 eps: float = 5e-4,
                 init_noise: float = 1e-3,
                 *,
                 gate_floor: float = 0.0,
                 tau_bounds: tuple = (0.5, 2.0),
                 uniform_init: bool = True,
                 save_last: bool = False,
                 prenorm: str = "none"):
        super().__init__()
        assert n_inputs >= 2, "SpatialFuseNode: should be n_inputs >= 2 ."
        self.K = int(n_inputs)
        self.C = int(channels)
        self.G = int(groups) if (groups > 0 and channels % groups == 0) else 1
        self.gC = self.C // self.G
        self.eps = float(eps)
        self.gate_floor = float(gate_floor)
        self.tau_lo, self.tau_hi = float(tau_bounds[0]), float(tau_bounds[1])
        self.save_last = bool(save_last)
        self.prenorm = str(prenorm).lower()
        self._last_w = None  # debug amaçlı

        # Group-wise K→K projection (produces spatially-varying logit)
        # Takes input divided into G groups (B, G*K, H, W) → Output (B, G*K, H, W)
        self.proj = nn.Conv2d(self.G * self.K, self.G * self.K, kernel_size=1,
                              groups=self.G, bias=True)

        # --- Uniform start (to prevent early arm saturation) ---
        if uniform_init:
            nn.init.zeros_(self.proj.bias)
            with torch.no_grad():
                self.proj.weight.zero_()
        else:
            nn.init.kaiming_uniform_(self.proj.weight, a=math.sqrt(5))
            nn.init.zeros_(self.proj.bias)

        # --- temperature τ ---
        if learnable_tau:
            self.log_tau = nn.Parameter(torch.log(torch.tensor(float(tau_init))))
        else:
            self.register_buffer("log_tau", torch.log(torch.tensor(float(tau_init))), persistent=False)

    # ------------------ helpers ------------------
    @torch.no_grad()
    def set_tau(self, tau: float):
        """To adjust τ during heating."""
        v = max(1e-3, float(tau))
        t = torch.log(torch.tensor(v, device=self.log_tau.device, dtype=self.log_tau.dtype))
        self.log_tau.copy_(t)

    def get_last_weights(self):
        """The last calculated w (B, G, K, H, W). If save_last=True, it is saved."""
        return self._last_w

    def extra_loss(self) -> dict:
        """
        Optional regulator: gate entropy (higher values result in more balanced branching).
        You can add a small coefficient on the loss side (e.g., 1e-4..5e-4).
        """
        if self._last_w is None:
            return {}
        p = self._last_w.clamp_min(1e-8)
        # Entropy: -sum_k p log p / log(K)  → [0,1]
        ent = -(p * p.log()).sum(dim=2) / math.log(self.K)   # (B,G,H,W)
        return {"gate_entropy": ent.mean()}

    def _group_pool(self, x: torch.Tensor) -> torch.Tensor:
        # (B,C,H,W) → (B,G,H,W), average per group
        B, C, H, W = x.shape
        return x.view(B, self.G, self.gC, H, W).mean(dim=2)

    # ------------------ forward ------------------
    def forward(self, *features: torch.Tensor) -> torch.Tensor:
        # All inputs must be in the same format.
        K = len(features)
        assert K == self.K, f"SpatialFuseNode: {self.K} giriş bekleniyordu, {K} geldi."
        B, C, H, W = features[0].shape
        for f in features:
            assert f.shape == (B, C, H, W), "Tüm giriş feature'lar (B,C,H,W) aynı şekil olmalı."

        # 1) Grup havuzu → (B,G,K,H,W)
        gp = [self._group_pool(f) for f in features]
        gp_cat = torch.stack(gp, dim=2)  # (B,G,K,H,W)

        # 2) (optional) RMS prenorm → logit scale control on the K axis
        if self.prenorm == "rms":
            rms = gp_cat.pow(2).mean(dim=2, keepdim=True).add(1e-6).sqrt()
            gp_cat = gp_cat / rms

        # 3) Projection and gate logits
        x = gp_cat.flatten(1, 2)                      # (B, G*K, H, W)
        logits = self.proj(x).view(B, self.G, self.K, H, W)

        # 4) Softplus-normalize + τ
        tau = self.log_tau.exp().clamp(self.tau_lo, self.tau_hi)
        w = F.softplus(logits / tau) + 1e-9           # (B,G,K,H,W), her yerde >0
        w = w / (w.sum(dim=2, keepdim=True) + self.eps)

        # 5) Uniform floor (no dead leg, gradient flows to every leg)
        if self.gate_floor > 0.0:
            u = 1.0 / float(self.K)
            w = (1.0 - self.gate_floor) * w + self.gate_floor * u  # no need to normalize again

        if self.save_last:
            self._last_w = w.detach()

        # 6) Distribute weights to channels efficiently (no repeats, memory-friendly)
        #    w_k: (B,G,1,H,W), f: (B,G,gC,H,W) → contribution (B,C,H,W)
        out = None
        for k, f in enumerate(features):
            wk = w[:, :, k, :, :].unsqueeze(2)                 # (B,G,1,H,W)
            contrib = (f.view(B, self.G, self.gC, H, W) * wk).view(B, C, H, W)
            out = contrib if out is None else (out + contrib)

        return out


class RMSNorm2d(nn.Module):
    def __init__(self, eps: float = 1e-6):
        super().__init__()
        self.eps = float(eps)
    def forward(self, x):
        return x / (x.pow(2).mean(dim=(2,3), keepdim=True).add(self.eps).sqrt())

def _init_laplacian_dw(dw: nn.Conv2d):
    with torch.no_grad():
        k = torch.tensor([[0., 1., 0.],
                          [1.,-4., 1.],
                          [0., 1., 0.]], dtype=dw.weight.dtype, device=dw.weight.device)
        w = torch.zeros_like(dw.weight)
        w[:, 0, :, :] = k
        dw.weight.copy_(w)

def _ste_boost(x: torch.Tensor, gain: float) -> torch.Tensor:
    return x if gain <= 0 else (x + gain * (x - x.detach()))

class _FPNResBlock(nn.Module):
    def __init__(self, channels: int, drop_path: float = 0.0, norm_factory: 'NormFactory' = None):
        super().__init__()
        nf = norm_factory or NormFactory("gn")
        self.conv = ConvLNAct(channels, channels, k=3)
        self.ls   = LayerScale(channels, init_values=0.6)
        self.dp   = DropPath(drop_path) if drop_path > 0 else nn.Identity()
    def forward(self, x):
        return x + self.dp(self.ls(self.conv(x)))

class LightBiFPN(nn.Module):
    """
    LightBiFPN v4.1 — spatial‑aware, RPB‑friendly, repeat‑shared up/edge
      • SpatialFuseNode (group‑based 2D weights)
      • No‑blur in first top‑down (bilinear + norm/act)
      • Up-refine: DW 3×3 + (optional) Laplacian residual (small LS)
      • RMS prenorm (optional)
      • Low DropPath schedule (0.0 → 0.01)
      • Light grad-boost for p3
      • **New**: 3 up/edge blocks and **shared between iterations**
                  → No “NEVER UPDATED”, no parameter waste
    """
    def __init__(
        self,
        in_channels: tuple[int,int,int,int] = (128, 256, 640, 768),
        out: int = 256,
        repeats: int = 2,
        *,
        norm_factory: 'NormFactory' | None = None,
        # ---- fuse ----
        use_spatial_fuse: bool = True,
        fuse_groups: int = 4,
        fuse_tau_init: float = 0.9,
        fuse_learn_tau: bool = True,
        fuse_eps: float = 5e-4,
        init_noise: float = 1e-4,
        # ---- refine & norm ----
        prenorm: str = "rms",            # "none" | "rms"
        no_blur_first: bool = True,
        edge_enhance: bool = True,
        edge_ls_init: float = 0.03,
        # ---- drop path ----
        dp_top_second: float = 0.01,
        dp_bot_second: float = 0.01,
        # ---- grad boost ----
        grad_boost_low: float = 0.1,
    ):
        super().__init__()
        in3, in4, in5, in6 = in_channels
        self.repeats = int(repeats)
        self.out = int(out)
        self.norm_factory = norm_factory or NormFactory("gn")
        self.prenorm = prenorm.lower()
        self.no_blur_first = bool(no_blur_first)
        self.edge_enhance = bool(edge_enhance)
        self.grad_boost_low = float(grad_boost_low)
        self.use_spatial_fuse = bool(use_spatial_fuse)
        self.fuse_groups = int(fuse_groups)
        self.fuse_tau_init = float(fuse_tau_init)
        self.fuse_learn_tau = bool(fuse_learn_tau)
        self.fuse_eps = float(fuse_eps)
        self.init_noise = float(init_noise)

        # 1x1 projections
        self.p3_in = ConvLNAct(in3, out, k=1, p=0)
        self.p4_in = ConvLNAct(in4, out, k=1, p=0)
        self.p5_in = ConvLNAct(in5, out, k=1, p=0)
        self.p6_in = ConvLNAct(in6, out, k=1, p=0)

        # DropPath schedule
        top_dps = [0.0, float(dp_top_second)]
        bot_dps = [0.0, float(dp_bot_second)]
        self.top_convs = nn.ModuleList([
            _FPNResBlock(out, drop_path=top_dps[i // 3], norm_factory=self.norm_factory)
            for i in range(3 * self.repeats)
        ])
        self.bot_convs = nn.ModuleList([
            _FPNResBlock(out, drop_path=bot_dps[i // 4], norm_factory=self.norm_factory)
            for i in range(4 * self.repeats)
        ])

        # Fuse nodes
        def _make_fuse(K: int):
            from typing import Callable

            return SpatialFuseNode(
                                      K, channels=out, groups=self.fuse_groups,
                                      tau_init=1.4, learnable_tau=True,
                                      tau_bounds=(0.8, 1.8),
                                      prenorm="rms",
                                      gate_floor=0.02,          # no death
                                      uniform_init=True,
                                      eps=5e-4,
                                      save_last=True
                                  )

        self.fuse2_top = nn.ModuleList([_make_fuse(2) for _ in range(3 * self.repeats)])
        self.fuse2_bot = nn.ModuleList([_make_fuse(2) for _ in range(3 * self.repeats)])
        self.fuse3_bot = nn.ModuleList([_make_fuse(3) for _ in range(1 * self.repeats)])

        # ---  3 up/edge blocks and shared between repeats ---
        self.up_dw, self.up_na = nn.ModuleList(), nn.ModuleList()
        self.edge_dw, self.edge_ls = nn.ModuleList(), nn.ModuleList()
        for _ in range(3):  # only for 3 items
            dw = nn.Conv2d(out, out, 3, 1, 1, groups=out, bias=False)
            nn.init.kaiming_normal_(dw.weight, mode="fan_out", nonlinearity="relu")
            self.up_dw.append(dw)

            nf_layer = self.norm_factory(out)
            if isinstance(nf_layer, LayerNormProxy):
                from DAT.models.dat_blocks import LNAct as _LNAct
                self.up_na.append(_LNAct(out))
            else:
                self.up_na.append(nn.Sequential(nf_layer, nn.SiLU(inplace=True)))

            edw = nn.Conv2d(out, out, 3, 1, 1, groups=out, bias=False)
            _init_laplacian_dw(edw)
            self.edge_dw.append(edw)
            self.edge_ls.append(LayerScale(out, init_values=edge_ls_init))

        # High-level passthrough LS (p5, p4, p3)
        self.keep_top = nn.ModuleList([LayerScale(out, init_values=0.10) for _ in range(3)])

        # Input balancing
        self.balance = RMSNorm2d(eps=1e-6) if self.prenorm == "rms" else nn.Identity()

    # --- helpers ---
    def _up_refine_shared(self, x: torch.Tensor, size_hw: tuple[int,int], step_id: int, rep_idx: int) -> torch.Tensor:
        """
        step_id: [0,1,2] → corresponds to each top-down step, independent of repeat
        """
        x = F.interpolate(x, size=size_hw, mode="bilinear", align_corners=False)
        if self.no_blur_first and rep_idx == 0:
            # no-blur first repeat: only norm/act
            x = self.up_na[step_id](x)
        else:
            x = self.up_dw[step_id](x)
            if self.edge_enhance:
                x = x + self.edge_ls[step_id](self.edge_dw[step_id](x))
            x = self.up_na[step_id](x)
        return x

    def _bal(self, *xs):
        return (tuple(self.balance(x) for x in xs) if self.prenorm != "none" else xs)

    def forward(self, p3, p4, p5, p6):
        # Proj
        p3, p4, p5, p6 = self.p3_in(p3), self.p4_in(p4), self.p5_in(p5), self.p6_in(p6)

        # Low-level grad-boost
        if self.training and self.grad_boost_low > 0:
            p3 = _ste_boost(p3, self.grad_boost_low)

        tconv = bconv = f2t = f2b = f3b = 0

        for rep in range(self.repeats):
            # ---------- Top‑down ----------
            # step_id = 0: p6→p5
            u5 = self._up_refine_shared(p6, p5.shape[-2:], step_id=0, rep_idx=rep)
            p5_b, u5_b = self._bal(p5, u5)
            p5_td = self.fuse2_top[f2t](p5_b, u5_b); f2t += 1
            p5_td = self.top_convs[tconv](p5_td); tconv += 1
            p5_td = p5_td + self.keep_top[0](p5)

            # step_id = 1: p5_td→p4
            u4 = self._up_refine_shared(p5_td, p4.shape[-2:], step_id=1, rep_idx=rep)
            p4_b, u4_b = self._bal(p4, u4)
            p4_td = self.fuse2_top[f2t](p4_b, u4_b); f2t += 1
            p4_td = self.top_convs[tconv](p4_td); tconv += 1
            p4_td = p4_td + self.keep_top[1](p4)

            # step_id = 2: p4_td→p3
            u3 = self._up_refine_shared(p4_td, p3.shape[-2:], step_id=2, rep_idx=rep)
            p3_b, u3_b = self._bal(p3, u3)
            p3_td = self.fuse2_top[f2t](p3_b, u3_b); f2t += 1
            p3_td = self.top_convs[tconv](p3_td); tconv += 1
            p3_td = p3_td + self.keep_top[2](p3)

            # ---------- Bottom‑up ----------
            p3 = self.fuse2_bot[f2b](*self._bal(p3, p3_td)); f2b += 1
            p3 = self.bot_convs[bconv](p3); bconv += 1

            p4 = self.fuse2_bot[f2b](*self._bal(p4_td, F.max_pool2d(p3, 2))); f2b += 1
            p4 = self.bot_convs[bconv](p4); bconv += 1

            p5 = self.fuse3_bot[f3b](*self._bal(p5, F.max_pool2d(p4, 2), p5_td)); f3b += 1
            p5 = self.bot_convs[bconv](p5); bconv += 1

            p6 = self.fuse2_bot[f2b](*self._bal(p6, F.max_pool2d(p5, 2))); f2b += 1
            p6 = self.bot_convs[bconv](p6); bconv += 1

        return p3, p4, p5, p6


# -----------------------------------------------------------------------------
# Positional Encoding
# -----------------------------------------------------------------------------

class Pos2d(nn.Module):
    """2D sinusoidal positional encoding"""
    def __init__(self, c: int = 320, max_h: int = 640, max_w: int = 640):
        super().__init__()
        assert c % 4 == 0
        self.cq = c // 4
        pos = self._build(max_h, max_w, c)
        self.register_buffer("pos_table", pos, persistent=False)

    def _build(self, H: int, W: int, C: int) -> torch.Tensor:
        yv, xv = torch.meshgrid(
            torch.linspace(0, 1, H), torch.linspace(0, 1, W), indexing="ij"
        )
        div = torch.exp(torch.arange(0, self.cq) * (-math.log(10000.0) / self.cq))
        pos_x = (xv[..., None] * div).reshape(H, W, -1)
        pos_y = (yv[..., None] * div).reshape(H, W, -1)
        pos = torch.cat([pos_y.sin(), pos_y.cos(), pos_x.sin(), pos_x.cos()], dim=2)
        return pos.permute(2, 0, 1).unsqueeze(0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        _, _, H, W = x.shape
        return self.pos_table[:, :, :H, :W]



import inspect

# -----------------------------------------------------------------------------
# DATBlock –  TransformerStage wrapper
# -----------------------------------------------------------------------------
class DATBlock(nn.Module):
    def __init__(self,
                channels: int,
                depth: int = 2,
                heads: int | None = None,
                window_size: int = 8,
                drop: float = 0.0):
        super().__init__()
        heads = heads or max(4, channels // 32)

        cfg = dict(
            fmap_size           = (window_size, window_size),
            window_size         = window_size,
            ns_per_pt           = 4,
            dim_in              = channels,
            dim_embed           = channels,
            depths              = depth,               # INT
            stage_spec          = 'N' * depth,
            n_groups            = 1,
            use_pe              = True,
            sr_ratio            = 1,
            heads               = heads,               # INT
            heads_q             = [heads] * depth,     # LIST
            stride              = 1,
            offset_range_factor = 2,
            dwc_pe              = True,
            no_off              = False,
            fixed_pe            = False,
            attn_drop           = drop,
            proj_drop           = drop,
            expansion           = 4,
            drop                = drop,
            drop_path_rate      = [drop] * depth,      # LIST
            use_dwc_mlp         = False,
            ksize               = 3,
            nat_ksize           = 7,
            k_qna               = 8,
            nq_qna              = 9,
            qna_activation      = "relu",
            layer_scale_value   = 0.3,
            use_lpu             = False,
            log_cpb             = True,
        )

        from DAT.models.dat import TransformerStage
        self.stage = TransformerStage(**cfg)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x = (B,C,H,W) –
        if hasattr(self.stage, 'fmap_size'):
            self.stage.fmap_size = x.shape[-2:]
        try:
            return self.stage(x)
        except TypeError:                           # some versions (x,H,W) require
            H, W = x.shape[-2:]
            return self.stage(x, H, W)


# -----------------------------------------------------------------------------
# Deformable Encoder (production-ready)
# -----------------------------------------------------------------------------

class GLUFFN(nn.Module):
    """
    GLU-based FFN:
      - First linear: d_model -> 2*inner
      - Gate: SwiGLU (SiLU) or GEGLU (GELU)
      - Second linear: inner -> d_model

    inner width:
      • if ffn_dim is given: inner = ffn_dim // 2  (2*inner = ffn_dim)
      • otherwise: inner = round(d_model * ffn_mult)
    """
    def __init__(
        self,
        d_model: int,
        ffn_dim: Optional[int] = None,
        ffn_mult: float = 2.0,
        act: str = "swiglu",
        drop: float = 0.0,
    ):
        super().__init__()
        if ffn_dim is not None:
            inner = max(32, int(ffn_dim) // 2)
        else:
            inner = max(32, int(round(d_model * float(ffn_mult))))
        self.fc1 = nn.Linear(d_model, inner * 2, bias=True)
        self.fc2 = nn.Linear(inner, d_model, bias=True)
        self.drop = nn.Dropout(drop) if drop > 0 else nn.Identity()

        act = act.lower()
        if act not in ("swiglu", "geglu"):
            raise ValueError("GLUFFN.act must be 'swiglu' or 'geglu'")
        self.act_kind = act

        nn.init.xavier_uniform_(self.fc1.weight); nn.init.zeros_(self.fc1.bias)
        nn.init.xavier_uniform_(self.fc2.weight); nn.init.zeros_(self.fc2.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        xh = self.fc1(x)
        x_lin, x_gate = xh.chunk(2, dim=-1)
        if self.act_kind == "swiglu":
            gated = F.silu(x_gate) * x_lin
        else:  # 'geglu'
            gated = F.gelu(x_gate) * x_lin
        return self.drop(self.fc2(gated))




class DeformEncoderLayer(nn.Module):
    """
    MS-Deformable self-attention + GLU-FFN, Pre-Norm + LayerScale + DropPath.

    Args (backward-compatible):
        d_model, n_heads, n_levels, n_points
        ffn_dim:   optional (if given, GLU inner width is ffn_dim//2)
        ffn_mult:  used if ffn_dim is not specified (default 2.0)
        ffn_act:   ‘swiglu’ (default) or ‘geglu’
        drop, drop_path
    """
    def __init__(
        self,
        d_model: int = 320,
        n_heads: int = 10,
        n_levels: int = 4,
        n_points: int = 4,
        *,
        ffn_dim: Optional[int] = None,
        ffn_mult: float = 2.0,
        ffn_act: str = "swiglu",
        drop: float = 0.0,
        drop_path: float = 0.1,
    ):
        super().__init__()

        self.self_attn = MultiScaleDeformableAttention(
            embed_dims=d_model, num_heads=n_heads,
            num_levels=n_levels, num_points=n_points,
            batch_first=True,
        )

        self.norm1 = nn.LayerNorm(d_model)
        self.ls1   = LayerScale(d_model, init_values=0.4)
        self.drop1 = nn.Dropout(drop) if drop > 0 else nn.Identity()
        self.dp1   = DropPath(drop_path) if drop_path > 0 else nn.Identity()

        self.ffn   = GLUFFN(d_model, ffn_dim=ffn_dim, ffn_mult=ffn_mult, act=ffn_act, drop=drop)
        self.norm2 = nn.LayerNorm(d_model)
        self.ls2   = LayerScale(d_model, init_values=0.4)
        self.drop2 = nn.Dropout(drop) if drop > 0 else nn.Identity()
        self.dp2   = DropPath(drop_path) if drop_path > 0 else nn.Identity()

    @staticmethod
    def _make_encoder_reference_points(
        spatial_shapes: torch.Tensor, B: int, device, dtype
    ) -> torch.Tensor:
        ref_list = []
        for (H, W) in spatial_shapes.tolist():
            ref_y, ref_x = torch.meshgrid(
                torch.linspace(0.5, H - 0.5, H, device=device, dtype=dtype) / H,
                torch.linspace(0.5, W - 0.5, W, device=device, dtype=dtype) / W,
                indexing="ij",
            )
            ref = torch.stack((ref_x, ref_y), dim=-1).reshape(-1, 2)  # (HW,2)
            ref_list.append(ref)
        return torch.cat(ref_list, dim=0)[None, :, None, :].repeat(B, 1, spatial_shapes.size(0), 1)

    def forward(
        self,
        src: torch.Tensor,             # (B, sumHW, C)
        pos: torch.Tensor,             # (B, sumHW, C)
        spatial_shapes: torch.Tensor,  # (L, 2)
        lvl_start_idx: torch.Tensor,   # (L,)
        key_padding_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:

        B, N, C = src.shape
        ref_pts = self._make_encoder_reference_points(
            spatial_shapes=spatial_shapes, B=B, device=src.device, dtype=src.dtype
        )  # (B,N,L,2)

        x = src
        q = self.norm1(x)
        attn_out = self.self_attn(
            query=q, value=q,
            reference_points=ref_pts,
            spatial_shapes=spatial_shapes,
            level_start_index=lvl_start_idx,
            key_padding_mask=key_padding_mask,
            query_pos=pos,
        )
        x = x + self.dp1(self.ls1(self.drop1(attn_out)))

        y = self.ffn(self.norm2(x))
        x = x + self.dp2(self.ls2(self.drop2(y)))
        return x

class TinyDeformEncoder(nn.Module):
    """
    Two layers are recommended (lightweight and effective).
    get_drop_path_rates() is available;
    """
    def __init__(self,
                 num_layers: int = 1,
                 d_model: int = 320,
                 n_heads: int = 10,
                 n_levels: int = 4,
                 n_points: int = 4,
                 ffn_dim: int = 1024,
                 drop: float = 0.0,
                 drop_path_max: float = 0.1):
        super().__init__()
        dpr = get_drop_path_rates(num_layers, drop_path_max)
        self.layers = nn.ModuleList([
            DeformEncoderLayer(d_model=d_model,
                               n_heads=n_heads,
                               n_levels=n_levels,
                               n_points=n_points,
                               ffn_dim=ffn_dim,
                               drop=drop,
                               drop_path=dpr[i])
            for i in range(num_layers)
        ])

    def forward(self,
                src: torch.Tensor,             # (B, sumHW, C)
                pos: torch.Tensor,             # (B, sumHW, C)
                spatial_shapes: torch.Tensor,  # (L,2)
                lvl_start_idx: torch.Tensor,   # (L,)
                key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        for ly in self.layers:
            src = ly(src, pos, spatial_shapes, lvl_start_idx, key_padding_mask)
        return src

# -----------------------------------------------------------------------------
# RT-Deform Decoder (Fixed IoU-aware padding)
# -----------------------------------------------------------------------------




class DeformDecoderLayer(nn.Module):
    """
    Decoder:
      Pre-Norm + LayerScale + DropPath + GLU-FFN + tanh-bounded ref update.
      ✔ Optional self-attention mask support (for DN↔MATCH isolation)

    Args:
        d_model, n_heads, n_levels, n_points
        ffn_dim:   optional (for GLU inner width, 2*inner = ffn_dim)
        ffn_mult:  used if ffn_dim is absent
        ffn_act:   ‘swiglu’ | ‘geglu’
        drop, drop_path
        refine_scale: float or (lo,hi) — schedule with layer-id
        grad_eps: very small nudg
    """
    def __init__(
        self,
        d_model: int = 320,
        n_heads: int = 10,
        n_levels: int = 4,
        n_points: int = 4,
        *,
        ffn_dim: Optional[int] = None,
        ffn_mult: float = 2.0,
        ffn_act: str = "swiglu",
        drop: float = 0.0,
        drop_path: float = 0.1,
        refine_scale: Union[float, Tuple[float, float]] = 0.5,
        layer_id: Optional[int] = None,
        num_layers: Optional[int] = None,
        grad_eps: float = 1e-4,
    ):
        super().__init__()
        self.grad_eps = float(grad_eps)

        # refine scale
        if isinstance(refine_scale, (tuple, list)):
            lo, hi = float(refine_scale[0]), float(refine_scale[1])
            if (num_layers is not None) and (layer_id is not None) and (num_layers > 1):
                t = float(layer_id) / float(num_layers - 1)
                self.refine_scale = lo + (hi - lo) * t
            else:
                self.refine_scale = 0.5 * (lo + hi)
        else:
            self.refine_scale = float(refine_scale)

        # Self-Attn (MHA) + Pre-Norm
        self.self_attn = nn.MultiheadAttention(
            embed_dim=d_model, num_heads=n_heads, batch_first=True
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.ls1   = LayerScale(d_model, init_values=0.4)
        self.drop1 = nn.Dropout(drop) if drop > 0 else nn.Identity()
        self.dp1   = DropPath(drop_path) if drop_path > 0 else nn.Identity()

        # Cross-Attn (MS-Deformable) + Pre-Norm
        self.cross_attn = MultiScaleDeformableAttention(
            embed_dims=d_model, num_heads=n_heads,
            num_levels=n_levels, num_points=n_points,
            batch_first=True,
        )
        self.norm2 = nn.LayerNorm(d_model)
        self.ls2   = LayerScale(d_model, init_values=0.4)
        self.drop2 = nn.Dropout(drop) if drop > 0 else nn.Identity()
        self.dp2   = DropPath(drop_path) if drop_path > 0 else nn.Identity()

        # FFN (GLU) + Pre-Norm
        self.ffn   = GLUFFN(d_model, ffn_dim=ffn_dim, ffn_mult=ffn_mult, act=ffn_act, drop=drop)
        self.norm3 = nn.LayerNorm(d_model)
        self.ls3   = LayerScale(d_model, init_values=0.4)
        self.drop3 = nn.Dropout(drop) if drop > 0 else nn.Identity()
        self.dp3   = DropPath(drop_path) if drop_path > 0 else nn.Identity()

        # Ref delta head (tanh clamp)
        self.ref_mlp = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Linear(d_model, 2),
        )
        for m in self.ref_mlp:
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight); nn.init.zeros_(m.bias)

    @staticmethod
    def _inv_sigmoid(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
        x = x.clamp(eps, 1.0 - eps)
        return torch.log(x) - torch.log(1.0 - x)

    def forward(
        self,
        tgt: torch.Tensor,                 # (B, N, C)
        ref_pts: torch.Tensor,             # (B, N, L, 2)
        src: torch.Tensor,                 # (B, S, C)
        query_pos: torch.Tensor,           # (B, N, C)
        spatial_shapes: torch.Tensor,      # (L, 2)
        lvl_start_idx: torch.Tensor,       # (L,)
        self_attn_mask: Optional[torch.Tensor] = None  # (N, N) bool ya da (B*h, N, N)
    ) -> Tuple[torch.Tensor, torch.Tensor]:

        x = tgt
        q = self.norm1(x)
        # ✔ MHA with mask for DN↔MATCH isolation
        sa_out = self.self_attn(q, q, q, need_weights=False, attn_mask=self_attn_mask)[0]
        x = x + self.dp1(self.ls1(self.drop1(sa_out)))

        q = self.norm2(x)
        ca_out = self.cross_attn(
            query=q, value=src,
            reference_points=ref_pts,
            spatial_shapes=spatial_shapes,
            level_start_index=lvl_start_idx,
            query_pos=query_pos,
        )
        x = x + self.dp2(self.ls2(self.drop2(ca_out)))

        y = self.ffn(self.norm3(x))
        x = x + self.dp3(self.ls3(self.drop3(y)))

        # tanh-bounded iterative ref update
        delta = torch.tanh(self.ref_mlp(x)) * self.refine_scale      # (B,N,2)
        new_ref = torch.sigmoid(self._inv_sigmoid(ref_pts) + delta.unsqueeze(2).to(ref_pts.dtype))

        # tiny nudge
        x = x + 1e-4 * delta.mean(dim=2, keepdim=True)
        return x, new_ref


class HeadPrep(nn.Module):
    def __init__(self, d_model: int, ls_init: float = 0.80, ln_eps: float = 1e-6):
        super().__init__()
        self.norm = nn.LayerNorm(d_model, eps=ln_eps)
        self.ls   = LayerScale(d_model, init_values=ls_init)
    def forward(self, x):  # x: (B, N, C)
        return self.ls(self.norm(x))

class BoxPosEnc(nn.Module):
    """Sin/cos box positional encoding for (cx, cy, w, h) in [0,1] -> (B,Q,C)"""
    def __init__(self, c: int = 320, *, use_2pi: bool = True):
        super().__init__()
        assert c % 4 == 0, "BoxPosEnc: c, 4'e bölünebilmeli"
        self.c = c
        self.c_per = c // 4           # her (cx,cy,w,h) için kanal sayısı
        assert self.c_per % 2 == 0, "c/4 çift olmalı (sin+cos için)"
        half = self.c_per // 2
        inv_freq = 10000 ** (-torch.arange(0, half).float() / max(1, half))
        self.register_buffer("inv_freq", inv_freq, persistent=True)
        self.use_2pi = use_2pi

    def _pe1d(self, v: torch.Tensor) -> torch.Tensor:
        # v: (B,Q) -> (B,Q,c_per)
        x = v.unsqueeze(-1)  # (B,Q,1)
        if self.use_2pi:
            x = x * (2.0 * math.pi)
        x = x * self.inv_freq  # (B,Q,half)
        return torch.cat([x.sin(), x.cos()], dim=-1)  # (B,Q,c_per)

    def forward(self, boxes: torch.Tensor) -> torch.Tensor:
        # boxes: (B,Q,4) in [0,1]
        cx = self._pe1d(boxes[..., 0])
        cy = self._pe1d(boxes[..., 1])
        w  = self._pe1d(boxes[..., 2])
        h  = self._pe1d(boxes[..., 3])
        return torch.cat([cx, cy, w, h], dim=-1)  # (B,Q,C)

class RTDeformDecoder(nn.Module):
    """
    RT-Deform Decoder (  DN + Self-Attn Mask)
      • Denoising (group-based POS/NEG, relative noise)
      • DN ↔ MATCH self-attention isolation (attn mask)
      • Proposal + MQS + learned query distribution (same as original API)
      • No IoU‑aware header (use_iou_aware always False; API preserved)

    Output dictionary:
      - ‘pred_logits’, ‘pred_boxes’, ‘final_ref_pts’, 'query_selection_mask'
      - in training: ‘aux_outputs’, ‘dn_meta’ (includes dn_len & dn_queries)
      - optional: ‘pred_obj_logits’ (if train or emit_obj_logits_in_eval=True)
    """

    def __init__(self,
                 num_obj_classes: int = 20,
                 include_background: bool = True,
                 *,
                 num_queries: int = 300,
                 num_layers: int = 4,
                 d_model: int = 320,
                 attn_n_heads: int = 10,
                 dn_queries: int = 100,
                 use_iou_aware: bool = False,
                 iou_k_ratio: float = 0.75,

                 # Encoder
                 use_encoder: bool = True,
                 encoder_layers: int = 1,
                 encoder_drop_path_max: float = 0.1,

                 # MQS
                 mqs_enable: bool = True,
                 mqs_obj_ratio: float = 0.30,
                 mqs_grid_ratio: float = 0.20,
                 mqs_levels: Tuple[int, ...] = (1, 2, 3),
                 mqs_local_max_kernel: int = 3,
                 mqs_train_only: bool = False,

                 # Seed
                 seed_enable: bool = True,
                 seed_alpha_obj_init: float = 0.55,
                 seed_alpha_grid_init: float = 0.4,
                 seed_mlp_expansion: float = 1.0,

                 # Two‑Stage Proposals
                 use_proposals: bool = True,
                 proposal_topk: Optional[int] = None,
                 proposal_ratio: float = 0.70,
                 min_mqs: int = 60,
                 min_left_queries: int = 16,
                 emit_obj_logits_in_eval: bool = False):
        super().__init__()

        # --- Class/header parameters ---
        self.num_obj_classes = num_obj_classes
        self.include_background = include_background
        self.num_pred = num_obj_classes + int(include_background)

        self.num_queries = int(num_queries)
        self.num_layers = int(num_layers)
        self.dn_queries = int(dn_queries)

        # IoU closed
        self.use_iou_aware = False
        self.emit_obj_logits_in_eval = bool(emit_obj_logits_in_eval)

        self.d_model = int(d_model)
        self.use_encoder = bool(use_encoder)
        self.head_prep = HeadPrep(self.d_model, ls_init=0.80, ln_eps=1e-6)

        # --- DN
        self.dn_groups = 5
        self.pos_box_noise = 0.4
        self.neg_box_noise = 1.0
        self.dn_box_pe = BoxPosEnc(self.d_model)  # ya da sin/cos box PE
        self.dn_box_proj = nn.Linear(self.d_model, self.d_model)
        self.dn_box_logit = nn.Parameter(torch.tensor(0.0))
        nn.init.xavier_uniform_(self.dn_box_proj.weight)
        nn.init.zeros_(self.dn_box_proj.bias)

        # --- Proposal  ---
        self.use_proposals = bool(use_proposals and use_encoder)
        self.proposal_topk = proposal_topk
        self.proposal_ratio = float(proposal_ratio)
        self.min_mqs = int(min_mqs)
        self.min_left_queries = int(min_left_queries)

        # --- Embeddings  ---
        self.level_embed_seed = nn.Parameter(torch.randn(4, self.d_model))
        self.query_obj_head = nn.Linear(self.d_model, 1)  # objness (seed/proposal )

        nn.init.xavier_uniform_(self.query_obj_head.weight)
        nn.init.zeros_(self.query_obj_head.bias)

        self.query_pos  = nn.Embedding(self.num_queries + self.dn_queries, self.d_model)
        self.query_feat = nn.Embedding(self.num_queries + self.dn_queries, self.d_model)
        self.label_enc  = nn.Embedding(self.num_obj_classes, self.d_model)

        self.pos_embed   = Pos2d(self.d_model)
        self.level_embed = nn.Parameter(torch.randn(4, self.d_model))

        # DN noise  (label & box) – poz/neg
        self.dn_label_noise_ratio = 0.5
        self.dn_box_noise_scale   = 0.4  # (backward compatibility only; using relative noise)

        # init ref(cx,cy)
        self.ref_init = nn.Linear(self.d_model, 2)

        # Encoder (opt)
        if self.use_encoder:
            self.encoder = TinyDeformEncoder(
                num_layers=int(encoder_layers),
                d_model=self.d_model, n_heads=int(attn_n_heads),
                n_levels=4, n_points=4,
                ffn_dim=self.d_model * 4,
                drop=0.0,
                drop_path_max=float(encoder_drop_path_max)
            )

        # Decoder
        n_points_list = [4] * max(0, self.num_layers - 2) + [2, 2] if self.num_layers >= 2 else [4] * self.num_layers
        self.layers = nn.ModuleList([
            DeformDecoderLayer(d_model=self.d_model,
                               n_heads=int(attn_n_heads),
                               n_levels=4,
                               n_points=n_points_list[i] if i < len(n_points_list) else 4,
                               ffn_dim=self.d_model * 4,
                               refine_scale=(0.15, 0.45),
                               layer_id=i, num_layers=self.num_layers)
            for i in range(self.num_layers)
        ])

        # Heads
        self.cls_head = nn.Linear(self.d_model, self.num_pred)
        self.box_head = nn.Linear(self.d_model, 4)

        self.aux_cls_heads = nn.ModuleList(
            nn.Linear(self.d_model, self.num_pred) for _ in range(self.num_layers - 1))
        self.aux_box_heads = nn.ModuleList(
            nn.Linear(self.d_model, 4) for _ in range(self.num_layers - 1))

        # MQS
        self.mqs_enable = bool(mqs_enable)
        self.mqs_obj_ratio = float(mqs_obj_ratio)
        self.mqs_grid_ratio = float(mqs_grid_ratio)
        self.mqs_levels = tuple(int(i) for i in mqs_levels)
        self.mqs_local_max_kernel = int(mqs_local_max_kernel)
        self.mqs_train_only = bool(mqs_train_only)

        # Seed
        self.seed_enable = bool(seed_enable)
        hid = max(int(self.d_model * float(seed_mlp_expansion)), self.d_model)
        self.seed_mlp = nn.Sequential(
            nn.LayerNorm(self.d_model),
            nn.Linear(self.d_model, hid),
            nn.GELU(),
            nn.Linear(hid, self.d_model)
        )

        def _logit(p: float) -> float:
            p = min(max(float(p), 1e-4), 1.0 - 1e-4)
            return math.log(p / (1.0 - p))

        self.seed_logit_obj  = nn.Parameter(torch.tensor(_logit(float(seed_alpha_obj_init)), dtype=torch.float))
        self.seed_logit_grid = nn.Parameter(torch.tensor(_logit(float(seed_alpha_grid_init)), dtype=torch.float))

        # 2‑Stage Proposals (from encoder memory)
        if self.use_proposals:
            self.prop_obj_head = nn.Linear(self.d_model, 1)
            self.prop_box_head = nn.Linear(self.d_model, 4)
            self.prop_content_proj = nn.Sequential(
                nn.LayerNorm(self.d_model),
                nn.Linear(self.d_model, self.d_model),
                nn.GELU(),
                nn.Linear(self.d_model, self.d_model)
            )

        self._init_weights()

    # ------------ init helpers ------------
    def _init_weights(self):
        nn.init.normal_(self.query_pos.weight,  std=0.02)
        nn.init.normal_(self.query_feat.weight, std=0.02)
        nn.init.normal_(self.level_embed,       std=0.02)
        nn.init.normal_(self.label_enc.weight,  std=0.02)
        nn.init.normal_(self.level_embed_seed,  std=0.02)

        nn.init.xavier_uniform_(self.ref_init.weight)
        nn.init.zeros_(self.ref_init.bias)

        nn.init.xavier_uniform_(self.box_head.weight, gain=1.0); nn.init.zeros_(self.box_head.bias)
        nn.init.xavier_uniform_(self.cls_head.weight, gain=1.0)

        for c, b in zip(self.aux_cls_heads, self.aux_box_heads):
            nn.init.xavier_uniform_(c.weight, gain=1.0)
            nn.init.xavier_uniform_(b.weight, gain=1.0); nn.init.zeros_(b.bias)

        for m in self.seed_mlp:
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

        if self.use_proposals:
            nn.init.xavier_uniform_(self.prop_obj_head.weight)
            nn.init.constant_(self.prop_obj_head.bias, -2.0)
            nn.init.xavier_uniform_(self.prop_box_head.weight)
            nn.init.zeros_(self.prop_box_head.bias)
            for m in self.prop_content_proj:
                if isinstance(m, nn.Linear):
                    nn.init.xavier_uniform_(m.weight)
                    nn.init.zeros_(m.bias)

        self.reset_class_bias()

    def _add_level_embed(self, seed: torch.Tensor, lvls: torch.Tensor) -> torch.Tensor:
        lvls_clamped = lvls.clamp(0, self.level_embed_seed.size(0) - 1)
        emb = F.embedding(lvls_clamped, self.level_embed_seed)  # (B,K,C) ya da (K,C)
        emb = emb.to(dtype=seed.dtype, device=seed.device)
        return seed + emb

    def reset_class_bias(self, object_prior: float = .2, no_object_bias: float = -2.):
        obj_bias = -math.log((1. - float(object_prior)) / float(object_prior))
        with torch.no_grad():
            self.cls_head.bias[:] = 0.
            self.cls_head.bias[: self.num_obj_classes] = obj_bias
            if self.include_background:
                self.cls_head.bias[-1] = float(no_object_bias)
            for aux in self.aux_cls_heads:
                aux.bias[:] = self.cls_head.bias

    # ------------ MQS helpers ------------
    @staticmethod
    def _allocate_per_level(k_total: int, shapes: List[Tuple[int, int]], sel_levels: Tuple[int, ...]) -> List[int]:
        L = len(shapes)
        if k_total <= 0:
            return [0] * L
        areas = [shapes[i][0] * shapes[i][1] if i in sel_levels else 0 for i in range(L)]
        S = sum(areas)
        if S == 0:
            out = [0] * L
            if sel_levels:
                out[sel_levels[-1]] = k_total
            return out
        floats = [k_total * (a / S) for a in areas]
        floors = [int(math.floor(x)) for x in floats]
        remain = k_total - sum(floors)
        fracs = [(floats[i] - floors[i]) if areas[i] > 0 else -1 for i in range(L)]
        order = sorted([i for i in range(L) if areas[i] > 0], key=lambda i: fracs[i], reverse=True)
        for i in range(remain):
            floors[order[i % len(order)]] += 1
        return [floors[i] if i in sel_levels else 0 for i in range(L)]

    @staticmethod
    def _energy_map(feat: torch.Tensor) -> torch.Tensor:
        return F.relu(feat, inplace=False).mean(dim=1, keepdim=True)

    def _mqs_objectness_points(self, feats: List[torch.Tensor], k_obj: int, sel_lvls: Tuple[int, ...]):
        B = feats[0].size(0)
        device = feats[0].device
        if k_obj <= 0:
            return torch.zeros(B, 0, 2, device=device), torch.zeros(B, 0, dtype=torch.long, device=device)

        shapes = [(f.shape[2], f.shape[3]) for f in feats]  # (H,W)
        per_level = self._allocate_per_level(k_obj, shapes, sel_lvls)

        coords_list, lvl_list = [], []
        ksize = self.mqs_local_max_kernel; pad = ksize // 2
        with torch.no_grad():
            for li, k_l in enumerate(per_level):
                if k_l <= 0:
                    continue
                f = feats[li]
                B_, C, H, W = f.shape
                en = self._energy_map(f)
                if ksize >= 3:
                    mp = F.max_pool2d(en, kernel_size=ksize, stride=1, padding=pad)
                    keep = (en >= mp - 1e-6)
                    en = en * keep
                en_flat = en.view(B, -1)
                k_sel = min(k_l, H * W)
                vals, idxs = torch.topk(en_flat, k_sel, dim=1)
                ys = (idxs // W).float() + 0.5
                xs = (idxs %  W).float() + 0.5
                cx = (xs / W).clamp(0., 1.)
                cy = (ys / H).clamp(0., 1.)
                coords = torch.stack([cx, cy], dim=-1)
                coords_list.append(coords)
                lvl_list.append(torch.full((B, k_sel), li, dtype=torch.long, device=device))
        if not coords_list:
            return torch.zeros(B, 0, 2, device=device), torch.zeros(B, 0, dtype=torch.long, device=device)
        coords_cat = torch.cat(coords_list, dim=1)
        lvls_cat  = torch.cat(lvl_list,   dim=1)
        return coords_cat, lvls_cat

    def _mqs_grid_points(self, spatial_shapes: torch.Tensor, k_grid: int, sel_lvls: Tuple[int, ...]):
        device = spatial_shapes.device
        if k_grid <= 0:
            return torch.zeros(1, 0, 2, device=device), torch.zeros(1, 0, dtype=torch.long, device=device)
        shapes: List[Tuple[int,int]] = [tuple(map(int, s.tolist())) for s in spatial_shapes]
        per_level = self._allocate_per_level(k_grid, shapes, sel_lvls)
        pts, lvls = [], []
        for li, k_l in enumerate(per_level):
            if k_l <= 0: continue
            H, W = shapes[li]
            ratio = W / max(H, 1)
            nx = max(1, int(round(math.sqrt(max(k_l,1) * ratio))))
            ny = max(1, int(math.ceil(k_l / nx)))
            xs = torch.linspace(0.5 / W, 1. - 0.5 / W, steps=nx, device=device)
            ys = torch.linspace(0.5 / H, 1. - 0.5 / H, steps=ny, device=device)
            grid_y, grid_x = torch.meshgrid(ys, xs, indexing="ij")
            grid = torch.stack([grid_x, grid_y], dim=-1).reshape(-1, 2)[:k_l]
            pts.append(grid); lvls.append(torch.full((k_l,), li, dtype=torch.long, device=device))
        if not pts:
            return torch.zeros(1, 0, 2, device=device), torch.zeros(1, 0, dtype=torch.long, device=device)
        pts = torch.cat(pts, dim=0).unsqueeze(0)
        lvls = torch.cat(lvls, dim=0).unsqueeze(0)
        return pts, lvls

    def _to_grid_xy(self, coords: torch.Tensor) -> torch.Tensor:
        if coords.dim() == 2:
            coords = coords.unsqueeze(0)
        gx = coords[..., 0] * 2.0 - 1.0
        gy = coords[..., 1] * 2.0 - 1.0
        grid = torch.stack([gx, gy], dim=-1)
        return grid.unsqueeze(1)  # (B,1,K,2)

    def _sample_feats_at_points(self, feats: List[torch.Tensor], pts: torch.Tensor, lvls: torch.Tensor) -> torch.Tensor:
        B = feats[0].size(0); C = feats[0].size(1); K = pts.size(1); device = pts.device
        if K == 0:
            return torch.zeros(B, 0, C, device=device, dtype=feats[0].dtype)
        out = torch.zeros(B, K, C, device=device, dtype=feats[0].dtype)
        for b in range(B):
            for li, f in enumerate(feats):
                mask = (lvls[b] == li)
                k_b = int(mask.sum().item())
                if k_b == 0: continue
                idx = torch.nonzero(mask, as_tuple=False).squeeze(-1)
                coords_b = pts[b, idx, :]
                grid_b = self._to_grid_xy(coords_b)
                grid_b = grid_b.to(device=f.device, dtype=f.dtype)
                f_b = f[b:b+1]
                samp = F.grid_sample(f_b, grid_b, mode="bilinear", align_corners=False)
                samp = samp.reshape(C, k_b).permute(1, 0).contiguous()
                out[b, idx, :] = samp
        return out

    # ------------ DN prep  ------------
    def prepare_denoising(
        self,
        targets: List[Dict],
        dn_groups: Optional[int] = None,
        pos_box_noise: Optional[float] = None,
        neg_box_noise: Optional[float] = None,
        label_noise_ratio: Optional[float] = None
    ):
        """
        Creates DINO-style denoising:
          • dn_groups groups; for each group, copies of POS (small deviation, correct/half-wrong label)
            and NEG (large deviation and/or wrong label).
          • Total DN length is fixed throughout the batch (Q):
                Q = min(self.dn_queries, 2 * dn_groups * max(1, max(#GT)))
          • Valid samples are marked with ‘mask’ (if GT=0, all masks are False).
        Return: dict {in_boxes, in_labels, tgt_boxes, tgt_labels, mask}
        """
        if not self.training or self.dn_queries == 0:
            return None

        device = targets[0]["boxes"].device
        B = len(targets)
        g_counts = [int(t["boxes"].size(0)) for t in targets]
        if sum(g_counts) == 0:
            return None

        dn_groups = self.dn_groups if dn_groups is None else int(dn_groups)
        pos_noise = self.pos_box_noise if pos_box_noise is None else float(pos_box_noise)
        neg_noise = self.neg_box_noise if neg_box_noise is None else float(neg_box_noise)
        lbl_noise = self.dn_label_noise_ratio if label_noise_ratio is None else float(label_noise_ratio)

        max_g = max(1, max(g_counts))
        raw_total   = min(self.dn_queries, 2 * dn_groups * max_g)
        q_per_group = max(1, math.ceil(raw_total / (2 * dn_groups)))
        total_q     = min(self.dn_queries, 2 * dn_groups * q_per_group)
        if total_q <= 0:
            return None

        dn_in_boxes  = torch.zeros(B, total_q, 4, device=device)
        dn_in_labels = torch.zeros(B, total_q,   dtype=torch.long, device=device)
        dn_tgt_boxes = torch.zeros(B, total_q, 4, device=device)
        dn_tgt_labels= torch.zeros(B, total_q,   dtype=torch.long, device=device)
        dn_mask      = torch.zeros(B, total_q,   dtype=torch.bool, device=device)
        dn_pos_mask  = torch.zeros(B, total_q,   dtype=torch.bool, device=device)

        q_per_group = max(1, total_q // (2 * dn_groups))

        for b, tgt in enumerate(targets):
            g = int(tgt["boxes"].size(0))
            if g == 0:
                continue

            boxes_gt  = tgt["boxes"]   # (g,4) cx,cy,w,h ∈ [0,1]
            labels_gt = tgt["labels"]  # (g,)


            rep = (q_per_group // max(1, g)) + 1
            base_boxes  = boxes_gt.repeat(rep, 1)[:q_per_group]
            base_labels = labels_gt.repeat(rep)[:q_per_group]

            # --- POS (relatively small deviation) ---

            pos_boxes = base_boxes.clone()
            pos_boxes[:, :2] += (torch.rand_like(pos_boxes[:, :2]) - 0.5) * pos_noise
            scale = 1.0 + (torch.rand_like(pos_boxes[:, 2:]) - 0.5) * pos_noise
            pos_boxes[:, 2:] = (pos_boxes[:, 2:] * scale).clamp(1e-4, 1.0)
            pos_boxes[:, :2] = pos_boxes[:, :2].clamp(0.0, 1.0)

            pos_labels = base_labels.clone()
            m = (torch.rand_like(pos_labels.float()) < lbl_noise)
            if int(m.sum()) > 0:
                pos_labels[m] = torch.randint(0, self.num_obj_classes, (int(m.sum()),), device=device)

            # --- NEG  (aggressive deviation + mostly incorrect classification) ---
            neg_boxes = base_boxes.clone()
            neg_boxes[:, :2] = torch.rand_like(neg_boxes[:, :2])
            scale = 0.5 + torch.rand_like(neg_boxes[:, 2:])
            neg_boxes[:, 2:] = (neg_boxes[:, 2:] * scale).clamp(1e-4, 1.0)

            neg_labels = base_labels.clone()
            m = (torch.rand_like(neg_labels.float()) < 1.0)
            if int(m.sum()) > 0:
                wrong = torch.randint(0, self.num_obj_classes, (int(m.sum()),), device=device)
                same = (wrong == base_labels[m])
                if int(same.sum()) > 0:
                    wrong[same] = (wrong[same] + 1) % self.num_obj_classes
                neg_labels[m] = wrong


            for gi in range(dn_groups):
                pos_base = 2 * gi * q_per_group
                neg_base = (2 * gi + 1) * q_per_group

                dn_in_boxes[b,  pos_base:pos_base+q_per_group]  = pos_boxes
                dn_in_labels[b, pos_base:pos_base+q_per_group]  = pos_labels
                dn_tgt_boxes[b, pos_base:pos_base+q_per_group]  = base_boxes
                dn_tgt_labels[b,pos_base:pos_base+q_per_group]  = base_labels
                dn_mask[b,      pos_base:pos_base+q_per_group]  = True

                dn_pos_mask[b,  pos_base:pos_base+q_per_group]  = True

                dn_in_boxes[b,  neg_base:neg_base+q_per_group]  = neg_boxes
                dn_in_labels[b, neg_base:neg_base+q_per_group]  = neg_labels
                dn_tgt_boxes[b, neg_base:neg_base+q_per_group]  = base_boxes
                dn_tgt_labels[b,neg_base:neg_base+q_per_group]  = base_labels
                dn_mask[b,      neg_base:neg_base+q_per_group]  = True

        return {
          "in_boxes":   dn_in_boxes,
          "in_labels":  dn_in_labels,
          "tgt_boxes":  dn_tgt_boxes,
          "tgt_labels": dn_tgt_labels,
          "mask":       dn_mask,       # cls : POS+NEG
          "pos_mask":   dn_pos_mask,   # bbox : only POS
      }

    @staticmethod
    def _build_self_attn_mask(dn_len, total_len, device, dn_groups=None):
        if dn_len <= 0: return None
        m = torch.zeros(total_len, total_len, dtype=torch.bool, device=device)
        m[:dn_len, dn_len:] = True; m[dn_len:, :dn_len] = True
        if dn_groups and dn_groups > 1:
            qg = dn_len // dn_groups
            for g in range(dn_groups):
                s, e = g*qg, (g+1)*qg
                m[s:e, :s] = True; m[s:e, e:dn_len] = True  # intergroup DN closed
        return m

    # ------------ forward ------------
    def forward(self,
                p3: torch.Tensor, p4: torch.Tensor,
                p5: torch.Tensor, p6: torch.Tensor,
                targets: Optional[List[Dict]] = None
                ) -> Dict[str, torch.Tensor]:

        B = p3.size(0)
        device = p3.device
        feats = [p6, p5, p4, p3]
        L = len(feats)

        # -------------------- feature flatten + pos --------------------
        src_list, pos_list = [], []
        spatial_shapes = torch.zeros(L, 2, dtype=torch.long, device=device)

        for i, f in enumerate(feats):
            B_, C, H, W = f.shape
            pos_lvl = self.pos_embed(f).expand(B_, -1, -1, -1) + self.level_embed[i].view(1, -1, 1, 1)
            src_list.append(f.flatten(2).transpose(1, 2))          # (B, H*W, C)
            pos_list.append(pos_lvl.flatten(2).transpose(1, 2))    # (B, H*W, C)
            spatial_shapes[i, 0] = H
            spatial_shapes[i, 1] = W

        src = torch.cat(src_list, dim=1)                           # (B, S, C)
        pos = torch.cat(pos_list, dim=1).to(dtype=src.dtype)       # (B, S, C)

        numel_per_level = spatial_shapes[:, 0] * spatial_shapes[:, 1]
        lvl_start_idx = torch.cat([numel_per_level.new_zeros(1), numel_per_level.cumsum(0)[:-1]], dim=0)  # (L,)

        # -------------------- encoder --------------------
        memory = self.encoder(src, pos, spatial_shapes, lvl_start_idx) if self.use_encoder else src
        if pos.size(0) != memory.size(0):
            pos = pos.expand(memory.size(0), -1, -1)
        pos = pos.to(dtype=memory.dtype)

        # -------------------- Denoising (DINO style) --------------------
        dn = self.prepare_denoising(targets) if (self.training and self.dn_queries > 0) else None
        has_dn = (dn is not None)
        dn_len = int(dn["in_boxes"].shape[1]) if has_dn else 0


        produce_obj_logits = self.training or self.emit_obj_logits_in_eval

        # -------------------- Query  (MATCH) --------------------
        K_total = int(self.num_queries)
        S = int(memory.size(1))

        proposal_topk  = self.proposal_topk
        proposal_ratio = self.proposal_ratio
        min_mqs        = int(self.min_mqs)
        min_left       = int(self.min_left_queries)
        mqs_enabled    = self.mqs_enable and (not self.mqs_train_only or self.training)

        # 1) Proposals
        Kp_cap = K_total
        if proposal_topk is not None:
            Kp_cap = min(Kp_cap, int(proposal_topk))
        if proposal_ratio is not None:
            Kp_cap = min(Kp_cap, int(round(K_total * float(proposal_ratio))))
        Kp_max = max(0, K_total - (min_mqs + min_left))
        Kp = min(Kp_cap, Kp_max, S) if self.use_proposals else 0

        # 2) remain
        K_rem = K_total - Kp

        # 3) MQS
        K_mqs = max(0, min(min_mqs, K_rem - min_left)) if mqs_enabled else 0
        tot = float(self.mqs_obj_ratio) + float(self.mqs_grid_ratio) + 1e-9
        k_obj  = int(round(K_mqs * float(self.mqs_obj_ratio) / tot)) if K_mqs > 0 else 0
        k_grid = K_mqs - k_obj

        # 4) Learned
        k_left = max(0, K_rem - (k_obj + k_grid))
        assert (Kp + k_obj + k_grid + k_left) == K_total

        def bgather(t: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
            idx_exp = idx.unsqueeze(-1).expand(-1, -1, t.size(-1))
            return torch.gather(t, 1, idx_exp)

        def levels_from_indices(idx: torch.Tensor,
                                spatial_shapes: torch.Tensor,
                                lvl_start_idx: torch.Tensor) -> torch.Tensor:
            starts = lvl_start_idx.view(1, 1, -1)
            sizes  = (spatial_shapes[:, 0] * spatial_shapes[:, 1]).view(1, 1, -1)
            ends   = starts + sizes
            x = idx.unsqueeze(-1)
            mask = (x >= starts) & (x < ends)
            lvls = mask.long().argmax(dim=-1)
            return lvls

        # --- Proposals ---
        content_prop = pos_prop = ref_prop = None
        obj_logits_prop = None
        if Kp > 0:
            prop_scores = self.prop_obj_head(memory).squeeze(-1)      # (B,S)
            prop_boxes  = self.prop_box_head(memory).sigmoid()        # (B,S,4)
            scores_sel, idx_sel = torch.topk(prop_scores, Kp, dim=1)  # (B,Kp)
            feat_sel = bgather(memory, idx_sel)                       # (B,Kp,C)
            pos_sel  = bgather(pos,    idx_sel)                       # (B,Kp,C)
            box_sel  = bgather(prop_boxes, idx_sel)                   # (B,Kp,4)

            ref_sel  = box_sel[..., :2]
            ref_prop = ref_sel.unsqueeze(2).expand(-1, -1, 4, -1)

            lvls_sel = levels_from_indices(idx_sel, spatial_shapes, lvl_start_idx)
            content_prop = self.prop_content_proj(feat_sel)
            content_prop = self._add_level_embed(content_prop, lvls_sel)
            pos_prop     = pos_sel

            if produce_obj_logits:
                obj_logits_prop = scores_sel

        # --- MQS (obj+grid) / Seed ---
        content_obj = content_grid = None
        pos_obj = pos_grid = None
        ref_obj = ref_grid = None
        obj_logits_obj = obj_logits_grid = None

        if (k_obj + k_grid) > 0:
            obj_pts = obj_lvls = None
            if k_obj > 0:
                obj_pts, obj_lvls = self._mqs_objectness_points(feats, k_obj, self.mqs_levels)

            grid_pts = grid_lvls = None
            if k_grid > 0:
                grid_pts_1, grid_lvls_1 = self._mqs_grid_points(spatial_shapes, k_grid, self.mqs_levels)
                grid_pts  = grid_pts_1.repeat(B, 1, 1).to(device)
                grid_lvls = grid_lvls_1.repeat(B, 1).to(device)

            if self.seed_enable:
                if k_obj > 0:
                    obj_seed = self._sample_feats_at_points(feats, obj_pts, obj_lvls)
                    if obj_seed.numel() > 0:
                        obj_seed = self._add_level_embed(obj_seed, obj_lvls)
                    alpha_obj  = torch.sigmoid(self.seed_logit_obj)
                    obj_seed_m = self.seed_mlp(obj_seed) if (obj_seed is not None and obj_seed.numel() > 0) else None
                    base_obj = self.query_feat.weight[self.dn_queries : self.dn_queries + k_obj] \
                    .unsqueeze(0).expand(B, -1, -1)
                    content_obj = (1.0 - alpha_obj) * base_obj + alpha_obj * obj_seed_m
                    pos_obj  = self.query_pos.weight[self.dn_queries : self.dn_queries + k_obj] \
              .unsqueeze(0).expand(B, -1, -1)
                    ref_obj     = obj_pts.unsqueeze(2).expand(-1, -1, 4, -1)
                    if produce_obj_logits:
                        obj_logits_obj = self.query_obj_head(content_obj).squeeze(-1)

                if k_grid > 0:
                    grid_seed = self._sample_feats_at_points(feats, grid_pts, grid_lvls)
                    if grid_seed.numel() > 0:
                        grid_seed = self._add_level_embed(grid_seed, grid_lvls)
                    alpha_grid  = torch.sigmoid(self.seed_logit_grid)
                    grid_seed_m = self.seed_mlp(grid_seed) if (grid_seed is not None and grid_seed.numel() > 0) else None
                    base_grid = self.query_feat.weight[self.dn_queries + k_obj : self.dn_queries + k_obj + k_grid] \
               .unsqueeze(0).expand(B, -1, -1)
                    content_grid = (1.0 - alpha_grid) * base_grid + alpha_grid * grid_seed_m
                    pos_grid  = self.query_pos.weight[self.dn_queries + k_obj : self.dn_queries + k_obj + k_grid] \
               .unsqueeze(0).expand(B, -1, -1)
                    ref_grid     = grid_pts.unsqueeze(2).expand(-1, -1, 4, -1)
                    if produce_obj_logits:
                        obj_logits_grid = self.query_obj_head(content_grid).squeeze(-1)

        # --- Learned remain ---
        content_left = pos_left = ref_left = None
        obj_logits_left = None
        if k_left > 0:
            base_s   = self.dn_queries + k_obj + k_grid
            content_left = self.query_feat.weight[base_s : base_s + k_left] \
                 .unsqueeze(0).expand(B, -1, -1)
            pos_left     = self.query_pos.weight[base_s : base_s + k_left] \
                 .unsqueeze(0).expand(B, -1, -1)
            init_ref_left = torch.sigmoid(self.ref_init(pos_left))
            ref_left = init_ref_left.unsqueeze(2).expand(-1, -1, 4, -1)
            if produce_obj_logits:
                obj_logits_left = self.query_obj_head(content_left).squeeze(-1)

        # ---  combine MATCH set ---
        def cat_safe(parts: list, dim: int) -> torch.Tensor:
            valid = [p for p in parts if p is not None]
            if len(valid) == 0:
                raise RuntimeError("RTDeformDecoder: boş MATCH seti oluştu.")
            return torch.cat(valid, dim=dim)

        parts_content = [content_prop, content_obj, content_grid, content_left]
        parts_pos     = [pos_prop,     pos_obj,     pos_grid,     pos_left]
        parts_ref     = [ref_prop,     ref_obj,     ref_grid,     ref_left]

        content_match = cat_safe(parts_content, dim=1)   # (B, N_match, C)
        pos_match     = cat_safe(parts_pos,     dim=1)   # (B, N_match, C)
        ref_match     = cat_safe(parts_ref,     dim=1)   # (B, N_match, 4, 2)
        content_match = content_match.contiguous()
        pos_match     = pos_match.contiguous()
        ref_match     = ref_match.contiguous()
        obj_logits_parts = [obj_logits_prop, obj_logits_obj, obj_logits_grid, obj_logits_left]
        obj_logits_match = torch.cat([p for p in obj_logits_parts if p is not None], dim=1) \
                           if (produce_obj_logits and any(p is not None for p in obj_logits_parts)) else None

        # --------------------Combine with DN (prefix) --------------------
        if has_dn:
            box_pe    = self.dn_box_pe(dn["in_boxes"].detach())
            alpha     = torch.sigmoid(self.dn_box_logit)
            content_dn = self.label_enc(dn["in_labels"]) + alpha * self.dn_box_proj(box_pe)                               # (B, Q, C)
            pos_dn     = self.query_pos.weight[:dn_len].unsqueeze(0).repeat(B, 1, 1)      # (B, Q, C)
            ref_dn     = dn["in_boxes"][..., :2].unsqueeze(2).expand(-1, -1, 4, -1).clone()  # (B, Q, 4, 2)

            tgt       = torch.cat([content_dn, content_match], 1).contiguous()
            query_pos = torch.cat([pos_dn,     pos_match],     1).contiguous()
            ref_pts   = torch.cat([ref_dn,     ref_match],     1).contiguous()                              # (B, Q+N_match, 4, 2)
        else:
            dn_len = 0
            tgt, query_pos, ref_pts = content_match, pos_match, ref_match

        # -------------------- self-attention mask --------------------
        total_len = tgt.size(1)
        attn_mask = self._build_self_attn_mask(dn_len, total_len, tgt.device)

        # -------------------- decoder --------------------
        aux_out = []
        for lid, layer in enumerate(self.layers):

            try:
                tgt, ref_pts = layer(
                    tgt=tgt, ref_pts=ref_pts, src=memory,
                    query_pos=query_pos, spatial_shapes=spatial_shapes, lvl_start_idx=lvl_start_idx,
                    self_attn_mask=attn_mask
                )
            except TypeError:
                tgt, ref_pts = layer(
                    tgt=tgt, ref_pts=ref_pts, src=memory,
                    query_pos=query_pos, spatial_shapes=spatial_shapes, lvl_start_idx=lvl_start_idx
                )

            if lid < self.num_layers - 1:
                # Aux only from MATCH
                aux_in = tgt[:, dn_len:] if dn_len > 0 else tgt
                aux_h  = self.head_prep(aux_in)
                aux_logits = self.aux_cls_heads[lid](aux_h)
                aux_boxes  = self.aux_box_heads[lid](aux_h).sigmoid()
                aux_out.append({"pred_logits": aux_logits, "pred_boxes": aux_boxes})

        # -------------------- last heads --------------------
        h = self.head_prep(tgt)
        if dn_len > 0:
            h_dn    = h[:, :dn_len]
            h_match = h[:, dn_len:]
            dn_logits    = self.cls_head(h_dn)
            dn_boxes_out = self.box_head(h_dn).sigmoid()
            logits = self.cls_head(h_match)
            boxes  = self.box_head(h_match).sigmoid()
        else:
            logits = self.cls_head(h)
            boxes  = self.box_head(h).sigmoid()

        out: Dict[str, torch.Tensor] = {"pred_logits": logits, "pred_boxes": boxes}

        if aux_out and self.training:
            out["aux_outputs"] = aux_out

        if has_dn and self.training:

            out["dn_meta"] = {
            "dn_logits":   dn_logits,
            "dn_boxes":    dn_boxes_out,
            "dn_labels":   dn["tgt_labels"],
            "dn_gt_boxes": dn["tgt_boxes"],
            "dn_mask":     dn["mask"].bool(),
            "dn_pos_mask": dn["pos_mask"].bool(),   # <-- YENİ
            "dn_len":      dn_len,
            "dn_queries":  dn_len,
          }

        # MATCH‑only final_ref_pts (center loss vb. için)
        ref_match_only = ref_pts[:, dn_len:, ...] if dn_len > 0 else ref_pts
        N_match = out["pred_boxes"].size(1)
        if ref_match_only.size(1) != N_match:
            ref_match_only = ref_match_only[:, -N_match:, ...]
        out["final_ref_pts"] = ref_match_only.contiguous()

        # Obj logits (only MATCH part)
        if produce_obj_logits and (obj_logits_match is not None):
            if obj_logits_match.size(1) != N_match:
                if obj_logits_match.size(1) > N_match:
                    obj_logits_match = obj_logits_match[:, :N_match]
            out["pred_obj_logits"] = obj_logits_match.contiguous()

        # All MATCH queries are “selected”
        out["query_selection_mask"] = torch.ones(B, logits.size(1), dtype=torch.bool, device=logits.device)

        return out





# -----------------------------------------------------------------------------
# MiniBackbone
# -----------------------------------------------------------------------------

class MiniBackbone(nn.Module):
    """Lightweight version with CSI and PGI"""
    def __init__(self, depths: List[int] = [6, 6, 12, 6],
                drop_path_max: float = 0.2,
                num_classes: int = 20):
        super().__init__()
        self.backbone = StageAwareBackbone(depths, drop_path_max, num_classes)

    def forward(self, x: torch.Tensor, need_aux: bool = False):
        # Match the return format of StageAwareBackbone
        return self.backbone(x, need_aux)



def _looks_like_norm_name(name: str) -> bool:
    low = name.lower()
    # name-based heuristic; used in conjunction with a module-based set
    return any(k in low for k in [
        ".norm", "bn", "groupnorm", "layernorm", "gn", "ln.", "ln_", "lnact", "ln_act", "lnproxy", "layernormproxy", "rmsnorm"
    ])



class HybridDCDATRT(nn.Module):
    """
    End‑to‑End Hybrid‑DCDAT‑RT detector (DAT‑backbone + DCNv4 + RT‑Deform Decoder)
    """
    def __init__(self,
                 num_classes:   int = 20,
                 num_queries:   int = 100,
                 depths:        List[int] = (2, 2, 2, 2),
                 drop_path_max: float = 0.0,
                 backbone_norm_factory: 'NormFactory' = None,
                 neck_norm_factory:     'NormFactory' = None,
                 use_layer_scale: bool = True,
                 layer_scale_init: float = 1.0,
                 # Decoder
                 d_model:    int = 320,
                 dec_layers: int = 4,
                 dn_queries: int = 100,
                 emit_obj_logits_in_eval: bool = True,
                 # Features
                 use_aux_loss: bool = True,
                 use_iou_aware: bool = False,
                 # LR multipliers
                 backbone_lr: float = 0.1,
                 head_lr:     float = 1.0,
                 # Dataset
                 voc_prior: bool = False,
                 # Decoder Encoder
                 decoder_use_encoder: bool = True,
                 decoder_encoder_layers: int = 2,
                 decoder_encoder_drop_path_max: float = 0.1):
        super().__init__()

        # default NormFactory
        backbone_norm_factory = backbone_norm_factory or NormFactory("gn")
        neck_norm_factory     = neck_norm_factory     or NormFactory("gn")

        self.num_classes   = num_classes
        self.use_aux_loss  = use_aux_loss

        # ---------- Backbone (DAT + CSI + PGI) ----------
        self.backbone = StageAwareBackbone(
            depths          = depths,
            drop_path_max   = drop_path_max,
            num_classes     = num_classes,
            voc_prior       = voc_prior,
            norm_factory    = backbone_norm_factory,
            use_layer_scale = use_layer_scale,
            layer_scale_init= layer_scale_init
        )

        # ---------- Neck (Light‑BiFPN) ----------

        self.neck = LightBiFPN(
          in_channels=(128,256,640,768),
          out=d_model,
          repeats=2,
          use_spatial_fuse=True,
          fuse_groups=4,
          fuse_tau_init=1.4,
          fuse_learn_tau=True,
          fuse_eps=5e-4,
          grad_boost_low=0.1,
          dp_top_second=0.01, dp_bot_second=0.01,
          no_blur_first=True,
          edge_enhance=True,
          edge_ls_init=0.03,
          norm_factory=neck_norm_factory,
      )

        # ---------- Decoder (RT‑Deform) ----------
        self.decoder = RTDeformDecoder(
            num_obj_classes = num_classes,
            include_background = True,
            num_queries     = num_queries,
            num_layers      = dec_layers,
            d_model         = d_model,
            dn_queries      = dn_queries,
            use_iou_aware   = use_iou_aware,
            use_encoder                 = decoder_use_encoder,
            encoder_layers              = decoder_encoder_layers,
            encoder_drop_path_max       = decoder_encoder_drop_path_max,
            emit_obj_logits_in_eval=emit_obj_logits_in_eval,
        )


        self._lr_mult = {
            "backbone": backbone_lr,
            "head":     head_lr,
            "bias":     2.0,
            "dcn_bias": 10.0,
        }

    # ------------------------------------------------------------------ #
    def forward(self,
                x: torch.Tensor,
                targets: Optional[List[Dict]] = None) -> Dict[str, torch.Tensor]:
        """
         - ‘pred_logits’, ‘pred_boxes’, optional ‘pred_ious’
          - in training: ‘aux_outputs’, ‘dn_meta’, ‘query_selection_mask’
          - optional: ‘aux_dense’ (auxiliary dense headers from the backbone)
        """
        if not x.is_cuda:
            raise RuntimeError("HybridDCDATRT expects CUDA tensor input.")

        # Backbone
        if self.training and self.use_aux_loss:
            (p3, p4, p5, p6), aux_dense = self.backbone(x, need_aux=True)
        else:
            p3, p4, p5, p6 = self.backbone(x, need_aux=False)
            aux_dense      = None

        # Neck
        p3, p4, p5, p6 = self.neck(p3, p4, p5, p6)

        # Decoder
        if self.training and targets is not None:
            dec_out = self.decoder(p3, p4, p5, p6, targets)
        else:
            dec_out = self.decoder(p3, p4, p5, p6)

        if aux_dense is not None:
            dec_out["aux_dense"] = aux_dense
        return dec_out



    def param_groups(self,
                    base_lr: float = 1e-4,
                    *,
                    weight_decay: float = 0.02,
                    bb_mult: float = 0.5,
                    dec_mult: float = 2.0,
                    bias_mult: float = 1.0,      # Bias boost
                    dcn_mult: float = 1.5,       # DCN offset/mask boost
                    rpb_mult: float = 5.0,       # RPB boost
                    gate_mult: float = 1.0,      # Gate params boost
                    ls_mult: float = 0.3,        # LayerScale reduction
                    pos_mult: float = 1.5):
        """
        Production-ready parameter grouping with optimized LR scheduling.
        """
        import re

        buckets = {
            # Backbone
            "bb_w": [], "bb_b": [], "bb_norm": [],
            "dcn_off": [],  # DCN offset/mask (specific LR)

            # Neck / Head
            "hd_w": [], "hd_b": [], "hd_norm": [],
            "fuse_gate": [],  # Fusion gates

            # Decoder
            "dec_w": [], "dec_b": [], "dec_norm": [],
            "deform_off_b": [],  # Deformable attention bias

            # Special
            "pos_embed": [], "rpb_p": [], "ls_gamma": [],
            "dat_scale_p": [], "scalars_gain": [],
        }

        def is_norm_name(n: str) -> bool:
            return any(x in n for x in [".norm.", ".bn.", ".ln.", ".gn."])

        for name, param in self.named_parameters():
            if not param.requires_grad:
                continue

            # --- Special cases (prefix-independent) ---

            # RPB tables (needs high LR)
            if name.endswith(".rpb") or "relative_position_bias" in name:
                buckets["rpb_p"].append(param)
                continue

            # Positional embeddings
            if re.search(r"(level_embed(_seed)?|query_pos|query_feat|label_enc|pos_table)", name):
                buckets["pos_embed"].append(param)
                continue

            # LayerScale gamma
            if re.search(r"(layer_scales\.\d+\.weight|\.scale\.weight$|\.ls\.weight$)", name):
                buckets["ls_gamma"].append(param)
                continue

            # Scalar gates
            if any(name.endswith(x) for x in ["residual_scale", "_lambda_logit", "dc_res_logit", "seed_logit_obj", "seed_logit_grid"]):
                buckets["scalars_gain"].append(param)
                continue

            # DAT scale parameters
            if "dat_logit" in name or "dat_log_tau" in name:
                buckets["dat_scale_p"].append(param)
                continue

            # Neck fusion gates
            if "neck.fuse" in name and "log_tau" in name:
                buckets["fuse_gate"].append(param)
                continue

            # Deformable attention sampling bias
            if "sampling_offsets.bias" in name:
                buckets["deform_off_b"].append(param)
                continue

            # --- Module-based categorization ---

            # Backbone
            if name.startswith("backbone."):
                # DCN offset/mask (special treatment)
                if "offset_mask" in name and "dcn" in name:
                    buckets["dcn_off"].append(param)
                elif name.endswith(".bias") and not is_norm_name(name):
                    buckets["bb_b"].append(param)
                elif is_norm_name(name):
                    buckets["bb_norm"].append(param)
                else:
                    buckets["bb_w"].append(param)
                continue

            # Neck
            if name.startswith("neck."):
                if name.endswith(".bias") and not is_norm_name(name):
                    buckets["hd_b"].append(param)
                elif is_norm_name(name):
                    buckets["hd_norm"].append(param)
                else:
                    buckets["hd_w"].append(param)
                continue

            # Decoder (default)
            if name.endswith(".bias") and not is_norm_name(name):
                buckets["dec_b"].append(param)
            elif is_norm_name(name):
                buckets["dec_norm"].append(param)
            else:
                buckets["dec_w"].append(param)

        # --- Build parameter groups ---
        wd = weight_decay
        groups = []

        # Backbone groups
        if buckets["bb_w"]:
            groups.append({"params": buckets["bb_w"], "lr": base_lr*bb_mult, "weight_decay": wd, "name": "bb_w"})
        if buckets["bb_b"]:
            groups.append({"params": buckets["bb_b"], "lr": base_lr*bb_mult*bias_mult, "weight_decay": 0.0, "name": "bb_b"})
        if buckets["bb_norm"]:
            groups.append({"params": buckets["bb_norm"], "lr": base_lr*bb_mult, "weight_decay": 0.0, "name": "bb_norm"})

        # DCN offset/mask (HIGH LR!)
        if buckets["dcn_off"]:
            groups.append({"params": buckets["dcn_off"], "lr": base_lr*dcn_mult, "weight_decay": 0.0, "name": "dcn_off"})

        # Neck/Head groups
        if buckets["hd_w"]:
            groups.append({"params": buckets["hd_w"], "lr": base_lr, "weight_decay": wd, "name": "hd_w"})
        if buckets["hd_b"]:
            groups.append({"params": buckets["hd_b"], "lr": base_lr*bias_mult, "weight_decay": 0.0, "name": "hd_b"})
        if buckets["hd_norm"]:
            groups.append({"params": buckets["hd_norm"], "lr": base_lr, "weight_decay": 0.0, "name": "hd_norm"})

        # Decoder groups
        if buckets["dec_w"]:
            groups.append({"params": buckets["dec_w"], "lr": base_lr*dec_mult, "weight_decay": wd, "name": "dec_w"})
        if buckets["dec_b"]:
            groups.append({"params": buckets["dec_b"], "lr": base_lr*dec_mult*bias_mult, "weight_decay": 0.0, "name": "dec_b"})
        if buckets["dec_norm"]:
            groups.append({"params": buckets["dec_norm"], "lr": base_lr*dec_mult, "weight_decay": 0.0, "name": "dec_norm"})

        # Special parameters with custom LR
        if buckets["deform_off_b"]:
            groups.append({"params": buckets["deform_off_b"], "lr": base_lr*dcn_mult, "weight_decay": 0.0, "name": "deform_off_b"})
        if buckets["fuse_gate"]:
            groups.append({"params": buckets["fuse_gate"], "lr": base_lr*gate_mult, "weight_decay": 0.0, "name": "fuse_gate"})
        if buckets["rpb_p"]:
            groups.append({"params": buckets["rpb_p"], "lr": base_lr*rpb_mult, "weight_decay": 0.0, "name": "rpb_p"})
        if buckets["pos_embed"]:
            groups.append({"params": buckets["pos_embed"], "lr": base_lr*pos_mult, "weight_decay": 0.0, "name": "pos_embed"})
        if buckets["ls_gamma"]:
            groups.append({"params": buckets["ls_gamma"], "lr": base_lr*ls_mult, "weight_decay": 0.0, "name": "ls_gamma"})
        if buckets["dat_scale_p"]:
            groups.append({"params": buckets["dat_scale_p"], "lr": base_lr*gate_mult, "weight_decay": 0.0, "name": "dat_scale_p"})
        if buckets["scalars_gain"]:
            groups.append({"params": buckets["scalars_gain"], "lr": base_lr*gate_mult, "weight_decay": 0.0, "name": "scalars_gain"})

        # Validation
        in_groups = {id(p) for g in groups for p in g["params"]}
        missing = [n for n,p in self.named_parameters() if p.requires_grad and id(p) not in in_groups]
        assert not missing, f"Missing params in groups: {len(missing)} params (e.g., {missing[:5]})"

        # Debug logging (optional)
        if not hasattr(self, '_param_groups_logged'):
            print("\n" + "="*60)
            print("Parameter Groups Summary")
            print("="*60)
            for g in groups:
                n_params = sum(p.numel() for p in g["params"])
                if n_params > 0:
                    lr_mult = g["lr"] / base_lr
                    print(f"{g['name']:20s}: {n_params:10,} params | LR: {lr_mult:6.1f}x | WD: {g['weight_decay']:.3f}")
            print("="*60 + "\n")
            self._param_groups_logged = True

        return groups





# -----------------------------------------------------------------------------
# build_optimizer – returns only Optimizer (scheduler opsiyonel)
# -----------------------------------------------------------------------------
def build_optimizer(model: nn.Module,
                    base_lr: float = 2e-4,
                    weight_decay: float = 0.02,
                    *,
                    return_scheduler: bool = False):

    device = next(model.parameters()).device

    # ---------- 1) LAZY‑PARAM warmup (only backbone) ----------
    was_training = model.training
    model.train()  # PGI only train
    with torch.no_grad():
        for s in (640, 320):
            dummy = torch.zeros(1, 3, s, s, device=device)
            _ = model.backbone(dummy, need_aux=True)
    model.train(was_training)

    # ---------- 2) PARAM GROUPS ----------
    groups = model.param_groups(base_lr)

    # Norm weights WD=0 (GN/LNProxy vb.)
    ln_keys = ('.norm', '.ln_act', 'ln_proxy')

    for g in groups:
        params_in_group = set(map(id, g['params']))
        if any(any(k in n for k in ln_keys)
               for n, p in model.named_parameters() if id(p) in params_in_group):
            g["weight_decay"] = 0.0

    # ---------- 3) OPTIMIZER ----------
    optim = torch.optim.AdamW(
        groups, lr=base_lr,
        betas=(0.9, 0.999), eps=1e-6,
        weight_decay=weight_decay
    )

    if not return_scheduler:
        return optim

    sched = torch.optim.lr_scheduler.CosineAnnealingLR(
        optim, T_max=100, eta_min=base_lr * 0.05)
    return optim, sched


def build_model(num_classes: int = 20,
                norm: str = "gn",            # "gn", "lnp" veya "bn"
                use_layer_scale: bool = True,
                layer_scale_init: float = 1.0,
                **kwargs) -> HybridDCDATRT:

    defaults = dict(
        # Backbone
        depths         = (2, 2, 4, 1),
        drop_path_max  = 0.2,
        # Decoder
        d_model        = 320,
        dec_layers     = 4,
        dn_queries     = 100,
        num_queries    = 200,
        use_aux_loss   = True,
        use_iou_aware  = False,
        # Dataset
        voc_prior      = False,
        # Decoder Encoder
        decoder_use_encoder           = True,
        decoder_encoder_layers        = 2,
        decoder_encoder_drop_path_max = 0.1,
    )
    defaults.update(kwargs)

    nf = NormFactory(norm)
    model = HybridDCDATRT(
        num_classes           = num_classes,
        backbone_norm_factory = nf,
        neck_norm_factory     = nf,
        use_layer_scale       = use_layer_scale,
        layer_scale_init      = layer_scale_init,
        **defaults
    )
    # init_weights_improved(model)
    # apply_hybrid_fixup(model)
    # boost_relative_position_bias(model, std=0.05)
    return model
# # -----------------------------------------------------------------------------
# # Test code
# # -----------------------------------------------------------------------------

# if __name__ == "__main__":
#     import time

#     assert torch.cuda.is_available(), "CUDA required for this model."
#     device = torch.device("cuda")

#     print("="*80)
#     print("Building Production-Ready Hybrid-DCDAT-RT Model...")
#     print("="*80)

#     model = build_model(
#         num_classes=80,
#         depths=[2, 2, 2, 2],  # ResNet-50 like depth
#         drop_path_max=0,
#         num_queries=200,
#         voc_prior=False,norm="gn"  # Use COCO prior
#     ).to(device)
#     optimizer = build_optimizer(model, base_lr=2e-4)
#     # Test forward pass
#     print("\n🚀 Testing forward pass...")
#     model.eval()
#     dummy = torch.randn(2, 3, 640, 640, device=device)

#     with torch.no_grad():
#         out = model(dummy)
#         print("✅ Forward pass successful!")
#         print(f"Output keys: {list(out.keys())}")
#         print(f"Predictions shape: {out['pred_logits'].shape}, {out['pred_boxes'].shape}")

#     # Test training mode with targets
#     print("\n🚀 Testing training mode...")
#     model.train()

#     # Create more realistic targets with proper box format
#     targets = []
#     for _ in range(2):  # batch size 2
#         num_objs = torch.randint(1, 10, (1,)).item()
#         # Generate boxes in (cx, cy, w, h) format, all normalized to [0, 1]
#         centers = torch.rand(num_objs, 2, device=device)
#         sizes = torch.rand(num_objs, 2, device=device) * 0.5  # max size 0.5
#         boxes = torch.cat([centers, sizes], dim=1)

#         # Ensure boxes are valid (centers must be at least half-size from borders)
#         half_sizes = boxes[:, 2:4] / 2
#         boxes[:, 0] = boxes[:, 0].clamp(min=half_sizes[:, 0], max=1-half_sizes[:, 0])
#         boxes[:, 1] = boxes[:, 1].clamp(min=half_sizes[:, 1], max=1-half_sizes[:, 1])

#         targets.append({
#             'boxes': boxes,
#             'labels': torch.randint(0, 80, (num_objs,), device=device)
#         })

#     out = model(dummy, targets)
#     print("✅ Training mode successful!")
#     print(f"Output keys: {list(out.keys())}")
#     if 'aux_dense' in out:
#         print(f"Auxiliary dense outputs: {list(out['aux_dense'].keys())}")
#     if 'dn_meta' in out:
#         print(f"Denoising queries: {out['dn_meta']['dn_queries']}")

#     # Test gradient flow
#     print("\n🚀 Testing gradient flow...")
#     loss = out['pred_logits'].sum() + out['pred_boxes'].sum()
#     loss.backward()

#     # Check that gradients flow through all parts
#     has_grad = {
#         'backbone': any(p.grad is not None for n, p in model.named_parameters() if 'backbone' in n),
#         'neck': any(p.grad is not None for n, p in model.named_parameters() if 'neck' in n),
#         'decoder': any(p.grad is not None for n, p in model.named_parameters() if 'decoder' in n),
#     }

#     for module, has in has_grad.items():
#         print(f"  {module}: {'✅' if has else '❌'} gradients")

#     # Test parameter groups
#     print("\n🚀 Testing parameter groups...")
#     param_groups = model.param_groups(base_lr=2e-4)
#     print(f"Number of parameter groups: {len(param_groups)}")
#     for i, group in enumerate(param_groups):
#         print(f"  Group {i}: {len(group['params'])} params, lr={group['lr']:.6f}, wd={group['weight_decay']}")

#     print("\n" + "="*80)
#     print("✅ All tests passed! Model is production-ready.")
#     print("="*80)



In [None]:
from __future__ import annotations
from typing import List, Dict, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from scipy.optimize import linear_sum_assignment
from torchvision.ops import generalized_box_iou, box_convert,box_iou


# ─────────────────────────────────────────────────────────────────────────────
# Helpers: focal losses
# ─────────────────────────────────────────────────────────────────────────────

def softmax_focal_loss_weighted(
    logits: torch.Tensor,          # (N, C+1)
    targets: torch.Tensor,         # (N,) int64  (‑1 = ignore)
    *,
    alpha: float = 0.25,
    gamma: float = 2.0,
    sample_weight: Optional[torch.Tensor] = None,  # (N,) or None
    reduction: str = "mean",
    ignore_index: int = -1,
) -> torch.Tensor:
    """
    Multi-class focal loss (softmax) + sample weight.
    • targets == ignore_index → not included in loss.
    • BG class ID = C (i.e., num_classes).
    • sample_weight to weight positive/negative samples separately (e.g., QFL).
    """
    logp = F.log_softmax(logits, dim=-1)
    p = logp.exp()  # (N, C+1)

    if ignore_index is not None:
        keep = targets.ne(ignore_index)
        if keep.sum() == 0:
            return logits.sum() * 0.0
        logits = logits[keep]
        logp = logp[keep]
        p = p[keep]
        targets = targets[keep]
        if sample_weight is not None:
            sample_weight = sample_weight[keep]

    idx = torch.arange(targets.size(0), device=targets.device)
    log_pt = logp[idx, targets]
    pt = p[idx, targets]

    bg_id = logits.size(1) - 1
    alpha_t = torch.where(targets == bg_id, 1.0 - alpha, alpha)

    loss = -alpha_t * (1.0 - pt).pow(gamma) * log_pt  # (N,)

    if sample_weight is not None:
        loss = loss * sample_weight

    if reduction == "sum":
        return loss.sum()
    if reduction == "mean":
        return loss.mean()
    return loss


def binary_focal_with_logits(
    input: torch.Tensor,           # (..., C)
    target: torch.Tensor,          # (..., C) in {0,1}
    alpha: float = 0.25,
    gamma: float = 2.0,
    reduction: str = "mean",
) -> torch.Tensor:
    """Sigmoid (channel-independent) focal loss. BG masked (target 0)."""
    ce = F.binary_cross_entropy_with_logits(input, target, reduction="none")
    p = torch.sigmoid(input)
    p_t = p * target + (1 - p) * (1 - target)
    alpha_t = alpha * target + (1 - alpha) * (1 - target)
    loss = alpha_t * (1 - p_t).pow(gamma) * ce

    if reduction == "sum":
        return loss.sum()
    if reduction == "mean":
        return loss.mean()
    return loss


# ─────────────────────────────────────────────────────────────────────────────
# Dense aux (sigmoid focal): BG implicit
# ─────────────────────────────────────────────────────────────────────────────

def focal_loss_dense_sigmoid(logits, targets, num_classes, alpha=.5, gamma=1.5, use_or_map=False, radius: int = 1):
    B, C, H, W = logits.shape
    device = logits.device
    assert C == num_classes
    tgt = torch.zeros(B, C, H, W, device=device)

    for b, t in enumerate(targets):
        if len(t["boxes"]) == 0: continue
        ctr = t["boxes"][:, :2] * torch.tensor([W, H], device=device)
        gx = ctr[:, 0].long().clamp(0, W - 1)
        gy = ctr[:, 1].long().clamp(0, H - 1)
        cls = t["labels"].clamp(0, C - 1)

        for (x, y, c) in zip(gx.tolist(), gy.tolist(), cls.tolist()):
            x0, x1 = max(0, x - radius), min(W - 1, x + radius)
            y0, y1 = max(0, y - radius), min(H - 1, y + radius)
            tgt[b, c, y0:y1+1, x0:x1+1] = 1.0

    logits_flat = logits.permute(0, 2, 3, 1).reshape(-1, C)
    tgt_flat = tgt.permute(0, 2, 3, 1).reshape(-1, C)
    return binary_focal_with_logits(logits_flat, tgt_flat, alpha=alpha, gamma=gamma, reduction="mean")


# ─────────────────────────────────────────────────────────────────────────────
# Hungarian matcher (log‑prob, cost scaling)
# ─────────────────────────────────────────────────────────────────────────────

class HungarianMatcher(nn.Module):
    """One‑to‑one matching between predicted queries and GT boxes."""

    def __init__(self,
                 cost_class: float = 2.,
                 cost_bbox: float = 5.,
                 cost_giou: float = 2.):
        super().__init__()
        self.cost_class = cost_class
        self.cost_bbox = cost_bbox
        self.cost_giou = cost_giou

    @torch.no_grad()
    def forward(self,
                outputs: Dict[str, torch.Tensor],
                targets: List[Dict]) -> List[Tuple[torch.Tensor, torch.Tensor]]:
        """
        outputs:
            - pred_logits: (B, N, C+1)
            - pred_boxes : (B, N, 4)  (cx,cy,w,h norm.)
        NOTE: query_selection_mask is not used here; we use it to ignore negatives on the loss side.
        """
        bs, num_queries = outputs["pred_logits"].shape[:2]
        device = outputs["pred_logits"].device

        out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1)   # (B·N, C+1)
        out_bbox = outputs["pred_boxes"].flatten(0, 1)                # (B·N, 4)

        tgt_ids  = torch.cat([v["labels"] for v in targets])          # (ΣT,)
        tgt_bbox = torch.cat([v["boxes"] for v in targets])           # (ΣT, 4)

        cost_class = -out_prob[:, tgt_ids].clamp(1e-8).log()          # (B·N, ΣT)
        cost_bbox  = torch.cdist(out_bbox, tgt_bbox, p=1)             # L1
        cost_giou  = -generalized_box_iou(
            box_convert(out_bbox, "cxcywh", "xyxy"),
            box_convert(tgt_bbox, "cxcywh", "xyxy")
        )

        C = (self.cost_class * cost_class +
             self.cost_bbox  * cost_bbox  +
             self.cost_giou  * cost_giou)                             # (B·N, ΣT)
        C = C.view(bs, num_queries, -1).cpu()

        sizes = [len(v["boxes"]) for v in targets]
        indices: list[Tuple[torch.Tensor, torch.Tensor]] = []
        tgt_ptr = 0
        for b in range(bs):
            tgt_cnt = sizes[b]
            if tgt_cnt == 0:
                indices.append((torch.empty(0, dtype=torch.int64, device=device),
                                torch.empty(0, dtype=torch.int64, device=device)))
                continue
            cost_b = C[b, :, tgt_ptr: tgt_ptr + tgt_cnt].numpy()
            row_ind, col_ind = linear_sum_assignment(cost_b)
            indices.append((torch.as_tensor(row_ind, dtype=torch.int64, device=device),
                            torch.as_tensor(col_ind, dtype=torch.int64, device=device)))
            tgt_ptr += tgt_cnt

        return indices


# ─────────────────────────────────────────────────────────────────────────────
# Utilities for QFL‑Softmax (quality weights for positives)
# ─────────────────────────────────────────────────────────────────────────────

def _compute_matched_ious(
    pred_boxes: torch.Tensor,  # (B,N,4)
    targets: List[Dict],       # list of dict with "boxes"
    indices: List[Tuple[torch.Tensor, torch.Tensor]],
) -> torch.Tensor:
    """
    Calculates IoU(pred_box, gt_box) for matched positives.
    Returns: (Σ_matches,) vector, [0,1].
    """
    if len(indices) == 0:
        return pred_boxes.new_zeros((0,))
    b_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
    s_idx = torch.cat([src for (src, _) in indices])
    if b_idx.numel() == 0:
        return pred_boxes.new_zeros((0,))

    src_boxes = pred_boxes[b_idx, s_idx]    # (Σ,4)
    tgt_boxes = torch.cat([t["boxes"][J] for t, (_, J) in zip(targets, indices)], 0)  # (Σ,4)

    ious = torch.diag(generalized_box_iou(
        box_convert(src_boxes, "cxcywh", "xyxy"),
        box_convert(tgt_boxes, "cxcywh", "xyxy")
    )).clamp(min=0., max=1.)
    return ious  # (Σ,)


# ─────────────────────────────────────────────────────────────────────────────
# Main SetCriterion (production‑ready)
# ─────────────────────────────────────────────────────────────────────────────

class SetCriterion(nn.Module):
    """
    Computes all required losses for HybridDCDATRT training (IoU-OFF OPtional).

    • Class loss: Softmax‑Focal + QFL‑Softmax (quality weight only with GT IoU).
    • Box: L1 + GIoU.
    • Dense aux: Sigmoid‑focal (FG channels).
    • DN (denoising): class + box (L1+GIoU).
    • Aux decoder outputs: labels/boxes, weighted.
    • (Optional) Objectness: If outputs[“pred_obj_logits”] exists, BCE loss is added.
    • NOTE: IoU regression **removed** (no IoU header in decoder).
    """

    def __init__(self,
                 num_classes: int,
                 matcher: 'HungarianMatcher',
                 weight_dict: Dict[str, float],
                 eos_coef: float = .1,
                 losses: Optional[List[str]] = None,
                 aux_weight: float = .4,
                 *,
                 # QFL‑Softmax
                 use_qfl: bool = True,
                 qfl_beta: float = 1.0,            # quality ^ beta
                 qfl_lambda: float = 0.5,
                 qfl_use_pred: bool = False):      # ← IoU‑OFF (opt)
        super().__init__()
        self.num_classes = num_classes
        self.matcher = matcher
        self.weight_dict = weight_dict
        self.aux_weight = aux_weight


        self.losses = losses or [
            "labels", "boxes", "cardinality",
            "dense_aux",
            "dn_label", "dn_box",
            "obj",
            "center",
        ]

        # background (no-object) class weight – (used with alpha in softmax focal)
        empty_w = torch.ones(self.num_classes + 1)
        empty_w[-1] = eos_coef
        self.register_buffer("empty_weight", empty_w)

        # QFL params
        self.use_qfl = use_qfl
        self.qfl_beta = float(qfl_beta)
        # IoU‑OFF: The following two parameters are retained but are ineffective.
        self.qfl_lambda = float(qfl_lambda)
        self.qfl_use_pred = False

    # ------------------------------------------------------------------ #
    #  helpers
    # ------------------------------------------------------------------ #
    def _get_src_permutation_idx(self, indices):
        batch_idx = torch.cat([torch.full_like(src, i)
                               for i, (src, _) in enumerate(indices)])
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx

    def _build_pos_mask(self, outputs, indices):
        """(B,N) bool; matched (positive) queries True"""
        B, N = outputs["pred_logits"].shape[:2]
        pos = outputs["pred_logits"].new_zeros((B, N), dtype=torch.bool)
        b_idx, s_idx = self._get_src_permutation_idx(indices)
        if b_idx.numel():
            pos[b_idx, s_idx] = True
        return pos

    # -------------------------------- labels (QFL‑Softmax) -------------
    def loss_labels(self, outputs, targets, indices, num_boxes, **_):
        """
        Softmax‑Focal + QFL (only GT IoU).  Denom: global keep (sel|pos) count.
        When DN is added, keep the loss scale constant using ‘mean over supervised queries’.
        """
        src_logits = outputs["pred_logits"]            # (B,N,C+1)
        B, N, K = src_logits.shape
        C = K - 1
        idx = self._get_src_permutation_idx(indices)

        # all unmatched queries → background
        tgt_classes = torch.full((B, N), C, dtype=torch.int64, device=src_logits.device)
        if idx[0].numel():
            tgt_classes[idx] = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)], 0)

        # keep = sel | pos (non keep  ignore=-1 )
        if "query_selection_mask" in outputs:
            sel = outputs["query_selection_mask"].to(torch.bool)
            pos = self._build_pos_mask(outputs, indices)
            keep = (sel | pos)
            tgt_classes = tgt_classes.clone()
            tgt_classes[~keep] = -1
        else:
            keep = torch.ones(B, N, dtype=torch.bool, device=src_logits.device)

        # BG weight loss (eos_coef)
        base_w = torch.ones((B, N), device=src_logits.device, dtype=src_logits.dtype)
        bg_id = C
        base_w[tgt_classes == bg_id] = self.empty_weight[-1].item()

        # QFL quality weight — only for positives GT IoU
        sample_weight = base_w
        if self.use_qfl and idx[0].numel():
            with torch.no_grad():
                ious_t = _compute_matched_ious(outputs["pred_boxes"], targets, indices)  # (Σ_pos,)
                q = ious_t.clamp(0., 1.).pow(self.qfl_beta)
                q = torch.clamp(q, min=0.05)
            sample_weight = sample_weight.clone()
            sample_weight[idx] = sample_weight[idx] * q  # quality multiplier for positives

        #  Normalize with DDP global ‘keep’ count (truly supervised examples)
        denom = self._ddp_denom(keep.sum(), device=src_logits.device)

        loss = softmax_focal_loss_weighted(
            src_logits.reshape(-1, K),
            tgt_classes.reshape(-1),
            alpha=.25, gamma=2., reduction="sum",
            sample_weight=sample_weight.reshape(-1)
        ) / denom

        return {"loss_labels": loss}

    # -------------------------------- cardinality ----------------------
    def loss_cardinality(self, outputs, targets, indices, num_boxes, **_):
        probs = outputs["pred_logits"].softmax(-1)   # (B,N,C+1)
        bg = probs[..., -1]

        if "query_selection_mask" in outputs:
            sel = outputs["query_selection_mask"].to(torch.bool)
            pos = self._build_pos_mask(outputs, indices)
            keep = (sel | pos).float()
            denom = keep.sum(1).clamp(min=1).float()    # only keep
            card_pred = ((1. - bg) * keep).sum(1) / denom
        else:
            card_pred = (1. - bg).sum(1)

        tgt_lens = torch.as_tensor([len(t["labels"]) for t in targets],
                                   dtype=torch.float, device=probs.device)
        loss = F.l1_loss(card_pred, tgt_lens)
        return {"loss_cardinality": loss}

    # -------------------------------- boxes ----------------------------
    def loss_boxes(self, outputs, targets, indices, num_boxes, **_):
        """
        If there is no match (no GT box) ⇒ 0 loss is returned.
        """
        idx = self._get_src_permutation_idx(indices)
        if idx[0].numel() == 0:                       # zero‑match guard
            z = outputs["pred_boxes"].sum() * 0.0
            return {"loss_bbox": z, "loss_giou": z}

        src_boxes = outputs["pred_boxes"][idx]
        tgt_boxes = torch.cat([t["boxes"][i]
                               for t, (_, i) in zip(targets, indices)], 0)

        l1 = F.l1_loss(src_boxes, tgt_boxes, reduction="none").sum() / max(num_boxes, 1.0)

        giou = 1.0 - torch.diag(generalized_box_iou(
            box_convert(src_boxes, "cxcywh", "xyxy"),
            box_convert(tgt_boxes, "cxcywh", "xyxy")
        )).sum() / max(num_boxes, 1.0)

        return {"loss_bbox": l1, "loss_giou": giou}

    # -------------------------------- dense aux ------------------------
    def loss_dense_aux(self, outputs, targets, *_):
        if "aux_dense" not in outputs:
            return {}
        total = 0.0
        for pred in outputs["aux_dense"].values():           # s1, s2, s3
            logits = pred[:, :self.num_classes]              # (B,C,H,W) — FG channels
            total = total + focal_loss_dense_sigmoid(
                logits, targets, self.num_classes,
                alpha=.25, gamma=2., use_or_map=False
            )
        return {"loss_dense_aux": total}

    # ------------------------------- denoising cls ---------------------
    @staticmethod
    def _ddp_denom(count, device) -> float:
        # count: int or tensor
        if isinstance(count, torch.Tensor):
            c = count.to(device=device, dtype=torch.float)
        else:
            c = torch.tensor(float(count), device=device, dtype=torch.float)
        if dist.is_available() and dist.is_initialized():
            dist.all_reduce(c, op=dist.ReduceOp.SUM)
        return max(float(c.item()), 1.0)

    def loss_dn_label(self, outputs, *_):
        if "dn_meta" not in outputs:
            return {}

        m = outputs["dn_meta"]
        logits = m["dn_logits"]        # (B,Q,C+1)
        labels = m["dn_labels"]        # (B,Q)
        mask   = m.get("dn_mask", None)


        if mask is not None:
            mask = mask.to(torch.bool)
            if mask.dim() == 1:
                mask = mask[:, None].expand_as(labels)
            elif mask.shape != labels.shape:
                mask = mask.expand_as(labels)
            labels = labels.where(mask, labels.new_full(labels.shape, -1))  # ignore: -1

        # Number of valid samples (global)
        valid = labels.ge(0)
        denom = self._ddp_denom(valid.sum(), device=logits.device)

        # 'sum' + global denom → stable scale
        loss = softmax_focal_loss_weighted(
            logits.reshape(-1, logits.size(-1)),
            labels.reshape(-1),
            alpha=.25, gamma=2., reduction="sum"
        ) / denom

        return {"loss_dn_label": loss}


    # ------------------------------- denoising box ---------------------
    def loss_dn_box(self, outputs, *_):
        if "dn_meta" not in outputs:
            return {}

        m = outputs["dn_meta"]
        pred = m["dn_boxes"]     # (B,Q,4)
        tgt  = m["dn_gt_boxes"]  # (B,Q,4)

        # First try pos_mask, otherwise dn_mask
        posm = m.get("dn_pos_mask", None)
        if posm is None:
            posm = m.get("dn_mask", None)

        if posm is not None:
            keep = posm.to(torch.bool)
            if keep.dim() == 1:
                keep = keep[:, None].expand(pred.shape[:2])
        else:
            keep = torch.ones(pred.shape[:2], dtype=torch.bool, device=pred.device)

        if not keep.any():
            z = pred.sum() * 0.0
            return {"loss_dn_bbox": z, "loss_dn_giou": z}

        pred = pred[keep]
        tgt  = tgt[keep]

        denom = self._ddp_denom(keep.sum(), device=pred.device)

        l1 = F.l1_loss(pred, tgt, reduction="sum") / denom

        giou_mat = generalized_box_iou(
            box_convert(pred, 'cxcywh', 'xyxy'),
            box_convert(tgt,  'cxcywh', 'xyxy')
        )
        giou = (1.0 - torch.diag(giou_mat)).sum() / denom

        return {"loss_dn_bbox": l1, "loss_dn_giou": giou}
    # ------------------------------- Objectness (optional) ------------
    def loss_obj(self, outputs, targets, indices, num_boxes, **_):
        """
        If the decoder has produced ‘pred_obj_logits’ (B,N), the simple BCE loss is calculated using the keep mask (sel|pos).
        Otherwise, it returns empty.
        """
        if "pred_obj_logits" not in outputs:
            return {}
        obj_logits = outputs["pred_obj_logits"]  # (B,N)

        # keep = sel|pos
        if "query_selection_mask" in outputs:
            sel = outputs["query_selection_mask"].to(torch.bool)  # (B,N)
        else:
            sel = torch.zeros_like(obj_logits, dtype=torch.bool)

        pos = self._build_pos_mask(outputs, indices)              # (B,N) True=matched
        keep = sel | pos

        # target: pos→1, (keep & ~pos) → 0, ignore those that are not kept
        target = keep & pos
        if keep.sum() == 0:
            return {"loss_obj": obj_logits.sum() * 0.0}

        loss = F.binary_cross_entropy_with_logits(
            obj_logits[keep], target[keep].float(), reduction="mean"
        )
        return {"loss_obj": loss}

    # ------------------------------- Center alignment ------------------
    def loss_center(self, outputs, targets, indices, num_boxes, **_):
        if "final_ref_pts" not in outputs:
            return {}

        B, N, _ = outputs["pred_boxes"].shape
        device = outputs["pred_boxes"].device

        # keep = sel | pos
        if "query_selection_mask" in outputs:
            sel = outputs["query_selection_mask"].to(torch.bool)  # (B,N)
        else:
            sel = torch.ones(B, N, dtype=torch.bool, device=device)

        pos = self._build_pos_mask(outputs, indices)              # (B,N)
        keep = sel | pos                                          # (B,N)

        if keep.sum() == 0:
            z = outputs["pred_boxes"].sum() * 0.0
            return {"loss_center": z}

        # (B,N,4,2) → only keep
        ref = outputs["final_ref_pts"][keep]                      # (K,4,2) veya (K,2)
        if ref.dim() == 2:
            ctr_ref = ref                                         # (K,2)
        else:
            ctr_ref = ref.mean(dim=1)                             # (K,2)  seviye ort.
        ctr_pred = outputs["pred_boxes"][..., :2][keep]           # (K,2)

        # DDP global mean
        denom = self._ddp_denom(keep.sum(), device=device)
        loss = F.l1_loss(ctr_pred, ctr_ref, reduction="sum") / denom
        return {"loss_center": loss}

    # ------------------------------------------------------------------ #
    def get_loss(self, name, outputs, targets, indices, num_boxes):
        return {
            "labels":      self.loss_labels,
            "cardinality": self.loss_cardinality,
            "boxes":       self.loss_boxes,
            "dense_aux":   self.loss_dense_aux,
            "dn_label":    self.loss_dn_label,
            "dn_box":      self.loss_dn_box,
            "obj":         self.loss_obj,
            "center":      self.loss_center,
        }[name](outputs, targets, indices, num_boxes)

    # ------------------------------------------------------------------ #
    def forward(self,
                outputs: Dict[str, torch.Tensor],
                targets: List[Dict]) -> Dict[str, torch.Tensor]:

        # Hungarian matching (main outputs)
        indices = self.matcher(outputs, targets)

        # num_boxes
        device = outputs["pred_logits"].device
        nb = torch.as_tensor([sum(len(t["labels"]) for t in targets)],
                     dtype=torch.float, device=device)
        if dist.is_available() and dist.is_initialized():
            dist.all_reduce(nb, op=dist.ReduceOp.SUM)   # ← now global TOTAL
        num_boxes = max(nb.item(), 1.0)

        # ---- main losses ----
        loss_dict: Dict[str, torch.Tensor] = {}
        for name in self.losses:
            if name.startswith("dn_") and "dn_meta" not in outputs:
                continue
            loss_dict.update(self.get_loss(
                name, outputs, targets, indices, num_boxes))

        # ---- auxiliary decoder outputs ----
        sel_mask = outputs.get("query_selection_mask", None)
        if "aux_outputs" in outputs:
            for i, aux in enumerate(outputs["aux_outputs"]):
                # Pass the mask to aux (shape: (B,N))
                aux_in = aux if sel_mask is None else {**aux, "query_selection_mask": sel_mask}
                aux_idx = self.matcher(aux_in, targets)
                for l in ("labels", "boxes"):
                    l_dict = self.get_loss(l, aux_in, targets, aux_idx, num_boxes)
                    l_dict = {k + f"_{i}": v * self.aux_weight for k, v in l_dict.items()}
                    loss_dict.update(l_dict)

        # ---- apply weight ----
        weighted: Dict[str, torch.Tensor] = {}
        for k, v in loss_dict.items():
            base = None
            for b in self.weight_dict:
                if k == b or k.startswith(b + "_"):
                    base = b
                    break
            weighted[k] = v * self.weight_dict.get(base, 1.0)

        weighted["loss"] = sum(weighted.values())
        return weighted


# ─────────────────────────────────────────────────────────────────────────────
# Convenience builder
# ─────────────────────────────────────────────────────────────────────────────

def build_criterion(
    num_classes: int = 20,
    weight_dict: Optional[Dict[str, float]] = None,
    matcher_costs: Optional[Dict[str, float]] = None,
    *,
    use_qfl: bool = True,
    qfl_beta: float = 1.0,
    qfl_lambda: float = 0.5,   # IoU‑OFF:is not used
    qfl_use_pred: bool = False # IoU‑OFF: is not used
) -> SetCriterion:

    if weight_dict is None:
        weight_dict = {
          "loss_labels":      2.0,
          "loss_bbox":        5.0,
          "loss_giou":        2.5,
          "loss_dense_aux":   0.5,
          "loss_dn_label":    1.2,
          "loss_dn_bbox":     0.75,
          "loss_dn_giou":     0.75,
          "loss_cardinality": 0.1,
          "loss_obj":         0.15,   # if pred_obj_logits
          "loss_center":      0.1,
      }
        # ‘loss_iou’ IS INTENTIONALLY MISSING (decoder IoU header is closed for ablation)

    if matcher_costs is None:
        matcher_costs = {"cost_class": 2., "cost_bbox": 5., "cost_giou": 2.}

    matcher = HungarianMatcher(**matcher_costs)

    return SetCriterion(
        num_classes=num_classes,
        matcher=matcher,
        weight_dict=weight_dict,
        eos_coef=.1,
        losses=[
            "labels", "boxes", "cardinality",
            "dense_aux", "dn_label", "dn_box",
            "obj", "center"
        ],
        aux_weight=.4,
        use_qfl=use_qfl,
        qfl_beta=qfl_beta,
        qfl_lambda=qfl_lambda,
        qfl_use_pred=False
    )