In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys

sys.path.append("..")

In [None]:
!module list

In [None]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [None]:
from models import Res18FPNCEASC  # Adjust as needed
from utils.visdrone_dataloader import get_dataset
from utils.losses import Lnorm, Lamm, LDet, DetectionLoss  # Adjust as needed

In [None]:
import torch
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import torch
import torchvision.transforms.functional as TF
import torchvision.ops as ops
import torch.nn as nn

import torch.nn.functional as F

In [None]:
def safe_shape(x):
    if isinstance(x, torch.Tensor):
        return x.shape
    elif isinstance(x, (list, tuple)):
        return [safe_shape(e) for e in x]
    return type(x)

In [None]:
config = {
        "root_dir": "/home/soroush1/scratch/eecs_project",
        "batch_size": 4,
        "num_workers": 4,
        "num_epochs": 1,
        "lr": 1e-3,
        "config_path": "../configs/resnet18_fpn_feature_extractor.py",
    }

# Unpack config
root_dir = config["root_dir"]
batch_size = config["batch_size"]
num_workers = config["num_workers"]
num_epochs = config["num_epochs"]
learning_rate = config["lr"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset and loader
dataloader = get_dataset(
    root_dir=root_dir,
    split="train",
    transform=None,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
)

num_classes = 10

 # Model
model = Res18FPNCEASC(config_path=config["config_path"], num_classes=num_classes)
model.to(device)
model.eval()

# Losses
l_norm = Lnorm()
l_amm = Lamm()
l_det = LDet(num_classes=num_classes, num_bins=16)

batch = next(iter(dataloader))

In [None]:
images = batch["image"].to(device)
targets = {
    "boxes": batch["boxes"],
    "labels": [lbl.clamp(0, num_classes - 1) for lbl in batch["labels"]],
    "image_id": batch["image_id"],
    "orig_size": batch["orig_size"],
}

print("\n🔍 Inspecting `targets` structure:")
for i in range(len(targets["boxes"])):
    print(f"--- Sample {i} ---")
    print(f"Image ID:         {targets['image_id'][i]}")
    print(f"Original Size:    {targets['orig_size'][i]}")
    print(f"Boxes shape:      {targets['boxes'][i].shape}")  # [N_i, 4]
    print(f"Labels shape:     {targets['labels'][i].shape}")  # [N_i]
    print(f"Boxes:            {targets['boxes'][i]}")
    print(f"Labels:           {targets['labels'][i]}")

In [None]:
# Forward pass
outputs = model(images, stage="train")
(
    cls_outs,
    reg_outs,
    soft_mask_outs,
    sparse_cls_feats_outs,
    sparse_reg_feats_outs,
    dense_cls_feats_outs,
    dense_reg_feats_outs,
    feats,
    anchors,
) = outputs

In [None]:
print("\n🔍 Output shapes from model:")
for i in range(len(cls_outs)):
    print(f"--- FPN Level {i} ---")
    print(f"cls_outs[{i}]:              {safe_shape(cls_outs[i])}") # anchors * num_classes 
    print(f"reg_outs[{i}]:              {safe_shape(reg_outs[i])}") # anchors * 4 * bin_numbers
    print(
        f"soft_mask_outs[{i}]:    {safe_shape(soft_mask_outs[i])}"
    )
    print(
        f"sparse_cls_feats[{i}]:      {safe_shape(sparse_cls_feats_outs[i])}"
    )
    print(
        f"sparse_reg_feats[{i}]:      {safe_shape(sparse_reg_feats_outs[i])}"
    )
    print(
        f"dense_cls_feats[{i}]:       {safe_shape(dense_cls_feats_outs[i])}"
    )
    print(
        f"dense_reg_feats[{i}]:       {safe_shape(dense_reg_feats_outs[i])}"
    )
    print(f"feats[{i}]:                 {safe_shape(feats[i])}")

for i, anchor in enumerate(anchors):
    print(f"P{i+3} Anchors shape: {anchor.shape}")

In [None]:
from mmdet.models.losses import QualityFocalLoss

In [None]:
QualityFocalLoss?

In [None]:
# === Calculate detection loss ===
loss_fn = DetectionLoss(num_bins=16, num_classes=10, num_anchors=6)
losses = loss_fn(cls_outs, reg_outs, anchors, targets, device=device)
print(f"{losses = }")
print(losses['total_loss'])

In [None]:
loss_amm = l_amm(
        soft_mask_outs, targets["boxes"], im_dimx=1024, im_dimy=540
    )  # used the soft masks in this version, might be incorrect

loss_amm

In [None]:
loss_norm = l_norm(
                sparse_cls_feats_outs, soft_mask_outs, dense_cls_feats_outs
            )

loss_norm

In [None]:
alpha = 1.0
beta = 10.0

final_loss = (
    losses["total_loss"] + 
    alpha * loss_norm + 
    beta * loss_amm
)

final_loss

In [None]:
def plot_anchors_on_image(image_tensor, anchors, num_to_plot=100, title="Anchors", color="red"):
    """
    Plots anchor boxes on an image.

    Args:
        image_tensor (Tensor): shape (3, H, W)
        anchors (Tensor): shape (N, 4), format (x1, y1, x2, y2)
        num_to_plot (int): number of anchor boxes to plot
        title (str): title of the plot
        color (str): color of anchor boxes
    """
    # Convert tensor to numpy for visualization
    if isinstance(image_tensor, torch.Tensor):
        # If image is a tensor (transformed), convert back to numpy
        img = image_tensor.permute(1, 2, 0).detach().cpu().numpy()
        # Unnormalize if normalized
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img = img * std + mean
        img = np.clip(img, 0, 1)

    
    image = TF.to_pil_image(image_tensor.cpu())
    anchors_np = anchors.cpu().numpy()

    fig, ax = plt.subplots(1, figsize=(12, 8))
    ax.imshow(image)
    ax.set_title(title)

    for i in range(min(num_to_plot, len(anchors_np))):
        x1, y1, x2, y2 = anchors_np[i]
        rect = patches.Rectangle(
            (x1, y1),
            x2 - x1,
            y2 - y1,
            linewidth=1,
            edgecolor=color,
            facecolor='none'
        )
        ax.add_patch(rect)

    plt.axis("off")
    plt.show()

ds = dataloader.dataset
plot_anchors_on_image(images[3], anchors[0], num_to_plot=5000, title="Anchors at FPN Level 0")

In [None]:
def visualize_item(image_tensor, boxes, figsize=(10, 10)):
        """
        Visualize an image with its annotations

        Args:
            idx (int): Index of the item to visualize
            figsize (tuple): Figure size
        """
        import matplotlib.pyplot as plt
        import matplotlib.patches as patches
        from matplotlib.colors import to_rgba

        # Convert tensor to numpy for visualization
        if isinstance(image_tensor, torch.Tensor):
            # If image is a tensor (transformed), convert back to numpy
            img = image_tensor.permute(1, 2, 0).cpu().numpy()
            # Unnormalize if normalized
            mean = np.array([0.485, 0.456, 0.406])
            std = np.array([0.229, 0.224, 0.225])
            img = img * std + mean
            img = np.clip(img, 0, 1)
        else:
            # If image is PIL, convert to numpy
            img = np.array(sample["image"]) / 255.0

        # Create figure and axis
        fig, ax = plt.subplots(1, figsize=figsize)
        ax.imshow(img)

        # Define colors for different categories (you can customize these)
        colors = [
            "red",
            "blue",
            "green",
            "yellow",
            "purple",
            "orange",
            "cyan",
            "magenta",
            "brown",
            "pink",
        ]

        # Plot bounding boxes
        for box in boxes:
            # print(f"{box.size() = }")
            x1, y1, x2, y2 = box
            width = x2 - x1
            height = y2 - y1

            # Get color based on category
            # color = colors[(label - 1) % len(colors)]

            # Create rectangle
            rect = patches.Rectangle(
                (x1, y1), width, height, linewidth=2, edgecolor="red", facecolor="none"
            )
            ax.add_patch(rect)

        # plt.title(f"Image: {sample['img_name']} - {len(boxes)} objects")
        plt.axis("off")
        plt.tight_layout()

        # save instead of show
        plt.savefig("test.png")
        plt.close()

visualize_item(images[3], targets["boxes"][3])

# Implemeting Detection loss

In [None]:
class ATSSMatcher:
    def __init__(self, top_k=9):
        self.top_k = top_k  # number of anchors to select per level

    def __call__(self, anchors_per_level, gt_boxes, device = None):
        """
        anchors_per_level: List[Tensor[N_i, 4]] in (x1, y1, x2, y2) format
        gt_boxes: Tensor[M, 4]
        Returns:
            matched_idxs: Tensor[N_total] with GT index or -1
            max_ious: Tensor[N_total]
        """
        
        num_gt = gt_boxes.size(0)
        all_anchors = torch.cat(anchors_per_level, dim=0)  # [N_total, 4]
        num_anchors = all_anchors.size(0)

        if device:
            all_anchors = all_anchors.to(device)
            gt_boxes = gt_boxes.to(device)

        matched_idxs = torch.full((num_anchors,), -1, dtype=torch.long, device=gt_boxes.device)
        max_ious = torch.zeros(num_anchors, dtype=torch.float, device=gt_boxes.device)

        # 1. Compute IoU between all anchors and GTs
        ious = ops.box_iou(all_anchors, gt_boxes)  # [N_total, M]

        # 2. Compute anchor centers
        anchor_centers = (all_anchors[:, :2] + all_anchors[:, 2:]) / 2  # [N, 2]
        gt_centers = (gt_boxes[:, :2] + gt_boxes[:, 2:]) / 2  # [M, 2]

        for gt_idx in range(num_gt):
            gt_box = gt_boxes[gt_idx]
            gt_center = gt_centers[gt_idx]  # [2]

            # Distance from GT center to anchor centers
            distances = torch.norm(anchor_centers - gt_center[None, :], dim=1)  # [N]

            # Pick top-k closest anchors
            topk_idxs = torch.topk(distances, self.top_k, largest=False).indices  # [top_k]

            topk_ious = ious[topk_idxs, gt_idx]
            iou_mean = topk_ious.mean()
            iou_std = topk_ious.std()
            dynamic_thresh = iou_mean + iou_std

            # Positive = anchors with IoU >= dynamic_thresh and inside GT
            candidate_mask = ious[:, gt_idx] >= dynamic_thresh

            inside_gt = self.anchor_inside_box(all_anchors, gt_box)
            pos_mask = candidate_mask & inside_gt  # [N]

            pos_indices = pos_mask.nonzero(as_tuple=False).squeeze(1)
            matched_idxs[pos_indices] = gt_idx
            max_ious[pos_indices] = ious[pos_indices, gt_idx]

        return matched_idxs, max_ious

    def anchor_inside_box(self, anchors, gt_box):
        """
        Return a mask of anchors whose center is inside the GT box.
        """
        cx = (anchors[:, 0] + anchors[:, 2]) / 2
        cy = (anchors[:, 1] + anchors[:, 3]) / 2

        return (
            (cx >= gt_box[0]) & (cx <= gt_box[2]) &
            (cy >= gt_box[1]) & (cy <= gt_box[3])
        )

In [None]:
matcher = ATSSMatcher(top_k=9)
matched_idxs = []
max_ious = []

for batch_id in range(len(targets["boxes"])):
    
    matched_idx, iou = matcher(anchors, targets["boxes"][batch_id], device=device)  # for image i
    matched_idxs.append(matched_idx)
    max_ious.append(iou)


matched_idxs = torch.stack(matched_idxs, dim=0)
max_ious = torch.stack(max_ious, dim=0)

matched_idxs.size(), max_ious.size()

In [None]:
all_anchors = torch.cat(anchors, dim=0)  # [N_total, 4]
num_total_anchors = all_anchors.size(0)
batch_size = len(targets["boxes"])
cls_targets = torch.zeros((batch_size, num_total_anchors), dtype=torch.long, device=device)

for batch_idx in range(batch_size):
    # For positive anchors, assign the GT class label
    pos_mask = matched_idxs[batch_idx] >= 0
    if pos_mask.any():
        gt_indices = matched_idxs[batch_idx][pos_mask]
        target = targets["labels"][batch_idx].to(device)
        cls_targets[batch_idx, pos_mask] = target[gt_indices]

cls_targets.size()

In [None]:
cls_targets.unique()

In [None]:
iou_targets = torch.zeros((batch_size, num_total_anchors), dtype=torch.float, device=device)
for batch_idx in range(batch_size):
    pos_mask = matched_idxs[batch_idx] >= 0
    if pos_mask.any():
        iou_targets[batch_idx, pos_mask] = max_ious[batch_idx][pos_mask]

iou_targets.size()

In [None]:
reg_targets = torch.zeros((batch_size, num_total_anchors, 4), dtype=torch.float, device=device)
for batch_idx in range(batch_size):
    pos_mask = matched_idxs[batch_idx] >= 0
    if pos_mask.any():
        gt_indices = matched_idxs[batch_idx][pos_mask]
        bbox = targets["boxes"][batch_idx].to(device)
        
        gt_boxes_matched = bbox[gt_indices]
        reg_targets[batch_idx, pos_mask] = gt_boxes_matched

reg_targets.size()

In [None]:
def prepare_predictions(cls_outs, reg_outs, num_classes=10, num_bins=16):
    batch_size = cls_outs[0].shape[0]
    num_anchors = 6
    
    # Process each FPN level
    all_cls_preds = []
    all_reg_preds = []
    
    for cls_out, reg_out in zip(cls_outs, reg_outs):
        # Get dimensions
        B, C, H, W = cls_out.shape
        
        # For classification: [B, A*num_classes, H, W] -> [B, H*W*A, num_classes]
        # First reshape to [B, A, num_classes, H, W]
        reshaped_cls = cls_out.view(B, num_anchors, num_classes, H, W)
        # Then permute to [B, H, W, A, num_classes]
        permuted_cls = reshaped_cls.permute(0, 3, 4, 1, 2)
        # Finally reshape to [B, H*W*A, num_classes]
        flat_cls = permuted_cls.reshape(B, H*W*num_anchors, num_classes)
        all_cls_preds.append(flat_cls)
        
        # For regression: [B, A*4*num_bins, H, W] -> [B, H*W*A, 4*num_bins]
        # First reshape to [B, A, 4*num_bins, H, W]
        reshaped_reg = reg_out.view(B, num_anchors, 4*num_bins, H, W)
        # Then permute to [B, H, W, A, 4*num_bins]
        permuted_reg = reshaped_reg.permute(0, 3, 4, 1, 2)
        # Finally reshape to [B, H*W*A, 4*num_bins]
        flat_reg = permuted_reg.reshape(B, H*W*num_anchors, 4*num_bins)
        all_reg_preds.append(flat_reg)
    
    # Concatenate across FPN levels
    cls_preds = torch.cat(all_cls_preds, dim=1)  # [B, N_total*A, num_classes]
    reg_preds = torch.cat(all_reg_preds, dim=1)  # [B, N_total*A, 4*num_bins]
    
    return cls_preds, reg_preds

cls_preds, reg_preds = prepare_predictions(cls_outs, reg_outs, num_classes=10, num_bins=16)
cls_preds.size(), reg_preds.size()

In [None]:
def quality_focal_loss(pred, target, iou_targets, beta=2.0):
    """
    Quality Focal Loss for dense object detection.
    
    Args:
        pred: [B, N, C] logits tensor
        target: [B, N] class indices tensor
        iou_targets: [B, N] IoU values between anchors and GT boxes
        beta: modulating factor
    
    Returns:
        loss: scalar tensor
    """
    batch_size, num_anchors, num_classes = pred.shape
    
    # Create one-hot encoding for targets
    target_one_hot = F.one_hot(target, num_classes).float()  # [B, N, C]
    
    # Sigmoid of predictions
    pred_sigmoid = pred.sigmoid()
    
    # Get probability for the target class (pt)
    pt = (target_one_hot * pred_sigmoid + (1 - target_one_hot) * (1 - pred_sigmoid))
    
    # Focal weight with IoU quality
    # When IoU is high (high quality) and pt is low, we want a high weight
    # When IoU is low (low quality) or pt is high, we want a low weight
    weight = (iou_targets.unsqueeze(-1) * (1 - pt) + (1 - iou_targets.unsqueeze(-1)) * pt).pow(beta)
    
    # Binary cross entropy loss
    bce_loss = F.binary_cross_entropy_with_logits(
        pred, target_one_hot, reduction='none'
    )
    
    # Combine with weight
    loss = bce_loss * weight
    
    # Determine positive samples for normalization
    num_positive = (target > 0).sum().item()
    
    return loss.sum() / max(1, num_positive)

quality_focal_loss(cls_preds, cls_targets, iou_targets)

In [None]:
def distribution_focal_loss(pred, target, pos_mask=None):
    """
    Distribution Focal Loss for bounding box regression.
    
    Args:
        pred: [B, N, 4*bins] regression prediction
        target: [B, N, 4] regression targets in [0, 1] range
        pos_mask: [B, N] boolean mask for positive samples, or None to use all samples
    
    Returns:
        loss: scalar tensor
    """
    # Extract dimensions
    B, N, _ = pred.shape
    num_bins = pred.shape[-1] // 4
    
    # Reshape predictions to [B, N, 4, bins]
    pred = pred.reshape(B, N, 4, num_bins)
    
    # Convert target from continuous value to bin+offset
    target_bins = (target * (num_bins - 1)).long()
    target_bins = torch.clamp(target_bins, 0, num_bins - 2)
    target_offset = (target * (num_bins - 1)) - target_bins.float()
    
    # Get left and right bin predictions
    pred_left = torch.gather(pred, dim=3, index=target_bins.unsqueeze(-1)).squeeze(-1)
    pred_right = torch.gather(pred, dim=3, index=(target_bins + 1).unsqueeze(-1)).squeeze(-1)
    
    # DFL weighting
    weight_left = 1 - target_offset
    weight_right = target_offset
    
    # Alternative implementation using direct computation
    # Reshape for BCELoss
    pred_left = pred_left.reshape(-1)
    pred_right = pred_right.reshape(-1)
    weight_left = weight_left.reshape(-1)
    weight_right = weight_right.reshape(-1)
    
    # Calculate losses
    loss_left = F.binary_cross_entropy_with_logits(
        pred_left, weight_left, reduction='none'
    )
    loss_right = F.binary_cross_entropy_with_logits(
        pred_right, weight_right, reduction='none'
    )
    
    # Reshape back
    loss = (loss_left + loss_right).reshape(B, N, 4)
    
    # Apply positive mask if provided
    if pos_mask is not None:
        loss = loss * pos_mask.unsqueeze(-1)
        num_positive = pos_mask.sum().item() * 4  # Multiply by 4 for coordinate dimension
    else:
        num_positive = B * N * 4  # Use all samples and all coordinates
    
    return loss.sum() / max(1, num_positive)

pos_mask = matched_idxs >= 0  # [B, N]
distribution_focal_loss(reg_preds, reg_targets, pos_mask)

In [None]:
def decode_dfl_bins(reg_preds, num_bins=16):
    """
    Convert DFL output to continuous bounding box regression targets.

    Args:
        reg_preds: [B, N, 4 * num_bins]
        num_bins: int, number of bins

    Returns:
        deltas: [B, N, 4] in (tx, ty, tw, th)
    """
    B, N, _ = reg_preds.shape
    reg_preds = reg_preds.view(B, N, 4, num_bins)  # [B, N, 4, bins]
    prob = F.softmax(reg_preds, dim=-1)            # apply softmax over bins

    bin_values = torch.arange(num_bins, dtype=torch.float32, device=reg_preds.device)  # [bins]
    expected = (prob * bin_values).sum(dim=-1)  # [B, N, 4]

    return expected / (num_bins - 1)  # normalize back to [0, 1] scale

def delta2bbox(anchors, deltas):
    """
    Decode regression deltas back to bounding boxes.

    Args:
        anchors: [B, N, 4] in (x1, y1, x2, y2)
        deltas:  [B, N, 4] in (tx, ty, tw, th)

    Returns:
        boxes: [B, N, 4] in (x1, y1, x2, y2)
    """
    widths  = anchors[:, :, 2] - anchors[:, :, 0]
    heights = anchors[:, :, 3] - anchors[:, :, 1]
    ctr_x   = anchors[:, :, 0] + 0.5 * widths
    ctr_y   = anchors[:, :, 1] + 0.5 * heights

    dx = deltas[:, :, 0]
    dy = deltas[:, :, 1]
    dw = deltas[:, :, 2]
    dh = deltas[:, :, 3]

    pred_ctr_x = dx * widths + ctr_x
    pred_ctr_y = dy * heights + ctr_y
    pred_w = torch.exp(dw) * widths
    pred_h = torch.exp(dh) * heights

    x1 = pred_ctr_x - 0.5 * pred_w
    y1 = pred_ctr_y - 0.5 * pred_h
    x2 = pred_ctr_x + 0.5 * pred_w
    y2 = pred_ctr_y + 0.5 * pred_h

    return torch.stack([x1, y1, x2, y2], dim=2)  # [B, N, 4]


def giou_loss(pred_deltas, target_boxes, anchors, pos_mask):
    B, N, _ = pred_deltas.shape
    anchors = anchors.unsqueeze(0).expand(B, N, 4)  # [B, N, 4]
    pred_boxes = delta2bbox(anchors, pred_deltas)

    total_loss = 0.0
    total_pos = 0

    for b in range(B):
        pos = pos_mask[b]
        if pos.sum() == 0:
            continue

        pred_b = pred_boxes[b][pos]
        target_b = target_boxes[b][pos]

        giou = ops.generalized_box_iou(pred_b, target_b)
        loss = 1.0 - giou.diagonal()  # only matched pairs

        total_loss += loss.sum()
        total_pos += len(loss)

    return total_loss / max(total_pos, 1)


# 1. Get pos_mask
pos_mask = matched_idxs >= 0  # [B, N]

# 2. Decode DFL bins to deltas
reg_deltas = decode_dfl_bins(reg_preds, num_bins=16)  # [B, N, 4]

# 3. Prepare anchor tensor
all_anchors = torch.cat(anchors, dim=0).to(device)  # [N, 4]

# 4. Compute GIoU Loss
giou = giou_loss(reg_deltas, reg_targets, all_anchors, pos_mask)
giou

In [None]:
from mmdet.models.losses import QualityFocalLoss   # or use your custom one

qfl = QualityFocalLoss(use_sigmoid=True, beta=2.0, reduction='mean', loss_weight=1.0)

# Example for a single image (i-th in batch)
i = 0
# Filter positive anchors
pos_mask = matched_idxs[i] >= 0
pos_inds = pos_mask.nonzero(as_tuple=True)[0]

# Prepare predictions and targets for QFL
qfl_pred = cls_preds[i][pos_inds]             # [num_pos, num_classes]
qfl_labels = cls_targets[i][pos_inds]         # [num_pos] class labels (long)
qfl_scores = iou_targets[i][pos_inds]         # [num_pos] IoU scores (float)

# Format target as a tuple for QFL
qfl_target = (qfl_labels, qfl_scores)

loss = qfl(qfl_pred, qfl_target)

loss

In [None]:
from mmdet.models.losses import DistributionFocalLoss
class GFLBBoxCoder:
    def __init__(self, num_bins=16):
        self.num_bins = num_bins

    def encode(self, anchors, gt_boxes):
        """
        Args:
            anchors: [N, 4] in (x1, y1, x2, y2)
            gt_boxes: [N, 4] in (x1, y1, x2, y2)
        Returns:
            reg_targets: [N, 4] float values in [0, num_bins)
        """
        # convert to ltrb format
        anchor_w = anchors[:, 2] - anchors[:, 0]
        anchor_h = anchors[:, 3] - anchors[:, 1]

        t = (gt_boxes[:, 1] - anchors[:, 1]) / anchor_h * (self.num_bins - 1)
        l = (gt_boxes[:, 0] - anchors[:, 0]) / anchor_w * (self.num_bins - 1)
        b = (gt_boxes[:, 3] - anchors[:, 3]) / anchor_h * (self.num_bins - 1)
        r = (gt_boxes[:, 2] - anchors[:, 2]) / anchor_w * (self.num_bins - 1)

        reg_targets = torch.stack([l, t, r, b], dim=-1).clamp(0, self.num_bins - 1)
        return reg_targets

i = 0
pos_mask = matched_idxs[i] >= 0
pos_inds = pos_mask.nonzero(as_tuple=True)[0]

# Ensure all_anchors is a tensor
if not isinstance(all_anchors, torch.Tensor):
    all_anchors = torch.tensor(all_anchors).to(reg_preds.device).float()

reg_pred = reg_preds[i][pos_inds]           # [num_pos, 64]
pos_anchors = all_anchors[pos_inds]         # [num_pos, 4]
pos_gt_boxes = reg_targets[i][matched_idxs[i][pos_inds]]  # [num_pos, 4]

bbox_coder = GFLBBoxCoder(num_bins=16)
encoded_reg_target = bbox_coder.encode(pos_anchors, pos_gt_boxes)  # [num_pos, 4]

# Flatten for DFL
reg_pred = reg_pred.view(-1, 16)            # [num_pos * 4, 16]
reg_target = encoded_reg_target.view(-1)    # [num_pos * 4]

dfl = DistributionFocalLoss(reduction='mean', loss_weight=1.0)
loss = dfl(reg_pred, reg_target)

loss

In [None]:
from mmdet.models.losses import GIoULoss

def decode_dfl_bbox(reg_pred_logits, anchors, num_bins=16):
    B, _ = reg_pred_logits.shape
    reg_pred_logits = reg_pred_logits.view(B, 4, num_bins)
    prob = torch.softmax(reg_pred_logits, dim=2)
    bins = torch.arange(num_bins, dtype=prob.dtype, device=prob.device)
    dist = torch.sum(prob * bins, dim=2) / (num_bins - 1)

    anchor_w = anchors[:, 2] - anchors[:, 0]
    anchor_h = anchors[:, 3] - anchors[:, 1]

    l = dist[:, 0] * anchor_w
    t = dist[:, 1] * anchor_h
    r = dist[:, 2] * anchor_w
    b = dist[:, 3] * anchor_h

    x1 = anchors[:, 0] - l
    y1 = anchors[:, 1] - t
    x2 = anchors[:, 2] + r
    y2 = anchors[:, 3] + b

    return torch.stack([x1, y1, x2, y2], dim=1)

# === Setup ===
i = 0
pos_mask = matched_idxs[i] >= 0
pos_inds = pos_mask.nonzero(as_tuple=True)[0]

if not isinstance(all_anchors, torch.Tensor):
    all_anchors = torch.tensor(all_anchors).to(reg_preds.device).float()

reg_pred_logits = reg_preds[i][pos_inds]
anchors_pos = all_anchors[pos_inds]

pred_boxes = decode_dfl_bbox(reg_pred_logits, anchors_pos, num_bins=16)
target_boxes = reg_targets[i][matched_idxs[i][pos_inds]]

giou_loss_fn = GIoULoss(reduction='mean', loss_weight=1.0)
loss = giou_loss_fn(pred_boxes, target_boxes)

loss
