# PoseCNN: A Convolutional Neural Network for 6D Object Pose Estimation
---
This notebook provides a comprehensive explanation of **PoseCNN**, an architecture designed for **6D object pose estimation** in cluttered scenes. PoseCNN tackles the challenges of occlusion and object symmetry effectively.

## 1. Introduction
PoseCNN is a convolutional neural network developed to estimate the **6D pose** of known objects. It decouples the problem into **semantic labeling**, **3D translation estimation**, and **3D rotation regression**. Its main contributions include:

- A robust framework for 6D pose estimation in occluded and cluttered scenes.
- Introduction of **ShapeMatch Loss** to handle symmetric objects.
- Creation of the **YCB-Video dataset** with 133,827 frames for 6D pose estimation.

## 2. Dataset
### YCB-Video Dataset
This dataset contains **133,827 RGB-D frames** across 92 videos. It provides:
- 21 objects with 6D pose annotations.
- Severe occlusions and symmetric objects.

### OccludedLINEMOD Dataset
- A benchmark for 6D pose estimation with significant occlusions.
- 1,214 frames with 6D poses for 8 objects.

## 3. PoseCNN Architecture (Loss)

In [None]:
def loss_cross_entropy(scores, labels):
    """
    scores: a tensor [batch_size, num_classes, height, width]
    labels: a tensor [batch_size, num_classes, height, width]
    """

    cross_entropy = -torch.sum(labels * scores, dim=1)
    loss = torch.div(torch.sum(cross_entropy), torch.sum(labels)+1e-10)

    return loss


def smooth_l1_loss(vertex_pred, vertex_targets, vertex_weights, sigma=1.0):
    sigma_2 = sigma ** 2
    vertex_diff = vertex_pred - vertex_targets
    diff = torch.mul(vertex_weights, vertex_diff)
    abs_diff = torch.abs(diff)
    smoothL1_sign = torch.lt(abs_diff, 1. / sigma_2).float().detach()
    in_loss = torch.pow(diff, 2) * (sigma_2 / 2.) * smoothL1_sign \
            + (abs_diff - (0.5 / sigma_2)) * (1. - smoothL1_sign)
    loss = torch.div( torch.sum(in_loss), torch.sum(vertex_weights) + 1e-10 )
    return loss

# compute output
if cfg.TRAIN.VERTEX_REG:
    if cfg.TRAIN.POSE_REG:
        out_logsoftmax, out_weight, out_vertex, out_logsoftmax_box, \
            bbox_labels, bbox_pred, bbox_targets, bbox_inside_weights, loss_pose_tensor, poses_weight \
            = network(inputs, labels, meta_data, extents, gt_boxes, poses, points, symmetry)

        loss_label = loss_cross_entropy(out_logsoftmax, out_weight) # 예측 확률과 실제 라벨 간 차이(semantic segmentation)
        loss_vertex = cfg.TRAIN.VERTEX_W * smooth_l1_loss(out_vertex, vertex_targets, vertex_weights) # 물체 vertex 위치 regression 손실
        loss_box = loss_cross_entropy(out_logsoftmax_box, bbox_labels) # bbox 클래스 예측과 실제 bbox 클레스 간 차이
        loss_location = smooth_l1_loss(bbox_pred, bbox_targets, bbox_inside_weights) # bbox 위치 regression 손실
        loss_pose = torch.mean(loss_pose_tensor) # 물체의 Point matching loss(아래 PLOSS와 SLOSS)
        loss = loss_label + loss_vertex + loss_box + loss_location + loss_pose

### Pose Loss (PLOSS)
Pose Loss is used for **asymmetric objects**. It minimizes the average squared distance between the predicted and ground-truth 3D points transformed by the respective rotations.

#### Pose Loss Formula:
$$ PLOSS(q, \tilde{q}) = \frac{1}{2m} \sum_{x \in M} ||R(q)x - R(\tilde{q})x||^2 $$
Where:
- $q$ is the ground-truth quaternion.
- $\tilde{q}$ is the predicted quaternion.
- $R(q)$ is the rotation matrix derived from $q$.
- $M$ is the set of 3D model points.

### ShapeMatch Loss (SLOSS)
ShapeMatch Loss is used for **symmetric objects**. Instead of comparing fixed points, it measures the closest point distance to handle symmetry effectively.

#### ShapeMatch Loss Formula:
$$ SLOSS(q, \tilde{q}) = \frac{1}{2m} \sum_{x_1 \in M} \min_{x_2 \in M} ||R(q)x_1 - R(\tilde{q})x_2||^2 $$
Where:
- $x_1$ is a point on the estimated model.
- $x_2$ is the closest point on the ground-truth model.


In [None]:
# Pose Loss and ShapeMatch Loss Implementation
import torch
import math

POSE_CHANNELS = 4

def pml_forward(bottom_prediction: torch.Tensor,
                bottom_target: torch.Tensor,
                bottom_weight: torch.Tensor,
                points: torch.Tensor,
                symmetry: torch.Tensor,
                hard_angle: float):
    """
    Python/PyTorch 구현의 pml_cuda_forward 개념적 복사본.
    원본 .cu 코드와 동일한 인덱싱, 로직, 수학 연산을 재현.
    
    Parameters
    ----------
    bottom_prediction : torch.Tensor
        [batch, num_classes*4] 예측 퀘터니언
    bottom_target : torch.Tensor
        [batch, num_classes*4] GT 퀘터니언
    bottom_weight : torch.Tensor
        [batch, num_classes*4] 해당 클래스 사용 여부
    points : torch.Tensor
        [num_classes, num_points, 3] 물체 포인트
    symmetry : torch.Tensor
        [num_classes], 대칭 물체 여부
    hard_angle : float
        hard angle 기준 값

    Returns
    -------
    top_data: shape [1], 전체 손실
    bottom_diff: [batch, num_classes*POSE_CHANNELS], 회전에 대한 그래디언트
    """

    device = bottom_prediction.device
    batch_size = bottom_prediction.size(0)
    num_classes = points.size(0)
    num_points = points.size(1)

    # rotations: [batch, num_points, 54], 54 = 6*9
    # 각 포인트 마다 
    # 0-8: GT rotation matrix
    # 9-17: predicted rotation matrix
    # 18-26, 27-35, 36-44, 45-53: 4개의 미분 행렬(각 9원소, 총 36원소)
    rotations = torch.zeros(batch_size, num_points, 6*9, device=device)
    losses = torch.zeros(batch_size, num_points, device=device)
    diffs = torch.zeros(batch_size, num_points, POSE_CHANNELS * num_classes, device=device)
    angles_batch = torch.zeros(batch_size, device=device)

    # points를 원래 코드처럼 flatten 하여 인덱싱하기 (C++에서 point[index+k]와 유사)
    # index = index_cls * num_points * 3 + p * 3 로 접근
    points_flat = points.view(-1, 3)

    for n in range(batch_size):
        # find class
        index_cls = -1
        for c in range(num_classes):
            if bottom_weight[n, c*POSE_CHANNELS].item() > 0:
                index_cls = c
                break
        if index_cls == -1:
            continue

        # GT quaternion
        s_gt = bottom_target[n, index_cls*POSE_CHANNELS+0].item()
        u_gt = bottom_target[n, index_cls*POSE_CHANNELS+1].item()
        v_gt = bottom_target[n, index_cls*POSE_CHANNELS+2].item()
        w_gt = bottom_target[n, index_cls*POSE_CHANNELS+3].item()

        # Pred quaternion
        s_pr = bottom_prediction[n, index_cls*POSE_CHANNELS+0].item()
        u_pr = bottom_prediction[n, index_cls*POSE_CHANNELS+1].item()
        v_pr = bottom_prediction[n, index_cls*POSE_CHANNELS+2].item()
        w_pr = bottom_prediction[n, index_cls*POSE_CHANNELS+3].item()

        # Compute GT rotation matrix
        # indices for rotations
        # ind = n * num_points * 6 * 9 + p * 6 * 9; -> Python: rotations[n, p, ...]
        # 여기서는 p 루프 후에 할당
        # 각 p마다 GT/Pred/derivatives를 rotations에 저장

        # 각 포인트별 반복
        for p in range(num_points):
            # set rotations for GT quaternion
            ind_base = p * 6 * 9
            # GT rotation
            rotations[n, p, 0] = s_gt*s_gt + u_gt*u_gt - v_gt*v_gt - w_gt*w_gt
            rotations[n, p, 1] = 2*(u_gt*v_gt - s_gt*w_gt)
            rotations[n, p, 2] = 2*(u_gt*w_gt + s_gt*v_gt)
            rotations[n, p, 3] = 2*(u_gt*v_gt + s_gt*w_gt)
            rotations[n, p, 4] = s_gt*s_gt - u_gt*u_gt + v_gt*v_gt - w_gt*w_gt
            rotations[n, p, 5] = 2*(v_gt*w_gt - s_gt*u_gt)
            rotations[n, p, 6] = 2*(u_gt*w_gt - s_gt*v_gt)
            rotations[n, p, 7] = 2*(v_gt*w_gt + s_gt*u_gt)
            rotations[n, p, 8] = s_gt*s_gt - u_gt*u_gt - v_gt*v_gt + w_gt*w_gt

            # predicted rotation
            rotations[n, p, 9] = s_pr*s_pr + u_pr*u_pr - v_pr*v_pr - w_pr*w_pr
            rotations[n, p,10] = 2*(u_pr*v_pr - s_pr*w_pr)
            rotations[n, p,11] = 2*(u_pr*w_pr + s_pr*v_pr)
            rotations[n, p,12] = 2*(u_pr*v_pr + s_pr*w_pr)
            rotations[n, p,13] = s_pr*s_pr - u_pr*u_pr + v_pr*v_pr - w_pr*w_pr
            rotations[n, p,14] = 2*(v_pr*w_pr - s_pr*u_pr)
            rotations[n, p,15] = 2*(u_pr*w_pr - s_pr*v_pr)
            rotations[n, p,16] = 2*(v_pr*w_pr + s_pr*u_pr)
            rotations[n, p,17] = s_pr*s_pr - u_pr*u_pr - v_pr*v_pr + w_pr*w_pr

            # 각도 계산(p == 0일때)
            if p == 0:
                d = s_gt*s_pr + u_gt*u_pr + v_gt*v_pr + w_gt*w_pr # 내적을 통한 유사도 표현
                angle = math.acos(max(min(2*d*d - 1,1),-1))*180.0/math.pi # θ/2 -> θ 
                if angle > hard_angle:
                    angles_batch[n] = 1.0

            # Derivatives of Ru to quaternion (from original code)
            # ind + 18 sets
            # For simplicity, reuse s_pr,u_pr,v_pr,w_pr from predicted quaternion
            # This is from original code
            # 18-26
            idx_deriv = 18
            rotations[n, p, idx_deriv+0] = 2 * s_pr
            rotations[n, p, idx_deriv+1] = -2 * w_pr
            rotations[n, p, idx_deriv+2] = 2 * v_pr
            rotations[n, p, idx_deriv+3] = 2 * w_pr
            rotations[n, p, idx_deriv+4] = 2 * s_pr
            rotations[n, p, idx_deriv+5] = -2 * u_pr
            rotations[n, p, idx_deriv+6] = -2 * v_pr
            rotations[n, p, idx_deriv+7] = 2 * u_pr
            rotations[n, p, idx_deriv+8] = 2 * s_pr

            # 27-35
            idx_deriv = 27
            rotations[n, p, idx_deriv+0] = 2 * u_pr
            rotations[n, p, idx_deriv+1] = 2 * v_pr
            rotations[n, p, idx_deriv+2] = 2 * w_pr
            rotations[n, p, idx_deriv+3] = 2 * v_pr
            rotations[n, p, idx_deriv+4] = -2 * u_pr
            rotations[n, p, idx_deriv+5] = -2 * s_pr
            rotations[n, p, idx_deriv+6] = 2 * w_pr
            rotations[n, p, idx_deriv+7] = 2 * s_pr
            rotations[n, p, idx_deriv+8] = -2 * u_pr

            # 36-44
            idx_deriv = 36
            rotations[n, p, idx_deriv+0] = -2 * v_pr
            rotations[n, p, idx_deriv+1] = 2 * u_pr
            rotations[n, p, idx_deriv+2] = 2 * s_pr
            rotations[n, p, idx_deriv+3] = 2 * u_pr
            rotations[n, p, idx_deriv+4] = 2 * v_pr
            rotations[n, p, idx_deriv+5] = 2 * w_pr
            rotations[n, p, idx_deriv+6] = -2 * s_pr
            rotations[n, p, idx_deriv+7] = 2 * w_pr
            rotations[n, p, idx_deriv+8] = -2 * v_pr

            # 45-53
            idx_deriv = 45
            rotations[n, p, idx_deriv+0] = -2 * w_pr
            rotations[n, p, idx_deriv+1] = -2 * s_pr
            rotations[n, p, idx_deriv+2] = 2 * u_pr
            rotations[n, p, idx_deriv+3] = 2 * s_pr
            rotations[n, p, idx_deriv+4] = -2 * w_pr
            rotations[n, p, idx_deriv+5] = 2 * v_pr
            rotations[n, p, idx_deriv+6] = 2 * u_pr
            rotations[n, p, idx_deriv+7] = 2 * v_pr
            rotations[n, p, idx_deriv+8] = 2 * w_pr

            index = index_cls * num_points * 3 + p * 3
            # rotate the first point with predicted rotation:
            x1 = rotations[n, p, 9+0]*points_flat[index][0]+rotations[n, p, 9+1]*points_flat[index][1]+rotations[n, p, 9+2]*points_flat[index][2]
            y1 = rotations[n, p, 9+3]*points_flat[index][0]+rotations[n, p, 9+4]*points_flat[index][1]+rotations[n, p, 9+5]*points_flat[index][2]
            z1 = rotations[n, p, 9+6]*points_flat[index][0]+rotations[n, p, 9+7]*points_flat[index][1]+rotations[n, p, 9+8]*points_flat[index][2]

            ## symmetry ##
            if symmetry[index_cls].item() > 0:
                dmin = float('inf')
                index_min = 0
                for i in range(num_points):
                    index2 = index_cls * num_points * 3 + i * 3
                    x2 = rotations[n, p, 0]*points_flat[index2][0]+rotations[n, p, 1]*points_flat[index2][1]+rotations[n, p, 2]*points_flat[index2][2]
                    y2 = rotations[n, p, 3]*points_flat[index2][0]+rotations[n, p, 4]*points_flat[index2][1]+rotations[n, p, 5]*points_flat[index2][2]
                    z2 = rotations[n, p, 6]*points_flat[index2][0]+rotations[n, p, 7]*points_flat[index2][1]+rotations[n, p, 8]*points_flat[index2][2]
                    dist = (x1 - x2)**2+(y1 - y2)**2+(z1 - z2)**2
                    if dist < dmin:
                        dmin = dist
                        index_min = index2
            else:
                index_min = index

            x2 = rotations[n, p, 0]*points_flat[index_min][0]+rotations[n, p, 1]*points_flat[index_min][1]+rotations[n, p, 2]*points_flat[index_min][2]
            y2 = rotations[n, p, 3]*points_flat[index_min][0]+rotations[n, p, 4]*points_flat[index_min][1]+rotations[n, p, 5]*points_flat[index_min][2]
            z2 = rotations[n, p, 6]*points_flat[index_min][0]+rotations[n, p, 7]*points_flat[index_min][1]+rotations[n, p, 8]*points_flat[index_min][2]

            # smooth l1 loss
            distance = 0.0
            index_diff = n * num_points * POSE_CHANNELS * num_classes + p * POSE_CHANNELS * num_classes + POSE_CHANNELS * index_cls
            coords_diff = [(x1 - x2), (y1 - y2), (z1 - z2)]
            for j in range(3):
                diff_val = coords_diff[j]
                abs_diff = abs(diff_val)
                if abs_diff < 1:
                    distance += 0.5 * diff_val*diff_val
                    df = diff_val
                else:
                    distance += abs_diff - 0.5
                    df = 1.0 if diff_val > 0 else -1.0

                # diffs 계산
                # k 루프
                # rotations 미분 인덱싱
                # 첫 미분행렬: ind + 18
                # ind = n * num_points * 6 * 9 + p * 6 * 9
                ind_base_rot = n * num_points * 6 * 9 + p * 6 * 9
                # 각 파셜 파트
                # diffs[index_diff + 0 ... 3]
                # point[index + k] * rotations[ind + ...] / num_points
                # k in {0,1,2}
                for k_ in range(3):
                    # 4 세트의 미분 행렬
                    # 0 set: ind+18
                    # 1 set: ind+27
                    # 2 set: ind+36
                    # 3 set: ind+45
                    # j * 3 + k_ 로 접근

                    # channel별 diffs
                    # diffs for quaternion channels 0..3
                    # original code:
                    # diffs[index_diff + 0] += df * point[index + k_] * rotations[ind + 18 + j*3 + k_] / num_points;
                    # diff[channel 0]
                    r18 = rotations.view(-1)[ind_base_rot + 18 + j*3 + k_].item()
                    r27 = rotations.view(-1)[ind_base_rot + 27 + j*3 + k_].item()
                    r36 = rotations.view(-1)[ind_base_rot + 36 + j*3 + k_].item()
                    r45 = rotations.view(-1)[ind_base_rot + 45 + j*3 + k_].item()

                    val_point = points_flat[index + k_][0] if k_ == 0 else (points_flat[index + k_][1] if k_==1 else points_flat[index + k_][2])
                    # 하지만 points_flat[index+k_][dim]은 index+k_에서 이미 3D point: points_flat는 (..,3)
                    # 실제로 index+k_는 point 인덱스 자체임. index는 3D점 시작점
                    # points_flat[index + k_]은 불가. index+k_는 3단위 증가
                    # 수정: index는 이미 point의 시작 index. k_를 좌표축으로 착각했음
                    # original code: point[index + k] 은 (index: start of point) + k (0=x,1=y,2=z)
                    val_point = points_flat[index][k_]

                    # num_points로 나누기
                    scale = df * val_point.item() / num_points

                    diffs[n, p, index_cls*POSE_CHANNELS + 0] += scale * r18
                    diffs[n, p, index_cls*POSE_CHANNELS + 1] += scale * r27
                    diffs[n, p, index_cls*POSE_CHANNELS + 2] += scale * r36
                    diffs[n, p, index_cls*POSE_CHANNELS + 3] += scale * r45

            losses[n, p] = distance / num_points

    # angles sum
    batch_hard = angles_batch.sum().item()

    # sum diffs and losses for bottom_diff, losses_batch
    bottom_diff = torch.zeros(batch_size, POSE_CHANNELS * num_classes, device=device)
    losses_batch = torch.zeros(batch_size, device=device)
    if batch_hard > 0:
        for n in range(batch_size):
            if angles_batch[n].item() > 0:
                # diffs sum
                # bottom_diff[index] = sum over p of diffs / batch_hard
                # losses_batch[n] = sum(losses[n,:]) / batch_hard
                bottom_diff[n, :] = diffs[n, :, :].sum(dim=0) / batch_hard
                losses_batch[n] = losses[n, :].sum() / batch_hard

    top_data = losses_batch.sum().unsqueeze(0)

    return top_data, bottom_diff


def pml_backward(grad_loss: torch.Tensor, bottom_diff: torch.Tensor):
    """
    Python/PyTorch로 pml_cuda_backward 개념 구현.
    grad_loss: [1]
    bottom_diff: [batch, num_classes*POSE_CHANNELS]

    output:
    grad_rotation: [batch, num_classes*POSE_CHANNELS]
    grad_rotation = grad_loss[0]*bottom_diff
    """
    grad_rotation = grad_loss[0] * bottom_diff
    return grad_rotation



## 4. Evaluation Metrics (Extended)
PoseCNN uses two primary evaluation metrics for 6D object pose estimation:

### 4.1. ADD (Average Distance)
ADD measures the mean distance between corresponding points of the predicted and ground-truth 3D model poses.

$$ ADD = \frac{1}{m} \sum_{x \in M} \| (R\tilde{x} + T) - (Rx + T) \| $$
- $R$ and $T$ are the ground-truth rotation and translation.
- $\tilde{R}$ and $\tilde{T}$ are the predicted rotation and translation.
- $x$ is a point on the 3D model with $m$ total points.

ADD is effective for **asymmetric objects**.

In [None]:
# ADD Metric Implementation
def transform_pts_Rt(pts, R, t):
    """
    Applies a rigid transformation to 3D points.

    :param pts: nx3 ndarray with 3D points.
    :param R: 3x3 rotation matrix.
    :param t: 3x1 translation vector.
    :return: nx3 ndarray with transformed 3D points.
    """
    assert(pts.shape[1] == 3)
    pts_t = R.dot(pts.T) + t.reshape((3, 1))
    return pts_t.T

def add(R_est, t_est, R_gt, t_gt, pts):
    """
    Average Distance of Model Points for objects with no indistinguishable views
    - by Hinterstoisser et al. (ACCV 2012).

    :param R_est, t_est: Estimated pose (3x3 rot. matrix and 3x1 trans. vector).
    :param R_gt, t_gt: GT pose (3x3 rot. matrix and 3x1 trans. vector).
    :param model: Object model given by a dictionary where item 'pts'
    is nx3 ndarray with 3D model points.
    :return: Error of pose_est w.r.t. pose_gt.
    """
    pts_est = transform_pts_Rt(pts, R_est, t_est)
    pts_gt = transform_pts_Rt(pts, R_gt, t_gt)
    e = np.linalg.norm(pts_est - pts_gt, axis=1).mean()
    return e

### 4.2. ADD-S (ADD-Symmetric)
ADD-S extends ADD to handle symmetric objects by taking the closest point distance:

$$ ADD-S = \frac{1}{m} \sum_{x_1 \in M} \min_{x_2 \in M} \| (R\tilde{x}_1 + T) - (Rx_2 + T) \| $$
- $x_1$ is a point on the predicted model.
- $x_2$ is the closest point on the ground-truth model.

ADD-S ensures robustness for symmetric objects, where exact point correspondence is ambiguous.

In [None]:
# ADD-S Metric Implementation
from scipy import spatial

def transform_pts_Rt(pts, R, t):
    """
    Applies a rigid transformation to 3D points.

    :param pts: nx3 ndarray with 3D points.
    :param R: 3x3 rotation matrix.
    :param t: 3x1 translation vector.
    :return: nx3 ndarray with transformed 3D points.
    """
    assert(pts.shape[1] == 3)
    pts_t = R.dot(pts.T) + t.reshape((3, 1))
    return pts_t.T

def adi(R_est, t_est, R_gt, t_gt, pts):
    """
    Average Distance of Model Points for objects with indistinguishable views
    - by Hinterstoisser et al. (ACCV 2012).

    :param R_est, t_est: Estimated pose (3x3 rot. matrix and 3x1 trans. vector).
    :param R_gt, t_gt: GT pose (3x3 rot. matrix and 3x1 trans. vector).
    :param model: Object model given by a dictionary where item 'pts'
    is nx3 ndarray with 3D model points.
    :return: Error of pose_est w.r.t. pose_gt.
    """
    pts_est = transform_pts_Rt(pts, R_est, t_est)
    pts_gt = transform_pts_Rt(pts, R_gt, t_gt)

    # Calculate distances to the nearest neighbors from pts_gt to pts_est
    nn_index = spatial.cKDTree(pts_est) # KD-Tree
    nn_dists, _ = nn_index.query(pts_gt, k=1)

    e = nn_dists.mean()
    return e

## 5. PoseCNN Architecture (Overview)
PoseCNN consists of three main branches:

1. **Semantic Labeling:** Identifies object pixels in the input image.
2. **3D Translation Estimation:** Localizes the object center and estimates the distance from the camera.
3. **3D Rotation Regression:** Estimates the orientation of the object using a quaternion representation.

In [None]:
# roi_target_layer (학습에 사용할 classification label과 Box regression Target 값 생성)

from typing import Tuple
import torch
import numpy as np
from fcn.config import cfg
from utils.bbox_transform import bbox_transform
from utils.cython_bbox import bbox_overlaps

def roi_target_layer(rpn_rois, gt_boxes):
    """
    Assign RoIs to ground truth boxes and compute classification labels and
    bounding box regression targets.
    
    Args:
        rpn_rois: Tensor of shape (N, 6) - [batch_id, class, x1, y1, x2, y2].
        gt_boxes: Tensor of shape (batch, num_classes, 5) - [x1, y1, x2, y2, class].

    Returns:
        label_blob: One-hot encoded classification labels.
        bbox_targets: Bounding box regression targets.
        bbox_inside_weights: Weights for inside boxes.
        bbox_outside_weights: Weights for outside boxes.
    """

    # Convert tensors to NumPy arrays for processing
    rpn_rois = rpn_rois.detach().cpu().numpy()
    gt_boxes = gt_boxes.detach().cpu().numpy()
    num_classes = gt_boxes.shape[1]

    # Prepare RoI and GT blobs for processing
    roi_blob = rpn_rois[:, (0, 2, 3, 4, 5, 1)]  # Rearrange columns
    gt_box_blob = np.vstack([
        np.hstack([
            np.full((gt_boxes.shape[1], 1), i),  # Batch index
            gt_boxes[i, :, :4],
            gt_boxes[i, :, 4:5]  # Class
        ])
        for i in range(gt_boxes.shape[0])
        if np.any(gt_boxes[i, :, -1] > 0)
    ])

    # Sample RoIs to create classification labels and regression targets
    labels, bbox_targets, bbox_inside_weights = _sample_rois(roi_blob, gt_box_blob, num_classes)
    bbox_outside_weights = np.array(bbox_inside_weights > 0).astype(np.float32)  # Binary outside weights

    # Convert labels to one-hot encoding
    label_blob = np.zeros((labels.shape[0], num_classes), dtype=np.float32)
    valid_indices = labels > 0
    label_blob[valid_indices, labels[valid_indices].astype(int)] = 1.0

    # Convert outputs back to PyTorch tensors
    return (
        torch.tensor(label_blob, device=rpn_rois.device, dtype=torch.float32),
        torch.tensor(bbox_targets, device=rpn_rois.device, dtype=torch.float32),
        torch.tensor(bbox_inside_weights, device=rpn_rois.device, dtype=torch.float32),
        torch.tensor(bbox_outside_weights, device=rpn_rois.device, dtype=torch.float32),
    )

def _get_bbox_regression_labels(bbox_target_data, num_classes):
    """
    Expand compact bounding-box regression targets into a per-class format.

    Args:
        bbox_target_data: (N, 5) np.array - [class, tx, ty, tw, th].
        num_classes: Total number of classes. (int)

    Returns:
        bbox_targets: (N, 4 * num_classes) regression targets.
        bbox_inside_weights: (N, 4 * num_classes) inside weights.
    """

    clss = bbox_target_data[:, 0].astype(int)
    bbox_targets = np.zeros((clss.size, 4 * num_classes), dtype=np.float32)
    bbox_inside_weights = np.zeros_like(bbox_targets, dtype=np.float32)

    # Fill in regression targets for foreground classes
    for i, cls in enumerate(clss):
        if cls > 0:
            start = 4 * cls
            bbox_targets[i, start:start+4] = bbox_target_data[i, 1:]
            bbox_inside_weights[i, start:start+4] = cfg.TRAIN.BBOX_INSIDE_WEIGHTS

    return bbox_targets, bbox_inside_weights



def _compute_targets(ex_rois, gt_rois, labels):
    """
    Compute bounding-box regression targets for RoIs.
    
    Args:
        ex_rois: (N, 4) array of proposed RoIs.
        gt_rois: (N, 4) array of ground truth boxes.
        labels: (N,) array of class labels.

    Returns:
        targets: (N, 5) array of regression targets [class, tx, ty, tw, th].
    """

    # Ensure inputs have consistent dimensions
    assert ex_rois.shape[0] == gt_rois.shape[0]
    assert ex_rois.shape[1] == 4
    assert gt_rois.shape[1] == 4

    # Compute the bounding-box transformation
    targets = bbox_transform(ex_rois, gt_rois)

    # Normalize the targets if configured
    if cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED:
        targets = ((targets - np.array(cfg.TRAIN.BBOX_NORMALIZE_MEANS))
                   / np.array(cfg.TRAIN.BBOX_NORMALIZE_STDS))

    return np.hstack((labels[:, None], targets)).astype(np.float32)


def _sample_rois(all_rois, gt_boxes, num_classes):
    """
    Generate a random sample of RoIs for training.
    
    Args:
        all_rois: (N, 6) array of proposed RoIs.
        gt_boxes: (M, 6) array of ground truth boxes.

    Returns:
        labels: (N,) classification labels.
        bbox_targets: (N, 4 * num_classes) regression targets.
        bbox_inside_weights: (N, 4 * num_classes) inside weights.
    """

    if gt_boxes.shape[0] == 0:  # If no ground truth boxes exist
        num = all_rois.shape[0]
        return (
            np.zeros(num, dtype=np.float32),
            np.zeros((num, 4 * num_classes), dtype=np.float32),
            np.zeros((num, 4 * num_classes), dtype=np.float32),
        )
        
    # Compute IoU overlaps between RoIs and GT boxes
    overlaps = bbox_overlaps(all_rois[:, 1:5], gt_boxes[:, 1:5])

    # Assign the best matching GT box to each RoI
    gt_assignment = overlaps.argmax(axis=1)
    max_overlaps = overlaps.max(axis=1)
    labels = gt_boxes[gt_assignment, -1]

    # Mark RoIs with low IoU as background
    bg_inds = max_overlaps < cfg.TRAIN.FG_THRESH
    labels[bg_inds] = 0

    # Compute regression targets
    bbox_target_data = _compute_targets(all_rois[:, 1:5], gt_boxes[gt_assignment, 1:5], labels)
    bbox_targets, bbox_inside_weights = _get_bbox_regression_labels(bbox_target_data, num_classes)

    return labels, bbox_targets, bbox_inside_weights



In [None]:
# pose_target_layer (학습에 사용할 Pose label과 Target 값 생성)
from typing import Tuple
import torch
import numpy as np
from fcn.config import cfg
from utils.bbox_transform import bbox_transform_inv
from utils.cython_bbox import bbox_overlaps

def pose_target_layer(
    rois: torch.Tensor,
    bbox_prob: torch.Tensor,
    bbox_pred: torch.Tensor,
    gt_boxes: torch.Tensor,
    poses: torch.Tensor,
    is_training: bool
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Generate pose targets for training.

    Args:
        rois: Proposed Regions of Interest (RoIs) [batch_id, class, x1, y1, x2, y2].
        bbox_prob: Predicted class probabilities for each RoI.
        bbox_pred: Predicted bounding box coordinates.
        gt_boxes: Ground truth bounding boxes.
        poses: Ground truth pose data.
        is_training: Whether the model is in training mode.

    Returns:
        Updated RoIs, pose targets, and pose weights.
    """
    # Convert PyTorch tensors to NumPy arrays for processing
    rois = rois.detach().cpu().numpy()
    bbox_prob = bbox_prob.detach().cpu().numpy()
    bbox_pred = bbox_pred.detach().cpu().numpy()
    gt_boxes = gt_boxes.detach().cpu().numpy()
    num_classes = bbox_prob.shape[1]  # Number of classes

    # Apply normalization to bounding box predictions if configured
    if cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED:
        stds = np.tile(cfg.TRAIN.BBOX_NORMALIZE_STDS, num_classes)
        means = np.tile(cfg.TRAIN.BBOX_NORMALIZE_MEANS, num_classes)
        bbox_pred_np *= stds
        bbox_pred_np += means

    # Decode bounding box predictions into actual coordinates
    boxes = rois[:, 2:6].copy()  # Extract current box coordinates
    pred_boxes = bbox_transform_inv(boxes, bbox_pred)  # Transform predictions into coordinates

    # Update RoIs with predicted boxes and probabilities
    for i in range(rois.shape[0]):
        cls = int(rois[i, 1])  # Get the predicted class for this RoI
        rois[i, 2:6] = pred_boxes[i, cls*4:cls*4+4]  # Assign predicted box
        rois[i, 6] = bbox_prob[i, cls]  # Assign predicted class probability

    # Prepare blobs for RoIs and ground truth
    roi_blob = rois[:, [0, 2, 3, 4, 5, 1]]  # Format: [batch_id, x1, y1, x2, y2, class]
    gt_box_blob = np.zeros((0, 6), dtype=np.float32)  # Initialize ground truth blob
    pose_blob = np.zeros((0, 9), dtype=np.float32)  # Initialize pose blob

    # Loop through ground truth boxes and prepare blobs
    for i in range(gt_boxes.shape[0]):  # Batch-wise
        for j in range(gt_boxes.shape[1]):  # Per-object
            if gt_boxes[i, j, -1] > 0:  # Check if the object exists (class > 0)
                gt_box = np.zeros((1, 6), dtype=np.float32)
                gt_box[0, 0] = i  # Batch index
                gt_box[0, 1:5] = gt_boxes[i, j, :4]  # Box coordinates
                gt_box[0, 5] = gt_boxes[i, j, 4]  # Class label
                gt_box_blob = np.vstack([gt_box_blob, gt_box])  # Add to blob
                poses[i, j, 0] = i  # Assign batch index for poses
                pose_blob = np.vstack([pose_blob, poses[i, j, :].cpu().reshape(1, 9)])  # Add pose data

    # If no ground truth boxes exist, create empty targets and weights
    if gt_box_blob.shape[0] == 0:
        num = rois.shape[0]
        poses_target = np.zeros((num, 4 * num_classes), dtype=np.float32)
        poses_weight = np.zeros((num, 4 * num_classes), dtype=np.float32)
    else:
        # Compute overlaps between RoIs and ground truth boxes
        overlaps = bbox_overlaps(
            roi_blob[:, :5].astype(np.float32),
            gt_box_blob[:, :5].astype(np.float32)
        )
        # Match each RoI to the best ground truth box
        gt_assignment = overlaps.argmax(axis=1)
        max_overlaps = overlaps.max(axis=1)
        labels = gt_box_blob[gt_assignment, 5]  # Assign labels
        quaternions = pose_blob[gt_assignment, 2:6]  # Assign poses (quaternions)

        # Mark RoIs with low overlap as background
        bg_inds = np.where(max_overlaps < cfg.TRAIN.FG_THRESH_POSE)[0]
        labels[bg_inds] = 0

        # Further filter based on mismatched class predictions
        bg_inds = np.where(roi_blob[:, -1] != labels)[0]
        labels[bg_inds] = 0

        # In training, only keep positive samples for pose regression
        if is_training:
            fg_inds = np.where(labels > 0)[0]  # Positive RoIs
            if len(fg_inds) > 0:
                rois = rois[fg_inds]  # Filter RoIs
                quaternions = quaternions[fg_inds]  # Filter quaternions
                labels = labels[fg_inds]  # Filter labels

        # Compute pose targets and weights
        poses_target, poses_weight = _compute_pose_targets(quaternions, labels, num_classes)

    # Convert NumPy arrays back to PyTorch tensors
    return (
        torch.tensor(rois_np, device=rois.device, dtype=torch.float32),
        torch.tensor(poses_target, device=rois.device, dtype=torch.float32),
        torch.tensor(poses_weight, device=rois.device, dtype=torch.float32),
    )

def _compute_pose_targets(quaternions, labels, num_classes):
    """
    Compute pose regression targets for an image.

    quaternions: Ground truth quaternions (rotation information).
    labels: Class labels for each RoI.
    num_classes: Number of classes.
    """
    num = quaternions.shape[0]  # Number of RoIs
    poses_target = np.zeros((num, 4 * num_classes), dtype=np.float32)  # Initialize targets
    poses_weight = np.zeros((num, 4 * num_classes), dtype=np.float32)  # Initialize weights

    for i in range(num):
        cls = labels[i]  # Class of this RoI
        if cls > 0 and np.linalg.norm(quaternions[i, :]) > 0:  # Skip invalid quaternions
            start = int(4 * cls)  # Start index for this class
            poses_target[i, start:start+4] = quaternions[i]  # Assign quaternion
            poses_weight[i, start:start+4] = 1.0  # Assign weight for this class

    return poses_target, poses_weight

In [None]:
# PoseCNN Architecture
import torch
import torch.nn as nn
import torchvision.models as models
import math
import sys
import copy
from torch.nn.init import kaiming_normal_
from layers.hard_label import HardLabel
from layers.hough_voting import HoughVoting
from layers.roi_pooling import RoIPool
from layers.point_matching_loss import PMLoss
from layers.roi_target_layer import roi_target_layer
from layers.pose_target_layer import pose_target_layer
from fcn.config import cfg
 
def log_softmax_high_dimension(input):
    """
    Compute the log softmax over a high-dimensional input tensor for numerical stability.
    
    This function calculates:
    log_softmax(x) = x_i - log(Σ_j exp(x_j))
    
    To ensure numerical stability, the maximum value `m` is subtracted from `input` before
    computing the exponential to prevent overflow issues.
    
    Args:
        input (torch.Tensor): Input tensor of shape (N, C, H, W) or (N, C), where N is the batch size,
                              C is the number of classes, and H, W are the height and width of the input.
        
    Returns:
        torch.Tensor: Log softmax of the input tensor, same shape as input.
    """
       
    num_classes = input.size()[1] # 입력 텐서의 클래스 수 가져오기.
    m = torch.max(input, dim=1, keepdim=True)[0] # 클래스 차원에서 최대값 추출 (수치 안전성 확보) + 차원 유지 -> 결과 (N,1,H,W)
    
    if input.dim() == 4: # (channel, class, height, width)
        d = input - m.repeat(1, num_classes, 1, 1) 
    else: # (channel, class)
        d = input - m.repeat(1, num_classes)
    e = torch.exp(d)
    s = torch.sum(e, dim=1, keepdim=True) # 지수 값의 합을 구함(class로의 합산)
    if input.dim() == 4:
        output = d - torch.log(s.repeat(1, num_classes, 1, 1)) # 로그 소프트맥스 계산 
    else:
        output = d - torch.log(s.repeat(1, num_classes))
    return output


def softmax_high_dimension(input):
    """
    Compute the softmax over a high-dimensional input tensor for numerical stability.
    
    This function calculates:
    softmax(x) = exp(x_i - m) / Σ_j exp(x_j - m)
    
    Similar to `log_softmax_high_dimension`, the maximum value `m` is subtracted from `input`
    for numerical stability during the exponential calculation.
    
    Args:
        input (torch.Tensor): Input tensor of shape (N, C, H, W) or (N, C), where N is the batch size,
                              C is the number of classes, and H, W are the height and width of the input.
        
    Returns:
        torch.Tensor: Softmax of the input tensor, same shape as input.
    """

    num_classes = input.size()[1]
    m = torch.max(input, dim=1, keepdim=True)[0]
    if input.dim() == 4:
        e = torch.exp(input - m.repeat(1, num_classes, 1, 1))
    else:
        e = torch.exp(input - m.repeat(1, num_classes))
    s = torch.sum(e, dim=1, keepdim=True)
    if input.dim() == 4:
        output = torch.div(e, s.repeat(1, num_classes, 1, 1))
    else:
        output = torch.div(e, s.repeat(1, num_classes))
    return output

def conv(in_planes, out_planes, kernel_size=3, stride=1, relu=True):
    if relu:
        return nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True),
            nn.ReLU(inplace=True))
    else:
        return nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True)


def fc(in_planes, out_planes, relu=True):
    if relu:
        return nn.Sequential(
            nn.Linear(in_planes, out_planes),
            nn.LeakyReLU(0.1, inplace=True))
    else:
        return nn.Linear(in_planes, out_planes)


def upsample(scale_factor):
    return nn.Upsample(scale_factor=scale_factor, mode='bilinear')


class PoseCNN(nn.Module):
    """
    PoseCNN network class for 6D pose estimation.
    This model is based on a VGG16 backbone and includes branches for semantic labeling, vertex regression,
    and pose estimation through RoI pooling and Hough voting.
    """
    def __init__(self, num_classes, num_units):

        """
        Initializes the PoseCNN model with feature extraction and specific branches for pose estimation.
        The model adapts a VGG16 backbone and adds custom layers for semantic labeling and vertex regression.

        Args:
            num_classes (int): Number of classes for classification and pose estimation.
            num_units (int): Number of feature embedding units in the intermediate layers.
        """
        super(PoseCNN, self).__init__()
        self.num_classes = num_classes 

        # conv features
        features = list(vgg16.features)[:30]
        
        # change the first conv layer for RGBD
        if cfg.INPUT == 'RGBD': # RGBD: 4-channels
            # Copy weights from the original RGB input to extend to RGBD -> first feature input channel 3 -> 6
            conv0 = conv(6, 64, kernel_size=3, relu=False)
            conv0.weight.data[:, :3, :, :] = features[0].weight.data # RGB
            conv0.weight.data[:, 3:, :, :] = features[0].weight.data # Depth
            conv0.bias.data = features[0].bias.data
            features[0] = conv0

        self.features = nn.ModuleList(features) # store the modified feature extraction layers to allow iteration during forward
        self.classifier = vgg16.classifier[:-1] # using vgg16 classifier 
        if cfg.TRAIN.SLIM:
            dim_fc = 256
            self.classifier[0] = nn.Linear(512*7*7, 256)
            self.classifier[3] = nn.Linear(256, 256)
        else:
            dim_fc = 4096
            
        print(self.features)
        print(self.classifier)

        # freeze some layers
        if cfg.TRAIN.FREEZE_LAYERS:
            for i in [0, 2, 5, 7, 10, 12, 14]: # for transfer-learning 
                self.features[i].weight.requires_grad = False
                self.features[i].bias.requires_grad = False

        # semantic labeling branch
        self.conv4_embed = conv(512, num_units, kernel_size=1)
        self.conv5_embed = conv(512, num_units, kernel_size=1)
        self.upsample_conv5_embed = upsample(2.0) # 2배 해서 conn 4 feature와 합칠려고
        self.upsample_embed = upsample(8.0) # 8배 해서 원래 이미지 크기 맞추기 위해
        self.conv_score = conv(num_units, num_classes, kernel_size=1) # 각 픽셀의 class 점수 계산 (1x1 convolution)
        self.hard_label = HardLabel(threshold=cfg.TRAIN.HARD_LABEL_THRESHOLD, sample_percentage=cfg.TRAIN.HARD_LABEL_SAMPLING) # hard label 가중치 부여 일반화 능력 ↑  
        self.dropout = nn.Dropout()

        if cfg.TRAIN.VERTEX_REG:
            # center regression branch
            self.conv4_vertex_embed = conv(512, 2*num_units, kernel_size=1, relu=False)
            self.conv5_vertex_embed = conv(512, 2*num_units, kernel_size=1, relu=False)
            self.upsample_conv5_vertex_embed = upsample(2.0)
            self.upsample_vertex_embed = upsample(8.0)
            self.conv_vertex_score = conv(2*num_units, 3*num_classes, kernel_size=1, relu=False) # 3D translation Estimation (x,y,distance)
            # hough voting
            self.hough_voting = HoughVoting(is_train=0, skip_pixels=10, label_threshold=100, \
                                            inlier_threshold=0.9, voting_threshold=-1, per_threshold=0.01)

            self.roi_pool_conv4 = RoIPool(pool_height=7, pool_width=7, spatial_scale=1.0 / 8.0)
            self.roi_pool_conv5 = RoIPool(pool_height=7, pool_width=7, spatial_scale=1.0 / 16.0)
            self.fc8 = fc(dim_fc, num_classes)
            self.fc9 = fc(dim_fc, 4 * num_classes, relu=False) # 쿼터니언 4차원

            if cfg.TRAIN.POSE_REG:
                self.fc10 = fc(dim_fc, 4 * num_classes, relu=False) 
                self.pml = PMLoss(hard_angle=cfg.TRAIN.HARD_ANGLE)

        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                kaiming_normal_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()


    def forward(self, x, label_gt, meta_data, extents, gt_boxes, poses, points, symmetry):
        
        """
        Defines the forward pass of the PoseCNN network.
        This method processes input through the feature extractor, semantic labeling branch, vertex regression branch,
        and optionally through the bounding box and rotation regression branches.

        Args:
            x (torch.Tensor): Input image tensor.
            label_gt (torch.Tensor): Ground truth labels for semantic segmentation.
            meta_data (torch.Tensor): Metadata for each input.
            extents (torch.Tensor): 3D extents of the object.
            gt_boxes (torch.Tensor): Ground truth bounding boxes.
            poses (torch.Tensor): Ground truth object poses.
            points (torch.Tensor): Points representing the object's 3D structure.
            symmetry (torch.Tensor): Symmetry information for objects.

        Returns:
            Depending on training/testing mode and enabled branches, returns various outputs including
            semantic segmentation logits, vertex predictions, bounding box outputs, and rotation outputs.
        """
        
        # Extract convolutional features from the input using VGG16 layers
        for i, model in enumerate(self.features):
            x = model(x)
            if i == 22:
                out_conv4_3 = x # Store the output after the 4th block (used for semantic labeling and vertex regression)
            if i == 29:
                out_conv5_3 = x # Store the output after the 5th block (used for semantic labeling and vertex regression)


        ##### semantic labeling branch #####
        out_conv4_embed = self.conv4_embed(out_conv4_3)
        out_conv5_embed = self.conv5_embed(out_conv5_3)
        out_conv5_embed_up = self.upsample_conv5_embed(out_conv5_embed)
        out_embed = self.dropout(out_conv4_embed + out_conv5_embed_up) # point-wise add
        out_embed_up = self.upsample_embed(out_embed) # 원본 이미지 크기로 변경
        out_score = self.conv_score(out_embed_up) # 픽셀을 클래스로 분류 
        out_logsoftmax = log_softmax_high_dimension(out_score) # log-softmax
        out_prob = softmax_high_dimension(out_score) # 각 클래스 확률 
        out_label = torch.max(out_prob, dim=1)[1].type(torch.IntTensor).cuda() # out_prob에서 클래스 확률이 가장 높은 인덱스 => 클래스 추출
        out_weight = self.hard_label(out_prob, label_gt, torch.rand(out_prob.size()).cuda()) # Hard Labeling 확률 적은 쪽에 가중치 높게 줌

        if cfg.TRAIN.VERTEX_REG:
            ##### center regression branch #####
            # Extract features for vertex prediction from the conv4 and conv5 feature maps
            out_conv4_vertex_embed = self.conv4_vertex_embed(out_conv4_3)
            out_conv5_vertex_embed = self.conv5_vertex_embed(out_conv5_3)
            out_conv5_vertex_embed_up = self.upsample_conv5_vertex_embed(out_conv5_vertex_embed)
            out_vertex_embed = self.dropout(out_conv4_vertex_embed + out_conv5_vertex_embed_up)
            out_vertex_embed_up = self.upsample_vertex_embed(out_vertex_embed)
            out_vertex = self.conv_vertex_score(out_vertex_embed_up)

            # hough voting
            if self.training:
                self.hough_voting.is_train = 1
                self.hough_voting.label_threshold = cfg.TRAIN.HOUGH_LABEL_THRESHOLD
                self.hough_voting.voting_threshold = cfg.TRAIN.HOUGH_VOTING_THRESHOLD
                self.hough_voting.skip_pixels = cfg.TRAIN.HOUGH_SKIP_PIXELS
                self.hough_voting.inlier_threshold = cfg.TRAIN.HOUGH_INLIER_THRESHOLD
            else:
                self.hough_voting.is_train = 0
                self.hough_voting.label_threshold = cfg.TEST.HOUGH_LABEL_THRESHOLD
                self.hough_voting.voting_threshold = cfg.TEST.HOUGH_VOTING_THRESHOLD
                self.hough_voting.skip_pixels = cfg.TEST.HOUGH_SKIP_PIXELS
                self.hough_voting.inlier_threshold = cfg.TEST.HOUGH_INLIER_THRESHOLD
            out_box, out_pose = self.hough_voting(out_label, out_vertex, meta_data, extents) 

            ##### bounding box classification and regression branch #####
            bbox_labels, bbox_targets, bbox_inside_weights, bbox_outside_weights = roi_target_layer(out_box, gt_boxes)
            # Perform RoI pooling on conv4 and conv5 feature maps using the predicted bounding boxes
            out_roi_conv4 = self.roi_pool_conv4(out_conv4_3, out_box)
            out_roi_conv5 = self.roi_pool_conv5(out_conv5_3, out_box)
            # Combine the pooled features from conv4 and conv5
            out_roi = out_roi_conv4 + out_roi_conv5
            # Flatten the combined RoI features for fully connected layer input
            out_roi_flatten = out_roi.view(out_roi.size(0), -1)
            # Pass the flattened features through the classifier
            out_fc7 = self.classifier(out_roi_flatten)
            out_fc8 = self.fc8(out_fc7)
            out_logsoftmax_box = log_softmax_high_dimension(out_fc8)
            bbox_prob = softmax_high_dimension(out_fc8)
            bbox_label_weights = self.hard_label(bbox_prob, bbox_labels, torch.rand(bbox_prob.size()).cuda())
            bbox_pred = self.fc9(out_fc7)

            ##### rotation regression branch #####
            rois, poses_target, poses_weight = pose_target_layer(out_box, bbox_prob, bbox_pred, gt_boxes, poses, self.training)
            if cfg.TRAIN.POSE_REG:    
                out_qt_conv4 = self.roi_pool_conv4(out_conv4_3, rois)
                out_qt_conv5 = self.roi_pool_conv5(out_conv5_3, rois)
                out_qt = out_qt_conv4 + out_qt_conv5
                out_qt_flatten = out_qt.view(out_qt.size(0), -1)
                out_qt_fc7 = self.classifier(out_qt_flatten)
                out_quaternion = self.fc10(out_qt_fc7)
                # point matching loss
                poses_pred = nn.functional.normalize(torch.mul(out_quaternion, poses_weight))
                if self.training:
                    loss_pose = self.pml(poses_pred, poses_target, poses_weight, points, symmetry)

        if self.training:
            if cfg.TRAIN.VERTEX_REG:
                if cfg.TRAIN.POSE_REG:
                    return out_logsoftmax, out_weight, out_vertex, out_logsoftmax_box, bbox_label_weights, \
                           bbox_pred, bbox_targets, bbox_inside_weights, loss_pose, poses_weight
                else:
                    return out_logsoftmax, out_weight, out_vertex, out_logsoftmax_box, bbox_label_weights, \
                           bbox_pred, bbox_targets, bbox_inside_weights
            else:
                return out_logsoftmax, out_weight
        else:
            if cfg.TRAIN.VERTEX_REG:
                if cfg.TRAIN.POSE_REG:
                    return out_label, out_vertex, rois, out_pose, out_quaternion
                else:
                    return out_label, out_vertex, rois, out_pose
            else:
                return out_label

    def weight_parameters(self):
        return [param for name, param in self.named_parameters() if 'weight' in name]

    def bias_parameters(self):
        return [param for name, param in self.named_parameters() if 'bias' in name]


### Hough Voting
Hough Voting is used to localize the 2D center of objects in an image. Each pixel votes for the object center using a predicted unit vector direction.

### Key Steps:
1. **Direction Prediction:** Each pixel regresses to a unit vector pointing towards the object center.
2. **Voting Process:** Each pixel casts votes for potential object center locations.
3. **Center Selection:** The object center is chosen as the location with the highest accumulated votes.

#### Illustration:
- A pixel at $(x, y)$ votes for points along the ray towards the predicted object center.
- Non-maximum suppression (NMS) is applied to handle multiple object instances.

In [None]:
# Hough voting
import torch
import numpy as np

def hough_voting_cuda_forward(
    bottom_label, bottom_vertex, bottom_meta_data, extents,
    is_train, skip_pixels, label_threshold, inlier_threshold,
    voting_threshold, per_threshold
):
    """
    Hough Voting in Python using PyTorch tensors with CUDA support.

    Args:
        bottom_label: Tensor of shape (batch_size, height, width).
        bottom_vertex: Tensor of shape (batch_size, num_classes * 3, height, width).
        bottom_meta_data: Tensor containing metadata.
        extents: Tensor containing 3D extents of object classes.
        is_train: Boolean indicating training mode.
        skip_pixels: Step size for pixel sampling.
        label_threshold: Minimum number of pixels for a class to be considered.
        inlier_threshold: Threshold for inlier votes.
        voting_threshold: Minimum number of votes for a valid hypothesis.
        per_threshold: Percentage threshold for voting area.

    Returns:
        top_box_final: Final bounding box tensor.
        top_pose_final: Final pose tensor.
    """
    batch_size = bottom_vertex.size(0)
    num_classes = bottom_vertex.size(1) // 3  # Vertex channels divided by 3
    height, width = bottom_vertex.size(2), bottom_vertex.size(3)
    num_meta_data = bottom_meta_data.size(1)

    # Constants
    max_rois = 128
    index_size = max_rois // batch_size

    top_box = torch.zeros((max_rois * 9, 7), device='cuda')
    top_pose = torch.zeros((max_rois * 9, 7), device='cuda')
    num_rois = torch.zeros(1, dtype=torch.int32, device='cuda')

    for batch_index in range(batch_size):
        labelmap = bottom_label[batch_index].to('cuda')
        vertmap = bottom_vertex[batch_index].to('cuda')
        meta_data = bottom_meta_data[batch_index].to('cuda')

        # Step 1: Compute a label index array for each class
        arrays = torch.zeros((num_classes, height * width), dtype=torch.int32, device='cuda')
        array_sizes = torch.zeros((num_classes,), dtype=torch.int32, device='cuda')

        for y in range(height):
            for x in range(width):
                cls = labelmap[y, x]
                if cls > 0:
                    index = y * width + x
                    array_sizes[cls] += 1
                    arrays[cls, array_sizes[cls] - 1] = index

        # Step 2: Compute valid class indexes
        class_indexes = []
        for c in range(1, num_classes):
            if array_sizes[c] > label_threshold:
                class_indexes.append(c)

        if not class_indexes:
            continue

        class_indexes = torch.tensor(class_indexes, dtype=torch.int32, device='cuda')

        # Step 3: Compute Hough space and data
        hough_space = torch.zeros((len(class_indexes), height, width), device='cuda')
        hough_data = torch.zeros((len(class_indexes), height, width, 3), device='cuda')

        for i, cls in enumerate(class_indexes):
            for y in range(height):
                for x in range(width):
                    for j in range(0, array_sizes[cls], skip_pixels):
                        location = arrays[cls, j]
                        px, py = location % width, location // width

                        # Read direction and distance
                        u = vertmap[cls * 3, py, px]
                        v = vertmap[cls * 3 + 1, py, px]
                        d = torch.exp(vertmap[cls * 3 + 2, py, px])

                        # Voting
                        dx, dy = x - px, y - py
                        norm1 = torch.sqrt(u ** 2 + v ** 2)
                        norm2 = torch.sqrt(dx ** 2 + dy ** 2)
                        dot = u * dx + v * dy
                        angle_dist = dot / (norm1 * norm2)

                        if angle_dist > inlier_threshold:
                            hough_space[i, y, x] += 1
                            hough_data[i, y, x, 0] += d

        # Normalize distances
        non_zero_votes = hough_space > 0
        hough_data[..., 0][non_zero_votes] /= hough_space[non_zero_votes]

        # Step 4: Find maxima in Hough space
        maxima = (hough_space > voting_threshold).nonzero(as_tuple=False)

        for m in maxima:
            cls_idx, cy, cx = m
            cls = class_indexes[cls_idx]

            bb_distance = hough_data[cls_idx, cy, cx, 0]
            bb_width = hough_data[cls_idx, cy, cx, 2]
            bb_height = hough_data[cls_idx, cy, cx, 1]

            # Add box
            roi_index = num_rois.item()
            top_box[roi_index, :] = torch.tensor([
                batch_index, cls, cx - bb_width / 2, cy - bb_height / 2,
                cx + bb_width / 2, cy + bb_height / 2, hough_space[cls_idx, cy, cx]
            ], device='cuda')

            num_rois += 1

    # Prepare final outputs
    num_rois = num_rois.item()
    if num_rois == 0:
        num_rois = 1

    top_box_final = top_box[:num_rois]
    top_pose_final = top_pose[:num_rois]

    return top_box_final, top_pose_final


## 6. Refinement with ICP
PoseCNN employs **Iterative Closest Point (ICP)** to refine the initial 6D pose predictions using depth data. This process significantly improves accuracy, particularly for challenging cases with occlusions or symmetric objects.

### Key Steps in Refinement:
1. **Initial Pose Prediction:** The network predicts an initial 6D pose (translation and rotation).
2. **ICP Refinement:**
   - Matches observed depth points to the rendered 3D model points.
   - Minimizes the point-to-plane residual between the observed and predicted depth points.
3. **Final Pose Selection:**
   - Multiple refined poses are evaluated.
   - The pose with the best alignment metric is selected.

In [None]:
# Refinement (ycb_object.py)
import numpy as np
import torch

class PoseRefiner:
    def __init__(self, extents, intrinsic_matrix):
        self._extents = extents
        self._intrinsic_matrix = intrinsic_matrix

    def refine_pose_with_depth(self, im_label, im_depth, cls_indexes, center, poses, classes, num_classes):
        """
        Refines pose using depth information by computing vertex deltas.

        Args:
            im_label: 2D label image (height x width).
            im_depth: Depth image with shape (height x width x 3).
            cls_indexes: List of object class indices.
            center: Center positions of objects.
            poses: List or array of object poses.
            classes: List of all classes.
            num_classes: Number of classes.

        Returns:
            vertex_targets: Vertex deltas for regression.
            vertex_weights: Weights for regression loss.
        """
        # Extract X, Y, Z depth channels
        x_image = im_depth[:, :, 0]
        y_image = im_depth[:, :, 1]
        z_image = im_depth[:, :, 2]

        # Image dimensions
        height, width = im_label.shape

        # Initialize outputs for vertex deltas and weights
        vertex_targets = np.zeros((3 * num_classes, height, width), dtype=np.float32)
        vertex_weights = np.zeros((3 * num_classes, height, width), dtype=np.float32)

        # Iterate through each class to calculate vertex deltas
        for i in range(1, num_classes):
            # Mask for valid depth and corresponding class label
            valid_mask = (z_image != 0.0)  # Ensure depth is valid
            label_mask = (im_label == classes[i])  # Match specific class label
            combined_mask = valid_mask & label_mask

            # Get pixel indices where both masks are true
            y, x = np.where(combined_mask)

            # Check if valid pixels exist for the current class
            ind = np.where(cls_indexes == classes[i])[0]
            if len(x) > 0 and len(ind) > 0:

                # Get object extent and compute half diameter for normalization
                extents_here = self._extents[i, :]
                largest_dim = np.sqrt(np.sum(extents_here**2))
                half_diameter = largest_dim / 2.0

                # Retrieve object center coordinates from poses or center array
                c_x, c_y = center[ind, 0], center[ind, 1]
                if isinstance(poses, list):
                    x_center_coord = poses[int(ind)][0]
                    y_center_coord = poses[int(ind)][1]
                    z_center_coord = poses[int(ind)][2]
                else:
                    x_center_coord = poses[ind, -3]
                    y_center_coord = poses[ind, -2]
                    z_center_coord = poses[ind, -1]

                # Compute vertex deltas normalized by object size
                targets_x = (x_image[y, x] - x_center_coord) / half_diameter
                targets_y = (y_image[y, x] - y_center_coord) / half_diameter
                targets_z = (z_image[y, x] - z_center_coord) / half_diameter

                # Assign vertex deltas to the target array
                vertex_targets[3 * i + 0, y, x] = targets_x
                vertex_targets[3 * i + 1, y, x] = targets_y
                vertex_targets[3 * i + 2, y, x] = targets_z

                # Assign uniform weights for valid regions
                vertex_weights[3 * i + 0, y, x] = 1.0  # Weight for X
                vertex_weights[3 * i + 1, y, x] = 1.0  # Weight for Y
                vertex_weights[3 * i + 2, y, x] = 1.0  # Weight for Z

        # Return the computed vertex targets and weights
        return vertex_targets, vertex_weights



## 7. Results and Visualization
### Quantitative Results
- PoseCNN achieves **state-of-the-art performance** on YCB-Video and OccludedLINEMOD datasets.

### Visualization
Below is an example of semantic labeling and pose estimation.

In [None]:
# Visualization
def _vis_test(inputs, labels, out_label, out_vertex, rois, poses, poses_refined, sample, points, classes, class_colors):
    """
    Visualize a mini-batch for debugging purposes.

    Args:
        inputs (torch.Tensor): Input image tensor of shape (N, C, H, W).
        labels (torch.Tensor): Ground truth labels tensor of shape (N, C, H, W).
        out_label (torch.Tensor): Predicted labels tensor of shape (N, H, W).
        out_vertex (torch.Tensor): Predicted vertex regression tensor.
        rois (torch.Tensor): Region of interest tensor.
        poses (torch.Tensor): Predicted poses tensor.
        poses_refined (torch.Tensor): Refined predicted poses tensor.
        sample (dict): Dictionary containing ground truth poses, metadata, and other information.
        points (numpy.ndarray): 3D points for object models.
        classes (list): List of class names.
        class_colors (list): List of RGB colors for each class.

    Visualizes the following:
        - Input image.
        - Ground truth and predicted labels.
        - Predicted bounding boxes and poses.
        - Refined poses if enabled.
        - Ground truth and predicted vertex targets.

    """
    import matplotlib.pyplot as plt

    im_blob = inputs.cpu().numpy()
    label_blob = labels.cpu().numpy()
    label_pred = out_label.cpu().numpy()
    gt_poses = sample['poses'].numpy()
    meta_data_blob = sample['meta_data'].numpy()
    metadata = meta_data_blob[0, :]
    intrinsic_matrix = metadata[:9].reshape((3, 3))
    gt_boxes = sample['gt_boxes'].numpy()
    extents = sample['extents'][0, :, :].numpy()

    if cfg.TRAIN.VERTEX_REG or cfg.TRAIN.VERTEX_REG_DELTA:
        vertex_targets = sample['vertex_targets'].numpy()
        vertex_pred = out_vertex.detach().cpu().numpy()

    m = 4
    n = 4
    for i in range(im_blob.shape[0]):
        fig = plt.figure()
        start = 1

        # show image
        im = im_blob[i, :, :, :].copy()
        im = im.transpose((1, 2, 0)) * 255.0
        im += cfg.PIXEL_MEANS
        im = im[:, :, (2, 1, 0)]
        im = im.astype(np.uint8)
        ax = fig.add_subplot(m, n, 1)
        plt.imshow(im)
        ax.set_title('color')
        start += 1

        # show gt boxes
        boxes = gt_boxes[i]
        for j in range(boxes.shape[0]):
            if boxes[j, 4] == 0:
                continue
            x1 = boxes[j, 0]
            y1 = boxes[j, 1]
            x2 = boxes[j, 2]
            y2 = boxes[j, 3]
            plt.gca().add_patch(
                plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='g', linewidth=3))

        # show gt label
        label_gt = label_blob[i, :, :, :]
        label_gt = label_gt.transpose((1, 2, 0))
        height = label_gt.shape[0]
        width = label_gt.shape[1]
        num_classes = label_gt.shape[2]
        im_label_gt = np.zeros((height, width, 3), dtype=np.uint8)
        for j in range(num_classes):
            I = np.where(label_gt[:, :, j] > 0)
            im_label_gt[I[0], I[1], :] = class_colors[j]

        ax = fig.add_subplot(m, n, start)
        start += 1
        plt.imshow(im_label_gt)
        ax.set_title('gt labels')

        # show predicted label
        label = label_pred[i, :, :]
        height = label.shape[0]
        width = label.shape[1]
        im_label = np.zeros((height, width, 3), dtype=np.uint8)
        for j in range(num_classes):
            I = np.where(label == j)
            im_label[I[0], I[1], :] = class_colors[j]

        ax = fig.add_subplot(m, n, start)
        start += 1
        plt.imshow(im_label)
        ax.set_title('predicted labels')

        if cfg.TRAIN.VERTEX_REG or cfg.TRAIN.VERTEX_REG_DELTA:

            # show predicted boxes
            ax = fig.add_subplot(m, n, start)
            start += 1
            plt.imshow(im)

            ax.set_title('predicted boxes')
            for j in range(rois.shape[0]):
                if rois[j, 0] != i or rois[j, -1] < cfg.TEST.DET_THRESHOLD:
                    continue
                cls = rois[j, 1]
                x1 = rois[j, 2]
                y1 = rois[j, 3]
                x2 = rois[j, 4]
                y2 = rois[j, 5]
                plt.gca().add_patch(
                    plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor=np.array(class_colors[int(cls)]) / 255.0, linewidth=3))

                cx = (x1 + x2) / 2
                cy = (y1 + y2) / 2
                plt.plot(cx, cy, 'yo')

            # Additional visualizations for gt poses, predicted poses, refined poses, and vertex targets omitted for brevity.

        plt.show()

## 8. Conclusion and Limitations
PoseCNN is a robust framework for 6D pose estimation in challenging conditions such as occlusions and symmetry. Its contributions include:
- Introduction of ShapeMatch Loss for symmetric objects.
- Development of the YCB-Video dataset for pose estimation tasks.
- State-of-the-art results on benchmark datasets.

--------------------------------------------------

## Appendix