# DenseFusion Framework
---
This notebook provides an in-depth explanation of the DenseFusion framework, including the training and evaluation pipelines, and detailed descriptions of auxiliary modules.

## 1. Introduction
DenseFusion is a 6D object pose estimation approach that fuses color and depth information in a pixel-wise manner. Unlike methods that treat RGB-D data as a single input, DenseFusion’s architecture handles each modality independently and preserves their native structures, enabling more nuanced feature extraction. It then performs dense pixel-level fusion, leveraging the intrinsic mapping between color and depth data. Finally, instead of relying on time-consuming post-processing ICP, DenseFusion incorporates a differentiable iterative refinement module directly into its pipeline, jointly training it with the main network to achieve robust and efficient pose estimation.

## 2. Dataset

### Dataset: YCB-Video
This dataset contains **133,827 RGB-D frames** across 92 videos. It provides:
- 21 objects with 6D pose annotations.
- Severe occlusions and symmetric objects.

Dataset URL: [YCB-Video Dataset](https://rse-lab.cs.washington.edu/projects/posecnn/)


## 3. Training Pipeline
The `train.py` script implements the training process for DenseFusion. The steps include:

1. **Dataset Preparation**: Load and preprocess the YCB dataset.
2. **Model Initialization**: Define `PoseNet` for initial pose estimation and `PoseRefineNet` for iterative refinement.
3. **Optimization**: Train the model using pose estimation and refinement loss functions.
4. **Training Loop**: Perform forward pass, compute losses, backpropagate, and update weights.


#### 3-1. Dataset Preparation

- Input Data

The following input files are required for each data sample:

| **File Type**    | **Description**                                                                                |
|-------------------|------------------------------------------------------------------------------------------------|
| `-color.png`      | RGB image containing objects in the scene.                                                    |
| `-depth.png`      | Depth map representing the distance of each pixel in the scene from the camera.               |
| `-label.png`      | Label map where each pixel indicates the class ID of the object it belongs to.                |
| `-meta.mat`       | Metadata containing object poses (rotation, translation) and intrinsic camera parameters.     |

---

- Processing Steps

The following steps are applied to prepare the dataset for training and evaluation:

1. **Load Data**  
   Load RGB (`-color.png`), depth (`-depth.png`), label (`-label.png`), and metadata (`-meta.mat`) files for the selected index.

2. **Camera Parameter Selection**  
   Choose the intrinsic camera parameters (`cx`, `cy`, `fx`, `fy`) based on the data type (real or synthetic).

3. **Background Mask Creation**  
   Create a mask for the background by detecting pixels in the label map with a value of `0`.

4. **Synthetic Noise Augmentation (Optional)**  
   Overlay random objects from the synthetic dataset onto the scene by combining masks and labels.

5. **Object Selection**  
   Randomly select an object (`obj[idx]`) from the metadata and compute its mask (`mask_label`) by combining label and depth masks.

6. **Bounding Box Calculation**  
   Compute the bounding box for the selected object using `get_bbox(mask_label)` to crop the relevant region.

7. **Crop and Normalize Image**  
   Crop the RGB image based on the bounding box and apply color normalization for input to the model.

8. **3D Point Cloud Generation**  
   Convert depth values within the bounding box to 3D points using the camera parameters (`cx`, `cy`, `fx`, `fy`).

9. **Noise in Point Cloud (Optional)**  
   Add Gaussian noise or translation noise to the 3D point cloud for augmentation.

10. **Model Point Sampling**  
    Sample a subset of 3D points from the CAD model of the selected object.

11. **Ground-Truth Pose Transformation**  
    Apply the ground-truth pose (rotation matrix `target_r` and translation vector `target_t`) to the sampled model points.

12. **Point Selection**  
    Select a fixed number of points (`num_pt`) from the 3D point cloud. Pad or shuffle if necessary to maintain consistency.

---

- Output Data

The processed dataset outputs the following information:

| **Name**         | **Shape**            | **Description**                                                                                  |
|-------------------|----------------------|--------------------------------------------------------------------------------------------------|
| `cloud`          | `(num_pt, 3)`        | 3D point cloud representing the object's position in the scene.                                  |
| `choose`         | `(num_pt,)`          | Indices of valid points selected from the 3D point cloud.                                        |
| `img_masked`     | `(3, H, W)`          | Cropped and normalized RGB image containing only the selected object.                            |
| `target`         | `(num_pt, 3)`        | 3D target points for the object after applying the ground-truth pose.                            |
| `model_points`   | `(num_pt_mesh, 3)`   | 3D model points sampled from the object's CAD model (size depends on `refine` setting).          |
| `class_idx`      | `(1,)`               | Class ID of the selected object (zero-indexed).                                                 |


In [None]:
from pathlib import Path
from PIL import Image
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import numpy.ma as ma
import random
import scipy.io as scio

def get_bbox(label: np.ndarray, img_width: int=480, img_length: int=640) -> tuple[int, int, int, int]:
    """
    Calculate the bounding box for the non-zero region of a label mask.

    This function determines the minimal and maximal rows and columns that enclose 
    the non-zero values in the label array. It then adjusts the bounding box to align 
    with predefined border sizes and ensures the bounding box stays within the image boundaries.

    Args:
        label (np.ndarray): A 2D array representing the label mask where non-zero values 
                            indicate the region of interest.

    Returns:
        tuple[int, int, int, int]: The adjusted bounding box as (rmin, rmax, cmin, cmax),
                                   representing the top, bottom, left, and right boundaries.
    """
    # Crop 경계값 후보 목록 설정 (불러온 label 파일에 있는 물체에 대해)
    border_list = [-1, 40, 80, 120, 160, 200, 240, 280, 320, 360, 400, 440, 480, 520, 560, 600, 640, 680]
    rows, cols = np.any(label, axis=1), np.any(label, axis=0)
    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]
    rmax, cmax = rmax + 1, cmax + 1 # 끝점 포함
    
    # 현재 경계 크기 계산
    r_b, c_b = rmax - rmin, cmax - cmin
    # 설정된 경계 목록에서 가장 가까운 경계값으로 확장
    for tt in range(len(border_list)):
        if r_b > border_list[tt] and r_b < border_list[tt + 1]:
            r_b = border_list[tt + 1]
            break
    for tt in range(len(border_list)):
        if c_b > border_list[tt] and c_b < border_list[tt + 1]:
            c_b = border_list[tt + 1]
            break
    
    # 중심을 기준으로 새로운 rmin, rmax, cmin, cmax 계산
    center = [(rmin + rmax) // 2, (cmin + cmax) // 2]
    rmin, rmax = center[0] - int(r_b // 2), center[0] + int(r_b // 2)
    cmin, cmax = center[1] - int(c_b // 2), center[1] + int(c_b // 2)

    # 이미지 크기 벗어날 때 조정
    if rmin < 0:
        delta = -rmin
        rmin, rmax = 0, rmax + delta
    if cmin < 0:
        delta = -cmin
        cmin, cmax = 0, cmax + delta
    if rmax > img_width:
        delta = rmax - img_width
        rmin, rmax = max(0, rmin - delta), img_width
    if cmax > img_length:
        delta = cmax - img_length
        cmin, cmax = max(0, cmin - delta), img_length

    return rmin, rmax, cmin, cmax


class PoseDataset(Dataset):
    def __init__(self, mode: str, num_pt: int, add_noise: bool, root: str, noise_trans: float, refine: bool):
        """
        Initialize the PoseDataset class.

        Args:
            mode (str): The mode of the dataset, either 'train' or 'test'.
            num_pt (int): The number of points to sample for each object.
            add_noise (bool): Whether to add noise to the image and point cloud.
            root (str): Root directory of the dataset.
            noise_trans (float): The amount of noise to add to the point cloud.
            refine (bool): Whether to refine the sampled model points.
        """
        self.mode = mode 
        self.num_pt = num_pt 
        self.add_noise = add_noise 
        self.noise_trans = noise_trans
        self.refine = refine
        self.root = Path(root)

        # Load dataset list
        data_list_path = Path(f"datasets/ycb/dataset_config/{mode}_data_list.txt")
        self.list = data_list_path.read_text().splitlines()
        self.real = [line for line in self.list if line.startswith("data/")]
        self.syn = [line for line in self.list if line.startswith("data_syn/")]

        # Load 3D model points for each class
        class_file_path = Path("datasets/ycb/dataset_config/classes.txt")
        self.cld = self._load_classes(class_file_path)

        # Camera intrinsic parameters for two cameras
        self.cam_params = {
            "cam_1": {"cx": 312.9869, "cy": 241.3109, "fx": 1066.778, "fy": 1067.487},
            "cam_2": {"cx": 323.7872, "cy": 279.6921, "fx": 1077.836, "fy": 1078.189},
        }

        # Image transformations
        self.trancolor = transforms.ColorJitter(0.2, 0.2, 0.2, 0.05)
        self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

        # Constants for symmetry, point sampling, and minimum points
        self.symmetry_obj_idx = [12, 15, 18, 19, 20] # Objects with rotation symmetry
        self.num_pt_mesh_small = 500 # N of points for Coarse modles
        self.num_pt_mesh_large = 2600 # N of points for refined models
        self.minimum_num_pt = 50 # Minimum number of valid points in an object

        # Precomputed 2D arrays for pixel coordinates
        # These are used to convert depth information into 3D point cloud coordinates.
        self.xmap = np.array([[j for i in range(640)] for j in range(480)])
        self.ymap = np.array([[i for i in range(640)] for j in range(480)])

        # Minimum object number when operate Synthetic noise
        self.front_num = 2

    def _load_classes(self, class_file_path: Path):
        cld = {}
        with class_file_path.open() as class_file:
            for class_id, class_name in enumerate(class_file, start=1):
                points_path = self.root / f"models/{class_name[:-1]}/points.xyz"
                points = np.loadtxt(points_path)
                cld[class_id] = points
        return cld

    def __getitem__(self, index):
        '''
        Returns:
            tuple: A tuple containing the following processed data:
                - cloud (torch.Tensor): A (num_pt, 3) tensor of 3D point cloud coordinates.
                    Represents the sampled 3D points of the object in the scene.
                - choose (torch.Tensor): A (num_pt,) tensor of selected point indices
                    from the 2D mask, ensuring a consistent number of points.
                - img_masked (torch.Tensor): A (3, H, W) tensor of the cropped and normalized RGB image
                    containing the object of interest.
                - target (torch.Tensor): A (num_pt, 3) tensor of the target 3D points,
                    representing the object transformed by the ground-truth pose.
                - model_points (torch.Tensor): A (num_pt_mesh_small or num_pt_mesh_large, 3) tensor
                    of the object's 3D model points, sampled from the CAD model.
                - obj[idx]-1 (torch.Tensor): A single-element tensor containing the class index (0-based)
                    of the selected object in the current sample.
        '''
        # Load RGB image, depth map, Label map and metadata
        img_path = self.root / f"{self.list[index]}-color.png"
        depth_path = self.root / f"{self.list[index]}-depth.png"
        label_path = self.root / f"{self.list[index]}-label.png"
        meta_path = self.root / f"{self.list[index]}-meta.mat"

        img = Image.open(img_path).convert("RGB")
        depth = np.array(Image.open(depth_path))
        label = np.array(Image.open(label_path))
        meta = scio.loadmat(meta_path)

        # Select camera parameters based on the data type
        cam_params = (
            self.cam_params["cam_2"]
            if self.list[index][:8] != "data_syn" and int(self.list[index][5:9]) >= 60
            else self.cam_params["cam_1"]
        )
        cam_cx, cam_cy, cam_fx, cam_fy = cam_params.values()

        # Generate a mask for the background (0이면 계산에서 제외 ex. 0,1,1 -> --, 1, 1) 이후 --부분만 True가 됨.
        mask_back = ma.getmaskarray(ma.masked_equal(label, 0))

        # Add synthetic noise by overlaying random objects
        add_front = False
        if self.add_noise:
            for _ in range(5): # Try up to 5 random objects
                seed = random.choice(self.syn) # select random object in synthetic dataset
                front = np.array(self.trancolor(Image.open(f"{self.root}/{seed}-color.png").convert("RGB")))
                front = np.transpose(front, (2, 0, 1))
                f_label = np.array(Image.open(f"{self.root}/{seed}-label.png"))
                front_label = np.unique(f_label).tolist()[1:] # 0(배경)을 제외한 객체 ID 목록 반환
                if len(front_label) < self.front_num: # front_num(2개)보다 적게 있으면 현재 label 폐기
                    continue
                front_label = random.sample(front_label, self.front_num) # front_num(2개)개의 객체 선택 
                for f_i in front_label: # 선택된 객체의 마스크 생성 
                    mk = ma.getmaskarray(ma.masked_not_equal(f_label, f_i))
                    if f_i == front_label[0]:
                        mask_front = mk
                    else:
                        mask_front = mask_front * mk # 겹치는 영역 제거 및 선택된 객체 마스크 결합 (0 * 1 => 0)
                t_label = label * mask_front
                if len(t_label.nonzero()[0]) > 1000:
                    label = t_label
                    add_front = True
                    break

        # Retrieve the class IDs of all objects in the scene
        obj = meta['cls_indexes'].flatten().astype(np.int32)

        # Randomly select a valid object and compute its mask
        while True:
            idx = np.random.randint(0, len(obj))
            mask_depth = ma.getmaskarray(ma.masked_not_equal(depth, 0))
            mask_label = ma.getmaskarray(ma.masked_equal(label, obj[idx]))
            mask = mask_label * mask_depth
            if len(mask.nonzero()[0]) > self.minimum_num_pt:
                break

        # Apply color jitter if noise is enabled
        if self.add_noise:
            img = self.trancolor(img)

        # Compute the bounding box for the object (하나의 label에 대해 crop)
        rmin, rmax, cmin, cmax = get_bbox(mask_label)
        img = np.transpose(np.array(img)[:, :, :3], (2, 0, 1))[:, rmin:rmax, cmin:cmax]

        # Synthesize the background if using synthetic data
        if self.list[index][:8] == "data_syn":
            seed = random.choice(self.real)
            back = np.array(self.trancolor(Image.open(f"{self.root}/{seed}-color.png").convert("RGB")))
            back = np.transpose(back, (2, 0, 1))[:, rmin:rmax, cmin:cmax] # 배경 이미지 크롭하여 현재 객체의 bbox에 맞춤
            img_masked = back * mask_back[rmin:rmax, cmin:cmax] + img
        else:
            img_masked = img


        if self.add_noise and add_front:
            # img_masked = (현재 이미지 객체 부분) + (front 이미지의 배경 부분) => Cutmix와 비슷
            img_masked = img_masked * mask_front[rmin:rmax, cmin:cmax] + front[:, rmin:rmax, cmin:cmax] * ~(mask_front[rmin:rmax, cmin:cmax])

        if self.list[index][:8] == "data_syn":
            # 합성 데이터에 대해서 noise 추가
            img_masked = img_masked + np.random.normal(loc=0.0, scale=7.0, size=img_masked.shape)

        # Extract depth information
        depth_masked = depth[rmin:rmax, cmin:cmax].flatten()[mask.flatten().nonzero()[0]]
        xmap_masked = self.xmap[rmin:rmax, cmin:cmax].flatten()[mask.flatten().nonzero()[0]]
        ymap_masked = self.ymap[rmin:rmax, cmin:cmax].flatten()[mask.flatten().nonzero()[0]]

        # Compute 3D Point Cloud (카메라 좌표 기준, Z: pt2, X: pt0, Y: pt1)
        # Z = (측정 거리) / 기준 거리, X = (x - cx) * Z / fx, Y = (y - cy) * Z / fy
        cam_scale = meta['factor_depth'][0][0]
        pt2 = depth_masked / cam_scale
        pt0 = (ymap_masked - cam_cx) * pt2 / cam_fx
        pt1 = (xmap_masked - cam_cy) * pt2 / cam_fy
        cloud = np.stack((pt0, pt1, pt2), axis=-1)

        # Add noise to the point cloud
        if self.add_noise:
            add_t = np.random.uniform(-self.noise_trans, self.noise_trans, 3)
            cloud += add_t

        # Sample model points for the object (3d 포인트 데이터가 매우 클 때 적절히 샘플링, refine하면 mesh_large만큼 남기고 아니면 mesh_small만 남김)
        dellist = [j for j in range(0, len(self.cld[obj[idx]]))]
        if self.refine:
            dellist = random.sample(dellist, len(self.cld[obj[idx]]) - self.num_pt_mesh_large)
        else:
            dellist = random.sample(dellist, len(self.cld[obj[idx]]) - self.num_pt_mesh_small)
        model_points = np.delete(self.cld[obj[idx]], dellist, axis=0)
        
        # Compute target transformation(4x4matrix, target_r: rotation, target_t: translation)
        target_r = meta['poses'][:, :, idx][:, :3]
        target_t = meta['poses'][:, :, idx][:, 3:4].flatten()
        target = np.dot(model_points, target_r.T)
        if self.add_noise:
            target = np.add(target, target_t + add_t)
        else:
            target = np.add(target, target_t) + target_t        

        # mask 값에서 0이 아닌 유효한 포인트의 인덱스 반환 (크기 부족하면 padding)
        choose = mask[rmin:rmax, cmin:cmax].flatten().nonzero()[0]
        if len(choose) > self.num_pt:
            c_mask = np.zeros(len(choose), dtype=int)
            c_mask[:self.num_pt] = 1
            np.random.shuffle(c_mask)
            choose = choose[c_mask.nonzero()]
        else:
            choose = np.pad(choose, (0, self.num_pt - len(choose)), 'wrap')

        return (
            torch.tensor(cloud.astype(np.float32)),
            torch.tensor(choose.astype(np.int32), dtype=torch.long),
            self.norm(torch.tensor(img_masked.astype(np.float32))),
            torch.tensor(target.astype(np.float32)),
            torch.tensor(model_points.astype(np.float32)),
            torch.tensor([obj[idx] - 1], dtype=torch.long),
        )
        
    def __len__(self):
        return len(self.list)

    def get_sym_list(self):
        return self.symmetry_obj_idx

    def get_num_points_mesh(self):
        return self.num_pt_mesh_large if self.refine else self.num_pt_mesh_small


#### 3-2. Model Initialization

PoseNet is the primary network designed to estimate the pose of objects in 3D space. It uses a ResNet-based PSPNet as a backbone for feature extraction and processes the combined features from RGB images and point cloud data.

0. **PSPNet**: Deep learning model designed for pixel-wise semantic segmentation
1. **PoseNet**: The primary network for pose estimation.
2. **PoseRefineNet**: An optional refinement network for improving the initial pose predictions.


In [None]:
# PSPNet 
import torch
from torch import nn
from torch.nn import functional as F
import lib.extractors as extractors  # ResNet과 같은 백본을 제공하는 라이브러리

# PSPModule: Pyramid Scene Parsing Module
class PSPModule(nn.Module):
    def __init__(self, features, out_features=1024, sizes=(1, 2, 3, 6)):
        super(PSPModule, self).__init__()
        # Create a stage for each pooling size
        self.stages = nn.ModuleList([self._make_stage(features, size) for size in sizes])
        # Bottleneck layer to combine all features (4 * stages + original)
        self.bottleneck = nn.Conv2d(features * (len(sizes) + 1), out_features, kernel_size=1)
        self.relu = nn.ReLU()

    def _make_stage(self, features, size):
        prior = nn.AdaptiveAvgPool2d(output_size=(size, size))  # Adaptive average pooling (channel 두고 w,h를 줄읾)
        conv = nn.Conv2d(features, features, kernel_size=1, bias=False)  # 1x1 convolution
        return nn.Sequential(prior, conv)

    def forward(self, feats):
        h, w = feats.size(2), feats.size(3)  # Get the height and width of the input
        # Apply each stage and upsample to the original size
        priors = [F.upsample(input=stage(feats), size=(h, w), mode='bilinear') for stage in self.stages] + [feats]
        # Concatenate original features and pyramid features
        bottle = self.bottleneck(torch.cat(priors, 1))
        return self.relu(bottle)  # Apply ReLU activation
    
# PSPUpsample: Upsampling Module for PSPNet
class PSPUpsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(PSPUpsample, self).__init__()
        self.conv = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),  # Upsampling by a factor of 2
            nn.Conv2d(in_channels, out_channels, kernel=3, padding=1),  # 3x3 convolution
            nn.PReLU()  # Parametric ReLU activation
        )

    def forward(self, x):
        return self.conv(x)


# PSPNet: Pyramid Scene Parsing Network
class PSPNet(nn.Module):
    """
    Implements the Pyramid Scene Parsing Network (PSPNet).
    Combines a feature extractor, a PSPModule, and an upsampling module
    to perform pixel-wise classification(segmentation).
    """
    def __init__(self, n_classes=21, sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024,
                    backend='resnet18', pretrained=False):
        """
        Args:
            n_classes: Number of output classes.
            sizes: Pooling sizes for the PSPModule.
            psp_size: Number of channels in the PSPModule's input.
            deep_features_size: Number of channels for the deep feature classifier.
            backend: Backbone model (e.g., 'resnet18', 'resnet50').
            pretrained: Whether to use pretrained weights for the backbone.
        """
        super(PSPNet, self).__init__()
        # Load the feature extractor backend (e.g., resnet18) from extractors
        self.feats = getattr(extractors, backend)(pretrained)
        # PSP module for multi-scale feature aggregation
        self.psp = PSPModule(psp_size, 1024, sizes)
        self.drop_1 = nn.Dropout2d(p=0.3)  # Dropout after the PSP module

        # Upsampling modules for progressively refining the feature map
        self.up_1 = PSPUpsample(1024, 256)
        self.up_2 = PSPUpsample(256, 64)
        self.up_3 = PSPUpsample(64, 64)

        self.drop_2 = nn.Dropout2d(p=0.15)  # Dropout during upsampling
        # Final convolution layer to output class probabilities
        self.final = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=1),  # Reduce to 32 channels
            nn.LogSoftmax(dim=1)  # Logarithm of softmax for multi-class probabilities
        )

        # Optional classifier for deep features (not used in the main segmentation pipeline)
        self.classifier = nn.Sequential(
            nn.Linear(deep_features_size, 256),  # Fully connected layer
            nn.ReLU(),  # Activation
            nn.Linear(256, n_classes)  # Output layer for classification
        )

    def forward(self, x):
        # Extract features using the backbone
        f, class_f = self.feats(x) 
        # Apply the PSP module
        p = self.psp(f)
        p = self.drop_1(p)  # Apply dropout

        # Upsample and refine the feature map
        p = self.up_1(p)
        p = self.drop_2(p)
        p = self.up_2(p)
        p = self.drop_2(p)
        p = self.up_3(p)

        # Compute final pixel-wise class probabilities
        return self.final(p)

**PoseNet**
1. **Input Data**:
   - RGB image (`img`): Shape `(batch_size, 3, height, width)`
   - Point cloud (`x`): Shape `(batch_size, 3, num_points)`
   - Point-pixel mapping (`choose`): Maps point cloud to corresponding RGB pixels.
   - Object indices (`obj`): Category indices for the objects.

2. **Semantic Segmentation Features**:
   - Pass the RGB image through `ModifiedResnet` to obtain semantic segmentation-like features.
   - Output shape: `(batch_size, depth, height, width)`.

3. **Point-Pixel Feature Mapping**:
   - Flatten and gather RGB features (`emb`) corresponding to the points using `choose`.
   - `emb`: Shape `(batch_size, depth, num_points)`.

4. **Point Cloud and RGB Fusion**:
   - Pass the point cloud (`x`) and gathered RGB features (`emb`) to `PoseNetFeat` for fusion.
   - `PoseNetFeat` extracts a dense, fused feature map: Shape `(batch_size, 1408, num_points)`.

5. **Pose Prediction**:
   - Pass the fused feature map through fully connected layers to predict:
     - **Rotation (`rx`)**: Quaternion, shape `(batch_size, num_objects, 4, num_points)`
     - **Translation (`tx`)**: 3D vector, shape `(batch_size, num_objects, 3, num_points)`
     - **Confidence (`cx`)**: Scalar, shape `(batch_size, num_objects, 1, num_points)`

6. **Object-Specific Output**:
   - Extract the pose outputs (`rx`, `tx`, `cx`) for specific objects using the object indices (`obj`).

7. **Return Outputs**:
   - Return rotation, translation, confidence, and detached RGB features for further processing or visualization.

In [None]:
# PoseNet
import torch.nn as nn
import torch.nn.functional as F
from lib.pspnet import PSPNet

# PSPNet Models (Segmentation 수행)
psp_models = {
    'resnet18': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet18'),
    'resnet34': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet34'),
    'resnet50': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet50'),
    'resnet101': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet101'),
    'resnet152': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet152')
}

# Semantic segmentation 결과
class ModifiedResnet(nn.Module):
    def __init__(self, usegpu=True):
        super().__init__()
        # Use PSPNet with resnet18 backbone
        self.model = psp_models['resnet18'.lower()]()
        self.model = nn.DataParallel(self.model)  # Multi-GPU support

    def forward(self, x):
        # Forward pass through the backbone
        return self.model(x)

# Feature Extraction for PoseNet
class PoseNetFeat(nn.Module):
    def __init__(self, num_points):
        """
        Feature extraction module for PoseNet. 
        This module fuses RGB features (emb) and point cloud features (x) at a dense, per-point level.
        
        Args:
            num_points (int): The number of points in the input point cloud.
        """
        super().__init__()
        self.num_points = num_points
        # 포인트 클라우드 및 임베딩 특징 추출기
        self.conv1 = nn.Conv1d(3, 64, 1)  # 3D 포인트 클라우드 -> 64 채널
        self.conv2 = nn.Conv1d(64, 128, 1)  # 64 -> 128 채널
        self.e_conv1 = nn.Conv1d(32, 64, 1)  # 이미지 임베딩 -> 64 채널
        self.e_conv2 = nn.Conv1d(64, 128, 1)  # 64 -> 128 채널
        # 결합된 특징 처리기
        self.conv5 = nn.Conv1d(256, 512, 1)  # 결합된 특징 -> 512 채널
        self.conv6 = nn.Conv1d(512, 1024, 1)  # 512 -> 1024 채널
        # 평균 풀링을 통한 전역 특징 추출
        self.ap1 = nn.AvgPool1d(num_points)

    def forward(self, x, emb):
        x = F.relu(self.conv1(x)) # process point cloud, Shape: (batch_size, 64, )
        emb = F.relu(self.e_conv1(emb)) # process RGB embedding features, Shape: (batch_size, 64, )
        pointfeat_1 = torch.cat((x, emb), dim=1) # Fuse point cloud and RGB features, Shape: (batch_size, 128, )

        # Second layer of feature extraction
        x = F.relu(self.conv2(x)) 
        emb = F.relu(self.e_conv2(emb))
        pointfeat_2 = torch.cat((x, emb), dim=1) # Shape: (batch_size, 256, )

        # process combined featrues
        x = F.relu(self.conv5(pointfeat_2))
        x = F.relu(self.conv6(x))
        
        # Global featrue by average pooling
        ap_x = self.ap1(x).view(-1, 1024, 1).repeat(1, 1, self.num_points)  # Shape: (batch_size, 1024, num_points)

        return torch.cat([pointfeat_1, pointfeat_2, ap_x], 1)  # Shape: (batch_size, 128+256+1024, num_points)

# PoseNet: Main Pose Estimation Network# PoseNet: Main Pose Estimation Network
class PoseNet(nn.Module):
    def __init__(self, num_points, num_obj):
        """
        Main pose estimation network that predicts rotation, translation, and confidence for objects.

        Args:
            num_points (int): The number of points in the point cloud.
            num_obj (int): The number of object categories.
        """
        super().__init__()
        self.num_points = num_points
        self.num_obj = num_obj

        # Submodules: Backbone and Feature Extractor
        self.cnn = ModifiedResnet()  # ResNet-based feature extractor (pixel-wise segmentation)
        self.feat = PoseNetFeat(num_points)  # Point cloud feature extractor

        # Fully connected layers for rotation, translation, and confidence estimation
        self.conv1_r = nn.Conv1d(1408, 640, 1) # For rotation
        self.conv1_t = nn.Conv1d(1408, 640, 1) # For translation
        self.conv1_c = nn.Conv1d(1408, 640, 1) # For confidence
        self.conv2_r = nn.Conv1d(640, 256, 1)
        self.conv2_t = nn.Conv1d(640, 256, 1)
        self.conv2_c = nn.Conv1d(640, 256, 1)
        self.conv3_r = nn.Conv1d(256, 128, 1)
        self.conv3_t = nn.Conv1d(256, 128, 1)
        self.conv3_c = nn.Conv1d(256, 128, 1)
        self.conv4_r = nn.Conv1d(128, num_obj * 4, 1)  # Outputs quaternions
        self.conv4_t = nn.Conv1d(128, num_obj * 3, 1)  # Outputs translations
        self.conv4_c = nn.Conv1d(128, num_obj * 1, 1)  # Outputs confidence scores

    def forward(self, img, x, choose, obj):
        # Generate semantic segmentation-like feature map from the input image
        out_img = self.cnn(img)
        bs, di, _, _ = out_img.size()  # Batch size, Depth, Height, Width

        # Embed image features into the point cloud space
        emb = out_img.view(bs, di, -1) # Flatten spatial dimensions
        choose = choose.repeat(1, di, 1)  # Repeat indices for sampling
        emb = torch.gather(emb, 2, choose).contiguous() # Gather RGB features for selected points

        # Process point cloud data
        x = x.transpose(2, 1).contiguous()  # Transpose to match Conv1d input
        ap_x = self.feat(x, emb) # Fused feature map

        # Compute rotation, translation, and confidence
        rx = F.relu(self.conv1_r(ap_x))
        tx = F.relu(self.conv1_t(ap_x))
        cx = F.relu(self.conv1_c(ap_x))

        rx = F.relu(self.conv2_r(rx))
        tx = F.relu(self.conv2_t(tx))
        cx = F.relu(self.conv2_c(cx))

        rx = F.relu(self.conv3_r(rx))
        tx = F.relu(self.conv3_t(tx))
        cx = F.relu(self.conv3_c(cx))

        # Final outputs reshaped for the given number of objects and points
        rx = self.conv4_r(rx).view(bs, self.num_obj, 4, self.num_points)
        tx = self.conv4_t(tx).view(bs, self.num_obj, 3, self.num_points)
        cx = torch.sigmoid(self.conv4_c(cx)).view(bs, self.num_obj, 1, self.num_points)

        # Select outputs for the given object indices
        b = 0 # 첫 번째 배치
        out_rx = torch.index_select(rx[b], 0, obj[b]) # Rotation
        out_tx = torch.index_select(tx[b], 0, obj[b]) # Translation
        out_cx = torch.index_select(cx[b], 0, obj[b]) # Confidence

        # Transpose for consistency
        out_rx = out_rx.contiguous().transpose(2, 1).contiguous()
        out_cx = out_cx.contiguous().transpose(2, 1).contiguous()
        out_tx = out_tx.contiguous().transpose(2, 1).contiguous()

        return out_rx, out_tx, out_cx, emb.detach()


**PoseRefineNet**
1. **Input Data**:
   - Point cloud (`x`): Shape `(batch_size, 3, num_points)`.
   - Embedded RGB features (`emb`): Shape `(batch_size, 32, num_points)`.
   - Object indices (`obj`): Indices of objects to refine.

2. **Point Cloud and RGB Feature Fusion**:
   - Pass the point cloud and embedded RGB features into `PoseRefineNetFeat`.
   - Extract a global, fused feature vector of shape `(batch_size, 1024)` using average pooling.

3. **Refinement Layers**:
   - Pass the fused feature vector through fully connected layers to refine:
     - **Rotation (`rx`)**: Shape `(batch_size, num_objects, 4)` (quaternion).
     - **Translation (`tx`)**: Shape `(batch_size, num_objects, 3)` (3D vector).

4. **Output**:
   - Return refined rotation (`rx`) and translation (`tx`) predictions for each object.

In [None]:
# PoseRefineNet
# Feature Extraction for PoseRefineNet
class PoseRefineNetFeat(nn.Module):
    def __init__(self, num_points):
        """
        Extracts global features from fused point cloud and RGB data.

        Args:
            num_points (int): The number of points in the input point cloud.
        """
        super().__init__()
        self.num_points = num_points
        # Convolution layers for point cloud features
        self.conv1 = nn.Conv1d(3, 64, 1)  # (3 -> 64)
        self.conv2 = nn.Conv1d(64, 128, 1)  # (64 -> 128)

        # Convolution layers for RGB embedding features
        self.e_conv1 = nn.Conv1d(32, 64, 1)  # (32 -> 64)
        self.e_conv2 = nn.Conv1d(64, 128, 1)  # (64 -> 128)

        # Convolution layers for combined features
        self.conv5 = nn.Conv1d(384, 512, 1)  # (128+256 -> 512)
        self.conv6 = nn.Conv1d(512, 1024, 1)  # (512 -> 1024)

        # Global average pooling
        self.ap1 = nn.AvgPool1d(num_points)

    def forward(self, x, emb):
        # Process point cloud features
        x = F.relu(self.conv1(x))  # Shape: (batch_size, 64, num_points)
        # Process RGB embedding features
        emb = F.relu(self.e_conv1(emb))  # Shape: (batch_size, 64, num_points)
        # Fuse point cloud and RGB features
        pointfeat_1 = torch.cat([x, emb], dim=1)  # Shape: (batch_size, 128, num_points)

        # Second layer of feature extraction
        x = F.relu(self.conv2(x))  # Shape: (batch_size, 128, num_points)
        emb = F.relu(self.e_conv2(emb))  # Shape: (batch_size, 128, num_points)
        # Fuse again
        pointfeat_2 = torch.cat([x, emb], dim=1)  # Shape: (batch_size, 256, num_points)

        # Combine features and extract global features
        pointfeat_3 = torch.cat([pointfeat_1, pointfeat_2], dim=1)  # Shape: (batch_size, 384, num_points)
        x = F.relu(self.conv5(pointfeat_3))  # Shape: (batch_size, 512, num_points)
        x = F.relu(self.conv6(x))  # Shape: (batch_size, 1024, num_points)

        # Global average pooling
        ap_x = self.ap1(x).view(-1, 1024)  # Shape: (batch_size, 1024)

        return ap_x

# PoseRefineNet: Refinement Network
class PoseRefineNet(nn.Module):
    def __init__(self, num_points, num_obj):
        super().__init__()
        self.num_points = num_points
        self.num_obj = num_obj
        self.feat = PoseRefineNetFeat(num_points)

        # Fully connected layers for refining rotation
        self.conv1_r = nn.Linear(1024, 512)
        self.conv2_r = nn.Linear(512, 128)
        self.conv3_r = nn.Linear(128, num_obj * 4)  # Outputs refined quaternions

        # Fully connected layers for refining translation
        self.conv1_t = nn.Linear(1024, 512)
        self.conv2_t = nn.Linear(512, 128)
        self.conv3_t = nn.Linear(128, num_obj * 3)  # Outputs refined translations

    def forward(self, x, emb, obj):
        bs = x.size(0)  # Batch size

        # Transpose point cloud for processing
        x = x.transpose(2, 1).contiguous()  # Shape: (batch_size, num_points, 3)
        # Extract global features (per-pixel fusion)
        ap_x = self.feat(x, emb)  # Shape: (batch_size, 1024)

        # Refine rotation
        rx = F.relu(self.conv1_r(ap_x))  # Shape: (batch_size, 512)
        rx = F.relu(self.conv2_r(rx))  # Shape: (batch_size, 128)
        rx = self.conv3_r(rx).view(bs, self.num_obj, 4)  # Shape: (batch_size, num_objects, 4)

        # Refine translation
        tx = F.relu(self.conv1_t(ap_x))  # Shape: (batch_size, 512)
        tx = F.relu(self.conv2_t(tx))  # Shape: (batch_size, 128)
        tx = self.conv3_t(tx).view(bs, self.num_obj, 3)  # Shape: (batch_size, num_objects, 3)

        return rx, tx



#### 3-3. Optimization (with loss)

**Optimizer**

In [None]:
import torch.optim as optim
opt = parse_arguments()
estimator = PoseNet(num_points=opt.num_points, num_obj=opt.num_objects).cuda()
refiner = PoseRefineNet(num_points=opt.num_points, num_obj=opt.num_objects).cuda()

optimizer = optim.Adam(refiner.parameters() if opt.resume_refinenet else estimator.parameters(), lr=opt.lr)

**Loss**

In [None]:
from torch.nn.modules.loss import _Loss
import torch
import torch.nn.functional as F
from lib.knn.__init__ import KNearestNeighbor

def loss_calculation(pred_r, pred_t, pred_c, target, model_points, idx, points, w, refine, num_point_mesh, sym_list):

    knn = KNearestNeighbor(1) # KNN module for finding 1-closest point (symmetric objects)
    bs, num_p, _ = pred_c.size() 

    # Normalize predicted rotation (크기가 1이 되게 함)
    pred_r = pred_r / torch.norm(pred_r, dim=2, keepdim=True)

    # Convert quaternions to 3x3 rotation matrix
    # https://gofo-coding.tistory.com/entry/Orientation-Rotation
    # Quaternion: [q0 (scalar), q1, q2, q3 (vector parts)]
    # Rotation equation: v' = q * v * q^*
    #   - v: Input vector (extended to a pure quaternion, v = [0, x, y, z])
    #   - q^*: Conjugate of quaternion q
    # Result: v' = R * v, where R is the rotation matrix derived from q
    
    # Rotation matrix R:
    # R = [[1 - 2(q2^2 + q3^2), 2(q1q2 - q0q3),   2(q0q2 + q1q3)],
    #      [2(q1q2 + q0q3),     1 - 2(q1^2 + q3^2), 2(q2q3 - q0q1)],
    #      [2(q1q3 - q0q2),     2(q0q1 + q2q3),   1 - 2(q1^2 + q2^2)]]
    base = torch.cat([
        (1.0 - 2.0 * (pred_r[..., 2]**2 + pred_r[..., 3]**2)).unsqueeze(-1),
        (2.0 * pred_r[..., 1] * pred_r[..., 2] - 2.0 * pred_r[..., 0] * pred_r[..., 3]).unsqueeze(-1),
        (2.0 * pred_r[..., 0] * pred_r[..., 2] + 2.0 * pred_r[..., 1] * pred_r[..., 3]).unsqueeze(-1),
        (2.0 * pred_r[..., 1] * pred_r[..., 2] + 2.0 * pred_r[..., 0] * pred_r[..., 3]).unsqueeze(-1),
        (1.0 - 2.0 * (pred_r[..., 1]**2 + pred_r[..., 3]**2)).unsqueeze(-1),
        (-2.0 * pred_r[..., 0] * pred_r[..., 1] + 2.0 * pred_r[..., 2] * pred_r[..., 3]).unsqueeze(-1),
        (-2.0 * pred_r[..., 0] * pred_r[..., 2] + 2.0 * pred_r[..., 1] * pred_r[..., 3]).unsqueeze(-1),
        (2.0 * pred_r[..., 0] * pred_r[..., 1] + 2.0 * pred_r[..., 2] * pred_r[..., 3]).unsqueeze(-1),
        (1.0 - 2.0 * (pred_r[..., 1]**2 + pred_r[..., 2]**2)).unsqueeze(-1)
    ], dim=-1).view(bs * num_p, 3, 3) # Shape (batch,num_point,4)->(batch,num_point,3,3)

    # Prepare inputs for loss calculation (model_points, target을 입력 데이터와 맞추기 위해 변환)
    # model_points: cad 모델에서 샘플링된 3d 점
    model_points = model_points.view(bs, 1, num_point_mesh, 3).expand(-1, num_p, -1, -1).reshape(bs * num_p, num_point_mesh, 3)
    # target: GT로, 물체의 실제 위치와 회전을 반영한 점 집합
    target = target.view(bs, 1, num_point_mesh, 3).expand(-1, num_p, -1, -1).reshape(bs * num_p, num_point_mesh, 3)
    pred_t = pred_t.view(bs * num_p, 1, 3) # 예측된 translation
    points = points.view(bs * num_p, 1, 3) # 입력 point cloud
    pred_c = pred_c.view(bs * num_p) # 예측된 confidence

    # Apply the rotation and translation to model points (base: 예측 rotation)
    # torch.bmm: (bs*num_p,num_point_mesh,3) x (bs*num_p,3,3) = (bs*num_p,num_point_mesh,3)
    pred = torch.bmm(model_points, base.transpose(2, 1)) + points + pred_t

    # Handle symmetric objects
    if not refine and idx[0].item() in sym_list:
        # 1. 타겟과 예측 점의 차원 재구성
        target = target[0].permute(1, 0).reshape(3, -1) # Shape (bs*num_p,num_point_mesh, 3) -> (3,bs*num_p*num_point_mesh)
        pred = pred.permute(2, 0, 1).reshape(3, -1) # Shape (bs*num_p,num_point_mesh, 3) -> (3,bs*num_p*num_point_mesh)
        
        # 2. KNN을 사용하여 가장 가까운 타겟 점 찾기 
        # knn할 때 (batch_size, num_points, dimension) 해야 함 -> unsqueeze
        # 출력: (num_points_pred, 1)
        inds = knn(target.unsqueeze(0), pred.unsqueeze(0)).squeeze(0) - 1
        
        # 3. 매칭된 타겟 점을 재구성
        target = target[:, inds].reshape(3, bs * num_p, num_point_mesh).permute(1, 2, 0) # Shape (bs*num_p,num_point_mesh,3)
        pred = pred.reshape(3, bs * num_p, num_point_mesh).permute(1, 2, 0) # Shape (bs*num_p,num_point_mesh,3)

    # Calculate loss
    dis = torch.norm(pred - target, dim=2).mean(dim=1) # (batch_size * num_points)
    loss = torch.mean(dis * pred_c - w * torch.log(pred_c + 1e-8))

    # Select the best matching prediction
    pred_c = pred_c.view(bs, num_p)
    dis = dis.view(bs, num_p)
    _, which_max = torch.max(pred_c, dim=1) # return: value, index 각 배치마다 최고 찾기

    t = pred_t[which_max[0]] + points[which_max[0]] # 신뢰도 가장 높은 포인트 기준으로 새로운 변환 벡터
    ori_base = base[which_max[0]].unsqueeze(0) # 신뢰도 가장 높은 포인트 회전 행렬(base)를 선택하여 새로운 기준 좌표계 설정

    # Transform points and targets for refinement (새로운 좌표계로 정렬)
    # 신뢰도가 가장 높은 포인트 위치 원점 설정, 예측된 회전 행렬 기준으로 설정
    new_points = torch.bmm(points - t.unsqueeze(1), ori_base.transpose(2, 1))
    new_target = torch.bmm(target[0] - t.unsqueeze(1), ori_base.transpose(2, 1))

    return loss, dis[0][which_max[0]], new_points.detach(), new_target.detach()

class Loss(_Loss):
    def __init__(self, num_points_mesh, sym_list):
        super().__init__(reduction='mean')
        self.num_pt_mesh = num_points_mesh
        self.sym_list = sym_list

    def forward(self, pred_r, pred_t, pred_c, target, model_points, idx, points, w, refine):
        return loss_calculation(pred_r, pred_t, pred_c, target, model_points, idx, points, w, refine, self.num_pt_mesh, self.sym_list)


**Loss_refine**

In [None]:
from torch.nn.modules.loss import _Loss
import torch
from lib.knn.__init__ import KNearestNeighbor

def loss_calculation(pred_r, pred_t, target, model_points, idx, points, num_point_mesh, sym_list):
    knn = KNearestNeighbor(1)

    # Normalize predicted rotation
    pred_r = pred_r / torch.norm(pred_r, dim=2, keepdim=True)

    # Construct rotation matrices from quaternions
    base = torch.cat([
        (1.0 - 2.0 * (pred_r[..., 2]**2 + pred_r[..., 3]**2)).unsqueeze(-1),
        (2.0 * pred_r[..., 1] * pred_r[..., 2] - 2.0 * pred_r[..., 0] * pred_r[..., 3]).unsqueeze(-1),
        (2.0 * pred_r[..., 0] * pred_r[..., 2] + 2.0 * pred_r[..., 1] * pred_r[..., 3]).unsqueeze(-1),
        (2.0 * pred_r[..., 1] * pred_r[..., 2] + 2.0 * pred_r[..., 0] * pred_r[..., 3]).unsqueeze(-1),
        (1.0 - 2.0 * (pred_r[..., 1]**2 + pred_r[..., 3]**2)).unsqueeze(-1),
        (-2.0 * pred_r[..., 0] * pred_r[..., 1] + 2.0 * pred_r[..., 2] * pred_r[..., 3]).unsqueeze(-1),
        (-2.0 * pred_r[..., 0] * pred_r[..., 2] + 2.0 * pred_r[..., 1] * pred_r[..., 3]).unsqueeze(-1),
        (2.0 * pred_r[..., 0] * pred_r[..., 1] + 2.0 * pred_r[..., 2] * pred_r[..., 3]).unsqueeze(-1),
        (1.0 - 2.0 * (pred_r[..., 1]**2 + pred_r[..., 2]**2)).unsqueeze(-1)
    ], dim=-1).view(-1, 3, 3)

    # Prepare inputs for loss calculation
    model_points = model_points.view(1, 1, num_point_mesh, 3).expand(-1, pred_r.size(1), -1, -1).reshape(-1, num_point_mesh, 3)
    target = target.view(1, 1, num_point_mesh, 3).expand(-1, pred_r.size(1), -1, -1).reshape(-1, num_point_mesh, 3)
    pred_t = pred_t.view(-1, 1, 3)
    pred = torch.bmm(model_points, base.transpose(2, 1)) + pred_t

    # Handle symmetric objects
    if idx[0].item() in sym_list:
        target = target[0].permute(1, 0).reshape(3, -1)
        pred = pred.permute(2, 0, 1).reshape(3, -1)
        inds = knn(target.unsqueeze(0), pred.unsqueeze(0)).squeeze(0) - 1
        target = target[:, inds].reshape(3, pred_r.size(0), num_point_mesh).permute(1, 2, 0)
        pred = pred.reshape(3, pred_r.size(0), num_point_mesh).permute(1, 2, 0)

    # Calculate distance
    dis = torch.mean(torch.norm(pred - target, dim=2), dim=1)

    # Transform points and targets for refinement
    t = pred_t[0]
    ori_base = base[0].unsqueeze(0)
    points = points.view(1, -1, 3)
    new_points = torch.bmm((points - t.unsqueeze(1)), ori_base.transpose(2, 1))
    new_target = torch.bmm((target[0] - t.unsqueeze(1)), ori_base.transpose(2, 1))

    del knn
    return dis, new_points.detach(), new_target.detach()

class Loss_refine(_Loss):
    def __init__(self, num_points_mesh, sym_list):
        super().__init__(reduction='mean')
        self.num_pt_mesh = num_points_mesh
        self.sym_list = sym_list

    def forward(self, pred_r, pred_t, target, model_points, idx, points):
        return loss_calculation(pred_r, pred_t, target, model_points, idx, points, self.num_pt_mesh, self.sym_list)


#### 3-4. Training Loop

In [None]:
import _init_paths
import argparse
import os
import random
import time

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from datasets.ycb.dataset import PoseDataset as PoseDataset_ycb
from datasets.linemod.dataset import PoseDataset as PoseDataset_linemod
from lib.network import PoseNet, PoseRefineNet
from lib.loss import Loss
from lib.loss_refiner import Loss_refine
from lib.utils import setup_logger

def parse_arguments() -> argparse.Namespace:
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(description="DenseFusion Training")
    parser.add_argument('--dataset', type=str, default='ycb', help='Dataset to use: ycb or linemod')
    parser.add_argument('--dataset_root', type=str, default='data/', help='Root directory of the dataset')
    parser.add_argument('--batch_size', type=int, default=8, help='Batch size for training')
    parser.add_argument('--workers', type=int, default=4, help='Number of data loading workers')
    parser.add_argument('--lr', type=float, default=0.0001, help='Initial learning rate')
    parser.add_argument('--lr_rate', type=float, default=0.3, help='Learning rate decay factor')
    parser.add_argument('--w', type=float, default=0.015, help='Loss weight factor')
    parser.add_argument('--w_rate', type=float, default=0.3, help='Weight decay factor')
    parser.add_argument('--decay_margin', type=float, default=0.016, help='Margin for learning rate decay')
    parser.add_argument('--refine_margin', type=float, default=0.013, help='Margin to start refinement training')
    parser.add_argument('--noise_trans', type=float, default=0.03, help='Translation noise range for training data')
    parser.add_argument('--iteration', type=int, default=2, help='Number of refinement iterations')
    parser.add_argument('--epoch', type=int, default=500, help='Number of epochs to train')
    parser.add_argument('--resume_posenet', type=str, default='', help='Path to a pre-trained PoseNet model')
    parser.add_argument('--resume_refinenet', type=str, default='', help='Path to a pre-trained PoseRefineNet model')
    parser.add_argument('--seed', type=int, default=7, help='Random seed Number')
    return parser.parse_args()

def initialize_model(arg: argparse.Namespace) -> tuple[PoseNet, PoseRefineNet, optim.Optimizer]:
    """Initialize models and optimizer."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    estimator = PoseNet(num_points=arg.num_points, num_obj=arg.num_objects).to(device)
    refiner = PoseRefineNet(num_points=arg.num_points, num_obj=arg.num_objects).to(device)

    # Load pre-trained weights if specified
    if arg.resume_posenet:
        estimator.load_state_dict(torch.load(os.path.join(arg.outf, arg.resume_posenet), map_location=device))
    if arg.resume_refinenet:
        refiner.load_state_dict(torch.load(os.path.join(arg.outf, arg.resume_refinenet), map_location=device))

    # Setup optimizer
    optimizer = optim.Adam(refiner.parameters() if arg.resume_refinenet else estimator.parameters(), lr=arg.lr)
    return estimator, refiner, optimizer

def load_dataset(arg: argparse.Namespace, phase: str) -> DataLoader:
    """Load the appropriate dataset."""
    match arg.dataset:
        case 'ycb':
            dataset_class = PoseDataset_ycb
        case 'linemod':
            dataset_class = PoseDataset_linemod
        case _:
            raise ValueError(f"Unsupported dataset: {arg.dataset}")

    dataset = dataset_class(phase, arg.num_points, phase == 'train', arg.dataset_root, arg.noise_trans, arg.refine_start)
    return DataLoader(dataset, batch_size=arg.batch_size if phase == 'train' else 1, shuffle=(phase == 'train'), num_workers=arg.workers)

def test_phase(arg: argparse.Namespace, estimator: PoseNet, refiner: PoseRefineNet, test_loader: DataLoader, criterion: Loss, criterion_refine: Loss_refine) -> float:
    """Perform testing and return the average distance."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    estimator.eval()
    refiner.eval()
    total_dis = 0.0
    test_count = 0

    for data in test_loader:
        points, choose, img, target, model_points, idx = (x.to(device) for x in data)

        # Forward pass for PoseNet
        pred_r, pred_t, pred_c, emb = estimator(img, points, choose, idx)
        _, dis, new_points, new_target = criterion(pred_r, pred_t, pred_c, target, model_points, idx, points, arg.w, arg.refine_start)

        # Refinement iterations
        if arg.refine_start:
            for _ in range(arg.iteration):
                pred_r, pred_t = refiner(new_points, emb, idx)
                dis, new_points, new_target = criterion_refine(pred_r, pred_t, new_target, model_points, idx, new_points)

        total_dis += dis.item()
        test_count += 1

    avg_dis = total_dis / test_count
    print(f"Test complete: Average distance = {avg_dis:.6f}")
    return avg_dis

def main():
    arg = parse_arguments()

    # Seed initialization for reproducibility
    random.seed(arg.seed)
    torch.manual_seed(arg.seed)

    # Dataset-specific parameters
    match arg.dataset:
        case 'ycb':
            arg.num_objects = 21
            arg.num_points = 1000
            arg.outf = 'trained_models/ycb'
            arg.log_dir = 'experiments/logs/ycb'
            arg.repeat_epoch = 1
        case 'linemod':
            arg.num_objects = 13
            arg.num_points = 500
            arg.outf = 'trained_models/linemod'
            arg.log_dir = 'experiments/logs/linemod'
            arg.repeat_epoch = 20
        case _:
            raise ValueError(f"Unknown dataset: {arg.dataset}")

    # Initialize models, datasets, and optimizer
    estimator, refiner, optimizer = initialize_model(arg)
    train_loader = load_dataset(arg, 'train')
    test_loader = load_dataset(arg, 'test')

    criterion = Loss(arg.num_points_mesh, arg.sym_list)
    criterion_refine = Loss_refine(arg.num_points_mesh, arg.sym_list)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Training loop
    best_test = float('inf')
    start_time_str = time.strftime("%Y%m%d_%H%M%S", time.localtime())  # Start time in readable format

    for epoch in range(arg.epoch):
        # Training phase
        estimator.train()
        for data in train_loader:
            # Unpack data and send to device
            points, choose, img, target, model_points, idx = (x.to(device) for x in data)
            
            # Forward pass
            pred_r, pred_t, pred_c, emb = estimator(img, points, choose, idx)
            loss, dis, new_points, new_target = criterion(pred_r, pred_t, pred_c, target, model_points, idx, points, arg.w, arg.refine_start)
            
            # Backward pass
            if arg.refine_start:
                for _ in range(arg.iteration):
                    pred_r, pred_t = refiner(new_points, emb, idx)
                    dis, new_points, new_target = criterion_refine(pred_r, pred_t, new_target, model_points, idx, new_points)
                    dis.backward()
            else:
                loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1} finished.")

        if (epoch + 1) % 5 == 0:
            # Testing phase
            avg_dis = test_phase(arg, estimator, refiner, test_loader, criterion, criterion_refine)
            if avg_dis < best_test:
                best_test = avg_dis
                save_path = f"{arg.outf}/best_model_{start_time_str}_avgdis_{avg_dis:.6f}.pth"
                torch.save(estimator.state_dict(), save_path)
                print(f"New best model saved at {save_path}")

    print("Training and testing complete.")


## 4. Evaluation Pipeline
The `eval.py` script handles evaluation of the DenseFusion framework. Key steps:

1. **Load Pretrained Model**: Initialize and load weights for `PoseNet` and `PoseRefineNet`.
2. **Dataset Preparation**: Load the YCB test data.
3. **Pose Prediction**: Perform inference to predict 6D object poses.
4. **Metric Calculation**: Evaluate performance using ADD and ADD-S metrics.


In [None]:
import argparse
import numpy as np
from pathlib import Path
from PIL import Image
import scipy.io as scio
import numpy.ma as ma
import torch
import torch.nn.functional as F
from torchvision import transforms
from datasets.ycb.dataset import PoseDataset
from lib.network import PoseNet, PoseRefineNet
from lib.transformations import quaternion_matrix, quaternion_from_matrix

def parse_arguments():
    """
    Parse command-line arguments for model paths and dataset settings.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset_root', type=str, default='', help='Dataset root directory')
    parser.add_argument('--model', type=str, default='', help='Path to the PoseNet model')
    parser.add_argument('--refine_model', type=str, default='', help='Path to the PoseRefineNet model')
    return parser.parse_args()

def get_bbox(posecnn_rois, idx, img_width, img_length, border_list):
    """
    Compute bounding box for the given ROI (Region of Interest).
    """
    rmin = int(posecnn_rois[idx][3]) + 1
    rmax = int(posecnn_rois[idx][5]) - 1
    cmin = int(posecnn_rois[idx][2]) + 1
    cmax = int(posecnn_rois[idx][4]) - 1

    # Adjust bounding box size based on `border_list`
    r_b, c_b = rmax - rmin, cmax - cmin
    for tt in range(len(border_list)):
        if r_b > border_list[tt] and r_b < border_list[tt + 1]:
            r_b = border_list[tt + 1]
            break
    for tt in range(len(border_list)):
        if c_b > border_list[tt] and c_b < border_list[tt + 1]:
            c_b = border_list[tt + 1]
            break

    # Center-based adjustment to bounding box
    center = [int((rmin + rmax) / 2), int((cmin + cmax) / 2)]
    rmin, rmax = center[0] - int(r_b / 2), center[0] + int(r_b / 2)
    cmin, cmax = center[1] - int(c_b / 2), center[1] + int(c_b / 2)

    # Ensure bounding box stays within image boundaries
    rmin, rmax = max(rmin, 0), min(rmax, img_width)
    cmin, cmax = max(cmin, 0), min(cmax, img_length)

    return rmin, rmax, cmin, cmax

def initialize_models(arg, num_points, num_obj):
    """
    Initialize PoseNet and PoseRefineNet models with pre-trained weights.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    estimator = PoseNet(num_points=num_points, num_obj=num_obj).to(device)
    estimator.load_state_dict(torch.load(arg.model, map_location=device))
    estimator.eval()

    refiner = PoseRefineNet(num_points=num_points, num_obj=num_obj).to(device)
    refiner.load_state_dict(torch.load(arg.refine_model, map_location=device))
    refiner.eval()

    return estimator, refiner

def prepare_data(img, depth, label, posecnn_rois, idx, itemid, xmap, ymap, cam_params, num_points, border_list):
    """
    Prepare point cloud and masked RGB image for pose estimation.
    """
    # Camera intrinsics
    cam_cx, cam_cy, cam_fx, cam_fy, cam_scale = cam_params['cx'], cam_params['cy'], cam_params['fx'], cam_params['fy'], cam_params['scale']

    # Get bounding box coordinates
    rmin, rmax, cmin, cmax = get_bbox(posecnn_rois, idx, 480, 640, border_list)

    # Mask processing
    mask_depth = ma.getmaskarray(ma.masked_not_equal(depth, 0))
    mask_label = ma.getmaskarray(ma.masked_equal(label, itemid))
    mask = mask_label * mask_depth

    # Select valid points
    choose = mask[rmin:rmax, cmin:cmax].flatten().nonzero()[0]
    if len(choose) > num_points:
        np.random.shuffle(choose)
        choose = choose[:num_points]
    else:
        choose = np.pad(choose, (0, num_points - len(choose)), 'wrap')

    # Compute 3D points from depth and camera intrinsics
    depth_masked = depth[rmin:rmax, cmin:cmax].flatten()[choose][:, None].astype(np.float32)
    xmap_masked = xmap[rmin:rmax, cmin:cmax].flatten()[choose][:, None].astype(np.float32)
    ymap_masked = ymap[rmin:rmax, cmin:cmax].flatten()[choose][:, None].astype(np.float32)

    # make from 2d points to 3d points
    pt2 = depth_masked / cam_scale
    pt0 = (ymap_masked - cam_cx) * pt2 / cam_fx
    pt1 = (xmap_masked - cam_cy) * pt2 / cam_fy
    cloud = np.concatenate((pt0, pt1, pt2), axis=1)

    # Process RGB image for the ROI
    img_masked = np.array(img)[:, :, :3]
    img_masked = np.transpose(img_masked, (2, 0, 1))[:, rmin:rmax, cmin:cmax]

    # Convert to tensors
    cloud = torch.tensor(cloud, dtype=torch.float32).unsqueeze(0).to(cam_params['device'])
    choose = torch.tensor(choose, dtype=torch.int64).unsqueeze(0).to(cam_params['device'])
    img_masked = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(
        torch.tensor(img_masked, dtype=torch.float32).unsqueeze(0)
    ).to(cam_params['device'])

    return cloud, choose, img_masked

def refine_pose(refiner, cloud, emb, index, pred_r, pred_t, num_points, iterations=2):
    """
    Refine pose estimation iteratively using PoseRefineNet.
    """
    device = cloud.device

    for _ in range(iterations):
        # Convert quaternion to rotation matrix
        base = quaternion_matrix(pred_r.cpu().numpy())[:3, :3] # Return homogeneous rotation matrix from quaternion. & 3x3:rotation만
        base = torch.tensor(base, dtype=torch.float32).to(device)

        # Transform point cloud using the predicted pose
        cloud_transformed = torch.bmm(cloud - pred_t.unsqueeze(1), base.unsqueeze(0))

        # Refine pose predictions
        pred_r, pred_t = refiner(cloud_transformed, emb, index)
        pred_r = pred_r / torch.norm(pred_r, dim=2, keepdim=True)  # Normalize quaternion

    return pred_r, pred_t

def _load_classes(class_file_path, dataset_root):
    """
    Load object model points from .xyz files.
    """
    cld = {}
    with class_file_path.open() as class_file:
        for class_id, class_name in enumerate(class_file, start=1):
            points_path = Path(dataset_root) / f"models/{class_name.strip()}/points.xyz"
            points = np.loadtxt(points_path)
            cld[class_id] = points
    return cld

def calculate_add(pred_r, pred_t, model_points, target_points):
    """
    Calculate Average Distance of Model Points (ADD) metric.
    """
    pred_points = np.dot(model_points, pred_r.T) + pred_t
    return np.mean(np.linalg.norm(pred_points - target_points, axis=1))

def calculate_add_s(pred_r, pred_t, model_points, target_points):
    """
    Calculate ADD-S (symmetric) metric.
    """
    # 예측된 포인트 생성
    pred_points = np.dot(model_points, pred_r.T) + pred_t

    # PyTorch를 활용한 ADD-S 계산
    pred_points = torch.from_numpy(pred_points).to(device)
    target_points = torch.from_numpy(target_points).to(device)
    distances = torch.cdist(pred_points.unsqueeze(0), target_points.unsqueeze(0), p=2)
    min_distances = distances.min(dim=2)[0]
    return min_distances.mean().item()

def main():
    arg = parse_arguments()

    # Camera parameters and constants
    cam_params = {
        'cx': 312.9869,
        'cy': 241.3109,
        'fx': 1066.778,
        'fy': 1067.487,
        'scale': 10000.0,
        'device': torch.device("cuda" if torch.cuda.is_available() else "cpu")
    }
    num_obj, num_points, iterations = 21, 1000, 2
    border_list = [-1, 40, 80, 120, 160, 200, 240, 280, 320, 360, 400, 440, 480, 520, 560, 600, 640, 680]

    # Initialize models
    estimator, refiner = initialize_models(arg, num_points, num_obj)

    # Iterate over test samples
    for now in range(2949):
        try:
            # Load data
            img = Image.open(f'{arg.dataset_root}/{now:06d}-color.png')
            depth = np.array(Image.open(f'{arg.dataset_root}/{now:06d}-depth.png'))
            posecnn_meta = scio.loadmat(f'YCB_Video_toolbox/results_PoseCNN_RSS2018/{now:06d}.mat')
            posecnn_rois = np.array(posecnn_meta['rois'])
            label = np.array(posecnn_meta['labels'])
            lst = posecnn_rois[:, 1:2].flatten()  # List of object IDs

            # Store results for each object
            my_result_wo_refine = []
            my_result = []
            add_scores, add_s_scores = [], []

            for idx, itemid in enumerate(lst): # 물체마다
            
                # model point load 
                data_list_path = Path(f"datasets/ycb/dataset_config/test_data_list.txt")
                list = data_list_path.read_text().splitlines()
                meta_path = f'{arg.dataset_root} / {list[itemid]}-meta.mat'
                meta = scio.loadmat(meta_path)
                target_r = meta['poses'][:, :, idx][:, :3]
                target_t = meta['poses'][:, :, idx][:, 3:4].flatten()
                target_points = np.dot(model_points, target_r.T) + target_t
                
                model_path = Path(arg.dataset_root) / f"models/{itemid:02d}/points.xyz"
                if not model_path.exists():
                  print(f"Model point file not found for object {itemid}. Skipping...")
                  break
                model_points = np.loadtxt(model_path)  
                      
                cloud, choose, img_masked = prepare_data(
                    img, depth, label, posecnn_rois, idx, itemid,
                    np.array([[j for i in range(640)] for j in range(480)]),
                    np.array([[i for i in range(640)] for j in range(480)]),
                    cam_params, num_points, border_list
                )

                # Initial pose estimation
                pred_r, pred_t, pred_c, emb = estimator(
                    img_masked, cloud, choose, torch.tensor([itemid - 1], dtype=torch.int64).to(cam_params['device'])
                )
                pred_r = pred_r / torch.norm(pred_r, dim=2, keepdim=True)

                # Save pose without refinement
                result_wo_refine = torch.cat((pred_r[0], pred_t[0].unsqueeze(0)), dim=1).cpu().numpy().flatten()
                my_result_wo_refine.append(result_wo_refine.tolist())

                # Refinement
                pred_r, pred_t = refine_pose(
                    refiner, cloud, emb, torch.tensor([itemid - 1], dtype=torch.int64).to(cam_params['device']), pred_r, pred_t, num_points, iterations
                )
                result_refine = torch.cat((pred_r[0], pred_t[0].unsqueeze(0)), dim=1).cpu().numpy().flatten()
                my_result.append(result_refine.tolist())
                
                pred_r_matrix = quaternion_matrix(pred_r.cpu().numpy())[0, :3, :3]
                pred_t_matrix = pred_t.cpu().numpy()

                add_scores.append(calculate_add(pred_r_matrix, pred_t_matrix, model_points, target_points))
                add_s_scores.append(calculate_add_s(pred_r_matrix, pred_t_matrix, model_points, target_points))

            # Save results
            scio.savemat(f'experiments/results/{now:06d}_wo_refine.mat', {'poses': my_result_wo_refine})
            scio.savemat(f'experiments/results/{now:06d}_refine.mat', {'poses': my_result})

        except Exception as e:
            print(f"Error processing frame {now}: {e}")
