In [None]:
#| default_exp box_regression
#| export
from typing import Tuple, List, Union
import torch
from qct_3d_nod_detect.structures import Boxes3D
from fvcore.nn import smooth_l1_loss

class Box3DTransform:

    def __init__(
        self, weights: Tuple[float, float, float, float, float, float], scale_clamp: float
    ):

        """
        Args:
            weights (6-element tuple): Scaling factors that are applied to the
                (dx, dy, dz, dw, dh, dd) deltas. In Fast R-CNN, these were originally set
                such that the deltas have unit variance; now they are treated as
                hyperparameters of the system.
            scale_clamp (float): When predicting deltas, the predicted box scaling
                factors (dw and dh) are clamped such that they are <= scale_clamp.
        """

        self.weights = weights
        self.scale_clamp = scale_clamp
    
    def get_deltas(self, src_boxes, target_boxes):

        """
        Get box regression transformation deltas (dx, dy, dz, dw, dh, dd) that can be used
        to transform the `src_boxes` into the `target_boxes`. That is, the relation
        ``target_boxes == self.apply_deltas(deltas, src_boxes)`` is true (unless
        any delta is too large and is clamped).

        Args:
            src_boxes (Tensor): source boxes, e.g., object proposals
            target_boxes (Tensor): target of the transformation, e.g., ground-truth
                boxes.
        """

        assert isinstance(src_boxes, torch.Tensor), type(src_boxes)
        assert isinstance(target_boxes, torch.Tensor), type(target_boxes)

        src_widths = src_boxes[:, 3] - src_boxes[:, 0] # x
        src_heights = src_boxes[:, 4] - src_boxes[:, 1] # y
        src_depths = src_boxes[:, 5] - src_boxes[:, 2] # z

        src_ctr_x = src_boxes[:, 0] + 0.5*src_widths
        src_ctr_y = src_boxes[:, 1] + 0.5*src_heights
        src_ctr_z = src_boxes[:, 2] + 0.5*src_depths

        target_widths = target_boxes[:, 3] - target_boxes[:, 0] 
        target_heights = target_boxes[:, 4] - target_boxes[:, 1] 
        target_depths = target_boxes[:, 5] - target_boxes[:, 2]

        target_ctr_x = target_boxes[:, 0] + 0.5 * target_widths
        target_ctr_y = target_boxes[:, 1] + 0.5 * target_heights
        target_ctr_z = target_boxes[:, 2] + 0.5 * target_depths

        wx, wy, wz, ww, wh, wd = self.weights

        dx = wx * (target_ctr_x - src_ctr_x) / src_widths
        dy = wy * (target_ctr_y - src_ctr_y) / src_heights
        dz = wz * (target_ctr_z - src_ctr_z) / src_depths

        dw = ww * torch.log(target_widths / src_widths)
        dh = wh * torch.log(target_heights / src_heights)
        dd = wd * torch.log(target_depths / src_depths)

        deltas = torch.stack((dx, dy, dz, dw, dh, dd), dim=1)

        # Safety check
        assert (src_widths > 0).all().item(),  "Invalid source box widths"
        assert (src_heights > 0).all().item(), "Invalid source box heights"
        assert (src_depths > 0).all().item(),  "Invalid source box depths"

        return deltas

    def apply_deltas(self, deltas: torch.Tensor, boxes: torch.Tensor) -> torch.Tensor:

        """
        Apply predicted deltas to boxes to obtain refined predictions.

        Args:
            deltas (Tensor): predicted deltas, shape (N, K*6) where K is number of predictions per box
            boxes (Tensor): source boxes, shape (N, 6)

        Returns:
            pred_boxes (Tensor): transformed boxes, same shape as input boxes but reshaped
        """

        deltas = deltas.float()  # ensure fp32 for numerical stability
        boxes = boxes.to(deltas.dtype)

        widths  = boxes[:, 3] - boxes[:, 0]
        heights = boxes[:, 4] - boxes[:, 1]
        depths  = boxes[:, 5] - boxes[:, 2]

        ctr_x = boxes[:, 0] + 0.5 * widths
        ctr_y = boxes[:, 1] + 0.5 * heights
        ctr_z = boxes[:, 2] + 0.5 * depths

        wx, wy, wz, ww, wh, wd = self.weights

        # Un-weight deltas
        dx = deltas[:, 0::6] / wx
        dy = deltas[:, 1::6] / wy
        dz = deltas[:, 2::6] / wz
        dw = deltas[:, 3::6] / ww
        dh = deltas[:, 4::6] / wh
        dd = deltas[:, 5::6] / wd

        # Clamp scale factors to prevent explosion
        dw = torch.clamp(dw, max=self.scale_clamp)
        dh = torch.clamp(dh, max=self.scale_clamp)
        dd = torch.clamp(dd, max=self.scale_clamp)

        # Apply transformation
        pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
        pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
        pred_ctr_z = dz * depths[:, None]  + ctr_z[:, None]

        pred_w = torch.exp(dw) * widths[:, None]
        pred_h = torch.exp(dh) * heights[:, None]
        pred_d = torch.exp(dd) * depths[:, None]

        # Convert back to corner coordinates
        x0 = pred_ctr_x - 0.5 * pred_w
        y0 = pred_ctr_y - 0.5 * pred_h
        z0 = pred_ctr_z - 0.5 * pred_d
        x1 = pred_ctr_x + 0.5 * pred_w
        y1 = pred_ctr_y + 0.5 * pred_h
        z1 = pred_ctr_z + 0.5 * pred_d

        pred_boxes = torch.stack((x0, y0, z0, x1, y1, z1), dim=-1)

        # Reshape to match input deltas shape (supports multiple predictions per box)
        return pred_boxes.reshape(deltas.shape)

def pairwise_intersection(boxes1: Boxes3D, boxes2: Boxes3D) -> torch.Tensor:

    boxes1, boxes2 = boxes1.tensor, boxes2.tensor
    depth_height_width = (
        torch.min(boxes1[:, None, 3:], boxes2[:, 3:]) - 
        torch.max(boxes1[:, None, :3], boxes2[:, :3])
    ) # (N, M, 3)

    depth_height_width.clamp_(min=0)

    intersection = depth_height_width.prod(dim=2) 
    return intersection

def pairwise_iou_3d(boxes1: Boxes3D, boxes2: Boxes3D) -> torch.Tensor:

    """
    Given two lists of boxes of sizes N and M computes the IoU 
    (intersection over Union) between **all** N x M pairs of boxes.
    The box order must be (xmin, ymin, xmax, ymax).
    """

    vol1 = boxes1.volume()
    vol2 = boxes2.volume()

    inter = pairwise_intersection(boxes1, boxes2)

    iou = torch.where(
        inter > 0,
        inter / (vol1[:, None] + vol2 - inter),
        torch.zeros(1, dtype=inter.dtype, device=inter.device)
    )

    return iou

def _volume(boxes: torch.Tensor) -> torch.Tensor:
    """
    boxes: (..., 6)
    """
    return (
        (boxes[..., 3] - boxes[..., 0]).clamp(min=0) *
        (boxes[..., 4] - boxes[..., 1]).clamp(min=0) *
        (boxes[..., 5] - boxes[..., 2]).clamp(min=0)
    )

def _intersection_3d(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
    """
    boxes1, boxes2: (..., 6)
    """
    max_xyz = torch.min(boxes1[..., 3:], boxes2[..., 3:])
    min_xyz = torch.max(boxes1[..., :3], boxes2[..., :3])
    inter = (max_xyz - min_xyz).clamp(min=0)
    return inter[..., 0] * inter[..., 1] * inter[..., 2]


def _enclosing_box_3d(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
    """
    Smallest enclosing box of boxes1 and boxes2
    """
    min_xyz = torch.min(boxes1[..., :3], boxes2[..., :3])
    max_xyz = torch.max(boxes1[..., 3:], boxes2[..., 3:])
    return torch.cat([min_xyz, max_xyz], dim=-1)

def giou_loss_3d(
    boxes1: torch.Tensor,
    boxes2: torch.Tensor,
    reduction: str="none",
    eps: float = 1e-7,
):

    """
    boxes1, boxes2: (N, 6)
    """

    inter = _intersection_3d(boxes1, boxes2)
    vol1, vol2 = _volume(boxes1), _volume(boxes2)

    union = vol1 + vol2 - inter + eps
    iou = inter / union

    c = _enclosing_box_3d(boxes1, boxes2)
    vol_c = _volume(c) + eps

    giou = iou - (vol_c - union) / vol_c
    loss = 1.0 - giou

    if reduction == "sum":
        return loss.sum()
    elif reduction == "mean":
        return loss.mean()
    return loss

def diou_loss_3d(
    boxes1: torch.Tensor,
    boxes2: torch.Tensor,
    reduction: str = "none",
    eps: float = 1e-7
):

    inter = _intersection_3d(boxes1, boxes2)
    vol1, vol2 = _volume(boxes1), _volume(boxes2)

    union = vol1 + vol2 - inter + eps
    iou = inter / union

    c1 = (boxes1[..., :3] + boxes1[..., 3:]) / 2
    c2 = (boxes2[..., :3] + boxes2[..., 3:]) / 2
    center_dist_sq = ((c1-c2) ** 2).sum(dim=-1)

    c = _enclosing_box_3d(boxes1, boxes2)
    diag_sq = ((c[..., 3:] - c[..., 3:]) ** 2).sum(dim=-1) + eps

    diou = iou - center_dist_sq / diag_sq
    loss = 1.0 - diou

    if reduction == "sum":
        return loss.sum()
    elif reduction == "mean":
        return loss.mean()
    return loss

def ciou_loss_3d(
    boxes1: torch.Tensor,
    boxes2: torch.Tensor,
    reduction: str = "none",
    eps: float = 1e-7,
):
    inter = _intersection_3d(boxes1, boxes2)
    vol1 = _volume(boxes1)
    vol2 = _volume(boxes2)

    union = vol1 + vol2 - inter + eps
    iou = inter / union

    # centers
    c1 = (boxes1[..., :3] + boxes1[..., 3:]) / 2
    c2 = (boxes2[..., :3] + boxes2[..., 3:]) / 2
    center_dist_sq = ((c1 - c2) ** 2).sum(dim=-1)

    # enclosing diagonal
    c = _enclosing_box_3d(boxes1, boxes2)
    diag_sq = ((c[..., 3:] - c[..., :3]) ** 2).sum(dim=-1) + eps

    # box dimensions
    d1 = (boxes1[..., 3:] - boxes1[..., :3]).clamp(min=eps)
    d2 = (boxes2[..., 3:] - boxes2[..., :3]).clamp(min=eps)

    # aspect ratio consistency (3D)
    v = (
        (torch.atan(d1[..., 0] / d1[..., 1]) - torch.atan(d2[..., 0] / d2[..., 1])) ** 2 +
        (torch.atan(d1[..., 1] / d1[..., 2]) - torch.atan(d2[..., 1] / d2[..., 2])) ** 2 +
        (torch.atan(d1[..., 0] / d1[..., 2]) - torch.atan(d2[..., 0] / d2[..., 2])) ** 2
    )

    with torch.no_grad():
        alpha = v / (1 - iou + v + eps)

    ciou = iou - center_dist_sq / diag_sq - alpha * v
    loss = 1.0 - ciou

    if reduction == "sum":
        return loss.sum()
    elif reduction == "mean":
        return loss.mean()
    return loss

def _dense_box_regression_loss_3d(
    anchors: List[Union[Boxes3D, torch.Tensor]],
    box3d2box3d_transform,
    pred_anchor_deltas: List[torch.Tensor],
    gt_boxes: List[torch.Tensor],
    fg_mask: torch.Tensor,
    box_reg_loss_type="smooth_l1",
    smooth_l1_beta=0.0,
):
    """
    Compute loss for dense multi-level 3D box regression.
    Loss is accumulated over ``fg_mask``.

    Args:
        anchors: #lvl anchor boxes, each is (HixWixA, 6)
        pred_anchor_deltas: #lvl predictions, each is (N, HixWixA, 6)
        gt_boxes: N ground truth boxes, each has shape (R, 6)
        fg_mask: foreground boolean mask of shape (N, R)
        box_reg_loss_type (str): "smooth_l1", "giou", "diou", "ciou"
        smooth_l1_beta (float): beta for Smooth L1 loss
    """

    # Concatenate anchors across feature levels
    if isinstance(anchors[0], Boxes3D):
        anchors = type(anchors[0]).cat(anchors).tensor  # (R, 6)
    else:
        anchors = torch.cat(anchors, dim=0)  # (R, 6)

    # Concatenate predictions across levels: (N, R, 6)
    pred_anchor_deltas = torch.cat(pred_anchor_deltas, dim=1)

    if box_reg_loss_type == "smooth_l1":
        # Compute GT deltas
        gt_anchor_deltas = [
            box3d2box3d_transform.get_deltas(anchors, k)
            for k in gt_boxes
        ]
        gt_anchor_deltas = torch.stack(gt_anchor_deltas)  # (N, R, 6)

        loss_box_reg = smooth_l1_loss(
            pred_anchor_deltas[fg_mask],
            gt_anchor_deltas[fg_mask],
            beta=smooth_l1_beta,
            reduction="sum",
        )

    elif box_reg_loss_type in {"giou", "diou", "ciou"}:
        # Decode predicted boxes
        pred_boxes = [
            box3d2box3d_transform.apply_deltas(d, anchors)
            for d in pred_anchor_deltas
        ]  # list of (R, 6)

        pred_boxes = torch.stack(pred_boxes)  # (N, R, 6)
        gt_boxes = torch.stack(gt_boxes)      # (N, R, 6)

        if box_reg_loss_type == "giou":
            loss_box_reg = giou_loss_3d(
                pred_boxes[fg_mask],
                gt_boxes[fg_mask],
                reduction="sum",
            )
        elif box_reg_loss_type == "diou":
            loss_box_reg = diou_loss_3d(
                pred_boxes[fg_mask],
                gt_boxes[fg_mask],
                reduction="sum",
            )
        else:  # ciou
            loss_box_reg = ciou_loss_3d(
                pred_boxes[fg_mask],
                gt_boxes[fg_mask],
                reduction="sum",
            )

    else:
        raise ValueError(f"Invalid dense box regression loss type '{box_reg_loss_type}'")

    return loss_box_reg