In [1]:
# import 
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "1"
import torch
import torch.nn as nn
import torch.nn.functional as F
import functools
import spconv.pytorch as spconv
from typing import List, Tuple, Union, Optional, Dict, Any
from torch.autograd import Variable
from dataclasses import dataclass, fields
import numpy as np
import copy
from epic_ops.reduce import segmented_reduce
from epic_ops.voxelize import voxelize
from epic_ops.ball_query import ball_query
from epic_ops.ccl import connected_components_labeling
from epic_ops.nms import nms
from epic_ops.reduce import segmented_maxpool
from gapartnet.misc.info import get_symmetry_matrix


In [2]:
# Data class
@dataclass
class Segmentation:
    batch_size: int

    sem_preds: torch.Tensor
    sem_labels: Optional[torch.Tensor] = None
    all_accu: Optional[torch.Tensor] = None
    pixel_accu: Optional[float] = None

@dataclass
class Instances:
    valid_mask: Optional[torch.Tensor] = None
    sorted_indices: Optional[torch.Tensor] = None
    pt_xyz: Optional[torch.Tensor] = None

    batch_indices: Optional[torch.Tensor] = None
    proposal_offsets: Optional[torch.Tensor] = None
    proposal_indices: Optional[torch.Tensor] = None
    num_points_per_proposal: Optional[torch.Tensor] = None

    sem_preds: Optional[torch.Tensor] = None
    pt_sem_classes: Optional[torch.Tensor] = None
    score_preds: Optional[torch.Tensor] = None
    npcs_preds: Optional[torch.Tensor] = None

    sem_labels: Optional[torch.Tensor] = None
    instance_labels: Optional[torch.Tensor] = None
    instance_sem_labels: Optional[torch.Tensor] = None
    num_points_per_instance: Optional[torch.Tensor] = None
    gt_npcs: Optional[torch.Tensor] = None

    npcs_valid_mask: Optional[torch.Tensor] = None

    ious: Optional[torch.Tensor] = None

    cls_preds: Optional[torch.Tensor] = None
    cls_labels: Optional[torch.Tensor] = None
    
    name: Optional[str] = None

@dataclass
class Result:
    xyz: torch.Tensor
    rgb: torch.Tensor
    sem_preds: torch.Tensor
    ins_preds: torch.Tensor
    npcs_preds: torch.Tensor

@dataclass
class PointCloudBatch:
    # basic
    pc_ids: List[str]
    points: torch.Tensor
    batch_indices: torch.Tensor
    batch_size: int
    device: str = None # type: ignore
    
    # voxel
    voxel_tensor: any = None, # type: ignore
    pc_voxel_id: any = None # type: ignore

    # semantic
    sem_labels: torch.Tensor = None # type: ignore
    obj_cls_labels = None
    
    # instance
    instance_labels: Optional[torch.Tensor] = None
    num_instances: Optional[List[int]] = None
    instance_regions: Optional[torch.Tensor] = None
    num_points_per_instance: Optional[torch.Tensor] = None
    instance_sem_labels: Optional[torch.Tensor] = None
    
    #npcs
    gt_npcs: Optional[Union[torch.Tensor, np.ndarray]] = None

@dataclass
class PointCloud:
    pc_id: str

    points: Union[torch.Tensor, np.ndarray]
    
    obj_cat: int = -1

    sem_labels: Optional[Union[torch.Tensor, np.ndarray]] = None
    instance_labels: Optional[Union[torch.Tensor, np.ndarray]] = None

    gt_npcs: Optional[Union[torch.Tensor, np.ndarray]] = None

    # instance number
    num_instances: Optional[int] = None 
    
    # for points in an instance: 0-3: mean_xyz; 3-6: max_xyz; 6-9: min_xyz
    instance_regions: Optional[Union[torch.Tensor, np.ndarray]] = None
    
    # instance points number
    num_points_per_instance: Optional[Union[torch.Tensor, np.ndarray]] = None
    
    # instance semantic label
    instance_sem_labels: Optional[torch.Tensor] = None

    voxel_features: Optional[torch.Tensor] = None
    voxel_coords: Optional[torch.Tensor] = None
    voxel_coords_range: Optional[List[int]] = None
    pc_voxel_id: Optional[torch.Tensor] = None

    def to_dict(self) -> Dict[str, Any]:
        return {
            field.name: getattr(self, field.name)
            for field in fields(self)
        }

    def to_tensor(self) -> "PointCloud":
        return PointCloud(**{
            k: torch.from_numpy(v) if isinstance(v, np.ndarray) else v
            for k, v in self.to_dict().items()
        }) # type: ignore

    def to(self, device: torch.device) -> "PointCloud":
        return PointCloud(**{
            k: v.to(device) if isinstance(v, torch.Tensor) else v
            for k, v in self.to_dict().items()
        }) # type: ignore

    @staticmethod
    def collate(point_clouds: List["PointCloud"]) -> PointCloudBatch: # 这里合并体素化的PointCloud，创建稀疏张量
        """
        将一个点云列表转换为一个 PointCloudBatch 对象。

        Args:
            point_clouds (List[PointCloud]): 包含多个点云的列表。

        Returns:
            PointCloudBatch: 一个包含所有点云信息的 PointCloudBatch 对象。
        """
        batch_size = len(point_clouds)  # 获取点云列表的长度，即批次大小
        device = point_clouds[0].points.device  # 获取点云的设备类型

        # 提取每个点云的id、对象类别标签、点数等信息
        pc_ids = [pc.pc_id for pc in point_clouds]
        cls_labels = torch.tensor([pc.obj_cat for pc in point_clouds])
        num_points = [pc.points.shape[0] for pc in point_clouds]

        # 合并所有点云的点坐标
        points = torch.cat([pc.points for pc in point_clouds], dim=0)
        
        # 生成每个点对应的批次索引
        batch_indices = torch.cat([
            torch.full((pc.points.shape[0],), i, dtype=torch.int32, device=device)
            for i, pc in enumerate(point_clouds)
        ], dim=0)

        # 合并所有点云的语义标签（如果有）
        if point_clouds[0].sem_labels is not None:
            sem_labels = torch.cat([pc.sem_labels for pc in point_clouds], dim=0)
        else:
            sem_labels = None

        # 合并所有点云的实例标签（如果有）
        if point_clouds[0].instance_labels is not None:
            instance_labels = torch.cat([pc.instance_labels for pc in point_clouds], dim=0)
        else:
            instance_labels = None

        # 合并所有点云的gt_npcs（如果有）
        if point_clouds[0].gt_npcs is not None:
            gt_npcs = torch.cat([pc.gt_npcs for pc in point_clouds], dim=0)
        else:
            gt_npcs = None

        # 处理每个点云的实例信息
        if point_clouds[0].num_instances is not None:
            num_instances = [pc.num_instances for pc in point_clouds]
            max_num_instances = max(num_instances)
            num_points_per_instance = torch.zeros(
                batch_size, max_num_instances, dtype=torch.int32, device=device
            )
            instance_sem_labels = torch.full(
                (batch_size, max_num_instances), -1, dtype=torch.int32, device=device
            )
            for i, pc in enumerate(point_clouds):
                num_points_per_instance[i, :pc.num_instances] = pc.num_points_per_instance
                instance_sem_labels[i, :pc.num_instances] = pc.instance_sem_labels
        else:
            num_instances = None
            num_points_per_instance = None
            instance_sem_labels = None

        # 合并所有点云的实例区域信息（如果有）
        if point_clouds[0].instance_regions is not None:
            instance_regions = torch.cat([
                pc.instance_regions for pc in point_clouds
            ], dim=0)
        else:
            instance_regions = None

        # 合并所有点云的体素信息，打标签，手动聚合，因为每一个本质上已经体素化了
        voxel_batch_indices = torch.cat([
            torch.full((
                pc.voxel_coords.shape[0],), i, dtype=torch.int32, device=device
            )
            for i, pc in enumerate(point_clouds)
        ], dim=0)
        voxel_coords = torch.cat([
            pc.voxel_coords for pc in point_clouds
        ], dim=0)
        voxel_coords = torch.cat([
            voxel_batch_indices[:, None], voxel_coords
        ], dim=-1)
        voxel_features = torch.cat([
            pc.voxel_features for pc in point_clouds
        ], dim=0)

        # 创建稀疏卷积张量
        voxel_coords_range = np.max([
            pc.voxel_coords_range for pc in point_clouds
        ], axis=0)
        voxel_tensor = spconv.SparseConvTensor(
            voxel_features, voxel_coords,
            spatial_shape=voxel_coords_range.tolist(),
            batch_size=len(point_clouds),
        )

        # 合并每个点云的体素编号
        pc_voxel_id = []
        num_voxel_offset = 0
        for pc in point_clouds:
            pc.pc_voxel_id[pc.pc_voxel_id >= 0] += num_voxel_offset
            pc_voxel_id.append(pc.pc_voxel_id)
            num_voxel_offset += pc.voxel_coords.shape[0]
        pc_voxel_id = torch.cat(pc_voxel_id, dim=0)

        # 返回PointCloudBatch对象
        return PointCloudBatch(
            pc_ids=pc_ids,
            points=points,
            batch_indices=batch_indices,
            batch_size=batch_size,
            device=device,
            voxel_tensor=voxel_tensor,
            pc_voxel_id=pc_voxel_id,
            sem_labels=sem_labels,
            num_instances=num_instances,
            instance_regions=instance_regions,
            num_points_per_instance=num_points_per_instance,
            instance_sem_labels=instance_sem_labels,
            instance_labels=instance_labels,
            gt_npcs=gt_npcs,
        )

In [3]:
# auxiliary functions
def feature_transform_reguliarzer(trans):
    d = trans.size()[1] # k (bs, k, k)
    I = torch.eye(d)[None, :, :] # no batch size
    if trans.is_cuda:
        I = I.cuda() # to cuda
    loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2, 1)) - I, dim=(1, 2))) # 尽可能满足正交性质
    return loss

def apply_voxelization( # 这里进行体素化
    pc: PointCloud, *, voxel_size: Tuple[float, float, float]
) -> PointCloud:
    pc = copy.copy(pc)

    num_points = pc.points.shape[0]
    pt_xyz = pc.points[:, :3]
    points_range_min = pt_xyz.min(0)[0] - 1e-4
    points_range_max = pt_xyz.max(0)[0] + 1e-4
    voxel_features, voxel_coords, _, pc_voxel_id = voxelize(
        pt_xyz, pc.points,
        batch_offsets=torch.as_tensor([0, num_points], dtype=torch.int64, device = pt_xyz.device),
        voxel_size=torch.as_tensor(voxel_size, device = pt_xyz.device),
        points_range_min=torch.as_tensor(points_range_min, device = pt_xyz.device),
        points_range_max=torch.as_tensor(points_range_max, device = pt_xyz.device),
        reduction="mean",
    )
    assert (pc_voxel_id >= 0).all()

    voxel_coords_range = (voxel_coords.max(0)[0] + 1).clamp(min=128, max=None)

    pc.voxel_features = voxel_features
    pc.voxel_coords = voxel_coords
    pc.voxel_coords_range = voxel_coords_range.tolist()
    pc.pc_voxel_id = pc_voxel_id

    return pc
# 画框框
def segmented_voxelize(
    pt_xyz: torch.Tensor,  # 输入点云的坐标信息，形状为 (N, 3)，N 是点的数量
    pt_features: torch.Tensor,  # 输入点云的特征信息，形状为 (N, C)，C 是特征的维度
    segment_offsets: torch.Tensor,  # 分割信息的偏移量，形状为 (S+1,)，S 是分割的数量，表示每个分割在 pt_xyz 中的起始索引
    segment_indices: torch.Tensor,  # 点云的索引信息，用于指示每个点属于哪个分割，形状为 (N,)
    num_points_per_segment: torch.Tensor,  # 每个分割中的点的数量，形状为 (S,)
    score_fullscale: float,  # 分割的缩放比例的全尺度参数
    score_scale: float,  # 分割的缩放比例的尺度参数
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    # 计算每个分割的起始索引和结束索引
    segment_offsets_begin = segment_offsets[:-1]
    segment_offsets_end = segment_offsets[1:]

    # 计算每个分割的中心坐标
    segment_coords_mean = segmented_reduce(
        pt_xyz, segment_offsets_begin, segment_offsets_end, mode="sum"
    ) / num_points_per_segment[:, None]

    # 中心化点云数据
    centered_points = pt_xyz - segment_coords_mean[segment_indices]

    # 计算每个分割的包围盒的最小坐标和最大坐标
    segment_coords_min = segmented_reduce(
        centered_points, segment_offsets_begin, segment_offsets_end, mode="min"
    )
    segment_coords_max = segmented_reduce(
        centered_points, segment_offsets_begin, segment_offsets_end, mode="max"
    )

    # 根据包围盒计算分割的缩放比例
    segment_scales = 1. / (
        (segment_coords_max - segment_coords_min) / score_fullscale
    ).max(-1)[0] - 0.01
    segment_scales = torch.clamp(segment_scales, min=None, max=score_scale)

    # 计算分割的最小坐标和最大坐标
    min_xyz = segment_coords_min * segment_scales[..., None]
    max_xyz = segment_coords_max * segment_scales[..., None]

    # 使用随机偏移以解决体素化后的点云坐标重叠问题
    range_xyz = max_xyz - min_xyz
    offsets = -min_xyz + torch.clamp(
        score_fullscale - range_xyz - 0.001, min=0
    ) * torch.rand(3, dtype=min_xyz.dtype, device=min_xyz.device) + torch.clamp(
        score_fullscale - range_xyz + 0.001, max=0
    ) * torch.rand(3, dtype=min_xyz.dtype, device=min_xyz.device)
    segment_scales = segment_scales[segment_indices]
    scaled_points = centered_points * segment_scales[..., None]
    scaled_points += offsets[segment_indices]
    scaled_points = scaled_points.cpu()
    pt_features = pt_features.cpu()
    segment_offsets = segment_offsets.cpu()
    # 对点云进行体素化
    voxel_features, voxel_coords, voxel_batch_indices, pc_voxel_id = voxelize(
        scaled_points,
        pt_features,
        batch_offsets=segment_offsets.long(),
        voxel_size=torch.as_tensor([1., 1., 1.], device="cpu", dtype=torch.float32),
        points_range_min=torch.as_tensor([0., 0., 0.], device="cpu", dtype=torch.float32),
        points_range_max=torch.as_tensor([score_fullscale, score_fullscale, score_fullscale], device="cpu", dtype=torch.float32),
        reduction="mean",
    )
    # 更新体素的坐标信息，添加批次索引
    voxel_coords = torch.cat([voxel_batch_indices[:, None], voxel_coords], dim=1)
    voxel_features = voxel_features.to("cuda")
    voxel_coords = voxel_coords.to("cuda")
    pc_voxel_id = pc_voxel_id.to("cuda")
    return voxel_features, voxel_coords, pc_voxel_id
# 聚类
def cluster_proposals(
    pt_xyz: torch.Tensor,
    batch_indices: torch.Tensor,
    batch_offsets: torch.Tensor,
    sem_preds: torch.Tensor,
    ball_query_radius: float,
    max_num_points_per_query: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    根据给定的点云坐标、批次索引、语义预测等信息，使用球形查询进行提案聚类。

    Args:
        pt_xyz (torch.Tensor): 点云的坐标张量，形状为 (N, 3)，其中 N 是点的数量。
        batch_indices (torch.Tensor): 点云的批次索引张量，形状为 (N,)。
        batch_offsets (torch.Tensor): 批次偏移张量，形状为 (B+1,)，其中 B 是批次数。
        sem_preds (torch.Tensor): 点云的语义预测张量，形状为 (N, 9)。
        ball_query_radius (float): 球形查询半径。
        max_num_points_per_query (int): 每个查询的最大点数。

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: 返回聚类后的提案标签张量和对应的排序索引张量。
    """
    device = pt_xyz.device
    index_dtype = batch_indices.dtype

    # 使用球形查询进行点云聚类
    clustered_indices, num_points_per_query = ball_query(
        pt_xyz,
        pt_xyz,
        batch_indices,
        batch_offsets,
        ball_query_radius,
        max_num_points_per_query,
        point_labels=sem_preds,
        query_labels=sem_preds,
    )

    # 构造聚类索引
    ccl_indices_begin = torch.arange(
        pt_xyz.shape[0], dtype=index_dtype, device=device
    ) * max_num_points_per_query # (N * max,) 开始的索引
    ccl_indices_end = ccl_indices_begin + num_points_per_query # 加一轮就是结束索引
    ccl_indices = torch.stack([ccl_indices_begin, ccl_indices_end], dim=1) # (2, N*max) -> (N*max, 2)

    # 执行连通组件标记并对聚类标签进行排序
    cc_labels = connected_components_labeling(
        ccl_indices.view(-1), clustered_indices.view(-1), compacted=False
    )
    sorted_cc_labels, sorted_indices = torch.sort(cc_labels)

    return sorted_cc_labels, sorted_indices


In [4]:
# based blocks
class ResBlock(spconv.SparseModule):
    def __init__(
        self, in_channels: int, out_channels: int, norm_fn: nn.Module, indice_key=None
    ):
        super().__init__()

        if in_channels == out_channels:
            self.shortcut = nn.Identity() # channel 相同就是 x 
        else:
            # assert False
            self.shortcut = spconv.SparseSequential( # feature 层面的全连接
                spconv.SubMConv3d(in_channels, out_channels, kernel_size=1, \
                bias=False),
                norm_fn(out_channels),
            )

        self.conv1 = spconv.SparseSequential(
            spconv.SubMConv3d(
                in_channels, out_channels, kernel_size=3,
                padding=1, bias=False, indice_key=indice_key,
            ),
            norm_fn(out_channels),
        )

        self.conv2 = spconv.SparseSequential(
            spconv.SubMConv3d(
                out_channels, out_channels, kernel_size=3,
                padding=1, bias=False, indice_key=indice_key,
            ),
            norm_fn(out_channels),
        )

    def forward(self, x: spconv.SparseConvTensor) -> spconv.SparseConvTensor:
        shortcut = self.shortcut(x)

        x = self.conv1(x)
        x = x.replace_feature(F.relu(x.features)) # 相当于ReLU

        x = self.conv2(x)
        x = x.replace_feature(F.relu(x.features + shortcut.features))

        return x

class UBlock(nn.Module):
    def __init__(
        self,
        channels: List[int],
        block_fn: nn.Module,
        block_repeat: int,
        norm_fn: nn.Module,
        indice_key_id: int = 1, # 递归计数器
    ):
        super().__init__()

        self.channels = channels

        encoder_blocks = [
            block_fn(
                channels[0], channels[0], norm_fn, indice_key=f"subm{indice_key_id}"
            )
            for _ in range(block_repeat)
        ]
        self.encoder_blocks = spconv.SparseSequential(*encoder_blocks) # 同层次几层

        if len(channels) > 1:
            self.downsample = spconv.SparseSequential(
                spconv.SparseConv3d(
                    channels[0], channels[1], kernel_size=2, stride=2,
                    bias=False, indice_key=f"spconv{indice_key_id}",
                ),
                norm_fn(channels[1]),
                nn.ReLU(),
            )

            self.ublock = UBlock(
                channels[1:], block_fn, block_repeat, norm_fn, indice_key_id + 1
            ) # 这也能递归？？！

            self.upsample = spconv.SparseSequential(
                spconv.SparseInverseConv3d(
                    channels[1], channels[0], kernel_size=2,
                    bias=False, indice_key=f"spconv{indice_key_id}",
                ),
                norm_fn(channels[0]),
                nn.ReLU(),
            )

            decoder_blocks = [
                block_fn(
                    channels[0] * 2, channels[0], norm_fn,
                    indice_key=f"subm{indice_key_id}",
                ),
            ]
            for _ in range(block_repeat -1):
                decoder_blocks.append(
                    block_fn(
                        channels[0], channels[0], norm_fn,
                        indice_key=f"subm{indice_key_id}",
                    )
                )
            self.decoder_blocks = spconv.SparseSequential(*decoder_blocks)

    def forward(self, x: spconv.SparseConvTensor) -> spconv.SparseConvTensor:
        x = self.encoder_blocks(x) # 平层过几次
        shortcut = x

        if len(self.channels) > 1: # 返回条件

            x = self.downsample(x)
            x = self.ublock(x) # 这也能递归？不愧是北大！艺术
            x = self.upsample(x)

            x = x.replace_feature(torch.cat([x.features, shortcut.features],\
                 dim=-1)) # shortcut
            x = self.decoder_blocks(x) # 每层都有decoder_blocks, 因为cut了，所以feature * 2

        return x
    
class SparseUNet(nn.Module):
    def __init__(self, stem: nn.Module, ublock: UBlock):
        super().__init__()

        self.stem = stem
        self.ublock = ublock # 掉了一层壳子

    def forward(self, x):
        if self.stem is not None:
            x = self.stem(x)
        x = self.ublock(x)
        return x

    @classmethod # classmethod是个python特殊的方法
    def build( # 相当于另一个构造函数
        cls,
        in_channels: int,
        channels: List[int],
        block_repeat: int,
        norm_fn: nn.Module,
        without_stem: bool = False,
    ):
        if not without_stem:
            stem = spconv.SparseSequential(
                spconv.SubMConv3d(
                    in_channels, channels[0], kernel_size=3, # 把inchannel和channel对应上
                    padding=1, bias=False, indice_key="subm1",
                ),
                norm_fn(channels[0]),
                nn.ReLU(),
            )
        else:
            stem = spconv.SparseSequential( # 通道一样就不管
                norm_fn(channels[0]),
                nn.ReLU(),
            )

        block = UBlock(channels, ResBlock, block_repeat, norm_fn, \
            indice_key_id=1)

        return SparseUNet(stem, block)

class UBlock_NoSkip(nn.Module):
    def __init__(
        self,
        channels: List[int],
        block_fn: nn.Module,
        block_repeat: int,
        norm_fn: nn.Module,
        indice_key_id: int = 1,
    ):
        super().__init__()

        self.channels = channels

        encoder_blocks = [
            block_fn(
                channels[0], channels[0], norm_fn, indice_key=f"subm{indice_key_id}"
            )
            for _ in range(block_repeat)
        ]
        self.encoder_blocks = spconv.SparseSequential(*encoder_blocks)

        if len(channels) > 1:
            self.downsample = spconv.SparseSequential(
                spconv.SparseConv3d(
                    channels[0], channels[1], kernel_size=2, stride=2,
                    bias=False, indice_key=f"spconv{indice_key_id}",
                ),
                norm_fn(channels[1]),
                nn.ReLU(),
            )

            self.ublock = UBlock(
                channels[1:], block_fn, block_repeat, norm_fn, indice_key_id + 1
            )

            self.upsample = spconv.SparseSequential(
                spconv.SparseInverseConv3d(
                    channels[1], channels[0], kernel_size=2,
                    bias=False, indice_key=f"spconv{indice_key_id}",
                ),
                norm_fn(channels[0]),
                nn.ReLU(),
            )

            decoder_blocks = [
                block_fn(
                    channels[0], channels[0], norm_fn,
                    indice_key=f"subm{indice_key_id}",
                ),
            ]
            for _ in range(block_repeat -1):
                decoder_blocks.append(
                    block_fn(
                        channels[0], channels[0], norm_fn,
                        indice_key=f"subm{indice_key_id}",
                    )
                )
            self.decoder_blocks = spconv.SparseSequential(*decoder_blocks)

    def forward(self, x: spconv.SparseConvTensor) -> spconv.SparseConvTensor:
        x = self.encoder_blocks(x)
        # shortcut = x

        if len(self.channels) > 1:
            x = self.downsample(x)
            x = self.ublock(x)
            x = self.upsample(x)

            # x = x.replace_feature(torch.cat([x.features, shortcut.features],\
            #      dim=-1)) # 注释几行话而已
            x = self.decoder_blocks(x)

        return x

class SparseUNet_NoSkip(nn.Module): # 同理注释
    def __init__(self, stem: nn.Module, ublock: UBlock_NoSkip):
        super().__init__()

        self.stem = stem
        self.ublock = ublock

    def forward(self, x):
        if self.stem is not None:
            x = self.stem(x)
        x = self.ublock(x)
        return x

    @classmethod
    def build(
        cls,
        in_channels: int,
        channels: List[int],
        block_repeat: int,
        norm_fn: nn.Module,
        without_stem: bool = False,
    ):
        if not without_stem:
            stem = spconv.SparseSequential(
                spconv.SubMConv3d(
                    in_channels, channels[0], kernel_size=3,
                    padding=1, bias=False, indice_key="subm1",
                ),
                norm_fn(channels[0]),
                nn.ReLU(),
            )
        else:
            stem = spconv.SparseSequential(
                norm_fn(channels[0]),
                nn.ReLU(),
            )

        block = UBlock(channels, ResBlock, block_repeat, norm_fn, \
            indice_key_id=1)

        return SparseUNet(stem, block)

class STN3d(nn.Module):
    def __init__(self, channel): # channel 看上去应该默认为3
        super(STN3d, self).__init__()
        self.conv1 = torch.nn.Conv1d(channel, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 9)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

    def forward(self, x):
        batchsize = x.size()[0] # (bs, features, points)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x))) # 一维卷积，放大features维度层次
        x = torch.max(x, 2, keepdim=True)[0] # 点归并成最大features
        x = x.view(-1, 1024) # 展平 

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x))) # 连接到256层特征
        x = self.fc3(x) # 9层

        iden = Variable(torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32))).view(1, 9).repeat(
            batchsize, 1) # (bs, 1, 9) #[1 0 0]
        if x.is_cuda: # is_cuda返回0     [0 1 0]
            iden = iden.cuda() #          [0 0 1]
        x = x + iden
        x = x.view(-1, 3, 3) # 预测的是一个单位阵，加上了一个矩阵
        return x

class STNkd(nn.Module):
    def __init__(self, k=64): # 上升到了k维
        super(STNkd, self).__init__()
        self.conv1 = torch.nn.Conv1d(k, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k * k) # 输出是k * k矩阵
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

        self.k = k

    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1, self.k * self.k).repeat(
            batchsize, 1) # k维度单位阵
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1, self.k, self.k)
        return x

class PointNetEncoder(nn.Module):
    def __init__(self, global_feat=True, feature_transform=False, channel=3):
        super(PointNetEncoder, self).__init__()
        self.stn = STN3d(channel) # 3维
        self.conv1 = torch.nn.Conv1d(channel, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.global_feat = global_feat
        self.feature_transform = feature_transform
        if self.feature_transform:
            self.fstn = STNkd(k=64) # 特征也能变换

    def forward(self, x):
        B, D, N = x.size()
        trans = self.stn(x) # 矩阵
        x = x.transpose(2, 1) # 交换 D, N，为了矩阵乘法
        if D > 3: # 分割 features
            feature = x[:, :, 3:]
            x = x[:, :, :3]
        x = torch.bmm(x, trans) # x 位置进行变换
        if D > 3:
            x = torch.cat([x, feature], dim=2)
        x = x.transpose(2, 1) # 变回来
        x = F.relu(self.bn1(self.conv1(x))) # 增广D

        if self.feature_transform:
            trans_feat = self.fstn(x)
            x = x.transpose(2, 1)
            x = torch.bmm(x, trans_feat) # 变换features
            x = x.transpose(2, 1)
        else:
            trans_feat = None

        pointfeat = x # shortcut
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        x = torch.max(x, 2, keepdim=True)[0] # 增广，features取N上面的最大
        x = x.view(-1, 1024) # 展平
        if self.global_feat:
            return x, trans, trans_feat # 返回的本质是1024feature和
        else:
            x = x.view(-1, 1024, 1).repeat(1, 1, N) # (bs, 1024, N) N个是一样的
            return torch.cat([x, pointfeat], 1), trans, trans_feat # 决定是否concat，增广是为了concat

class PointNetSegBackbone(nn.Module):
    def __init__(self, pc_dim, fea_dim):
        super(PointNetSegBackbone, self).__init__()
        self.fea_dim = fea_dim
        self.feat = PointNetEncoder(global_feat=False, feature_transform=True, channel=3+pc_dim)
        self.conv1 = torch.nn.Conv1d(1088, 512, 1) # 1024 + 64 feature位置
        self.conv2 = torch.nn.Conv1d(512, 256, 1)
        self.conv3 = torch.nn.Conv1d(256, 256, 1)
        self.conv4 = torch.nn.Conv1d(256, self.fea_dim, 1) # 干到输出的features
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.bn3 = nn.BatchNorm1d(256)

    def forward(self, x):
        batchsize = x.size()[0]
        n_pts = x.size()[2]
        x, trans, trans_feat = self.feat(x)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.conv4(x) # 给feature降维 
        fea = x.transpose(2,1).contiguous() # D, N 换位
        return fea
        # x = F.log_softmax(x.view(-1,self.k), dim=-1)
        # x = x.view(batchsize, n_pts, self.k)
        # return x, trans_feat

class get_loss(torch.nn.Module):
    def __init__(self, mat_diff_loss_scale=0.001):
        super(get_loss, self).__init__()
        self.mat_diff_loss_scale = mat_diff_loss_scale

    def forward(self, pred, target, trans_feat, weight):
        loss = F.nll_loss(pred, target, weight = weight) # ?
        mat_diff_loss = feature_transform_reguliarzer(trans_feat) # 正交损失
        total_loss = loss + mat_diff_loss * self.mat_diff_loss_scale # 你也没返回loss啊

class PointNetBackbone(nn.Module): # 这个就是把pointnet包调出来
    def __init__(
        self,
        pc_dim: int,
        feature_dim: int,
    ):
        super().__init__()
        self.pc_dim = pc_dim
        self.feature_dim = feature_dim
        self.backbone = PointNetSegBackbone(self.pc_dim,self.feature_dim)
    
    def forward(self, input_pc):
        others = {}
        return self.backbone(input_pc), others

In [8]:
# data
pc_xyzs = [torch.randn(1024, 3), torch.randn(512, 3)]
feats = [torch.randn(1024, 6), torch.randn(512, 6)]
voxel_feats_list = []
voxel_coords_list = []
voxel_coords_range_list = []
for pc_xyz, feat in zip(pc_xyzs, feats):
    points_range_min = pc_xyz.min(0)[0] - 1e-4
    points_range_max = pc_xyz.max(0)[0] + 1e-4
    num_points = pc_xyz.shape[0]
    voxel_features, voxel_coords, _, pc_voxel_id = voxelize(
        pc_xyz.cuda(), feat.cuda(),
        batch_offsets=torch.as_tensor([0, num_points], dtype=torch.int64, device = "cuda"),
        voxel_size=torch.as_tensor([0.01,0.01,0.01], device = "cuda"),
        points_range_min=torch.as_tensor(points_range_min, device = "cuda"),
        points_range_max=torch.as_tensor(points_range_max, device = "cuda"),
        reduction="mean",
    )
    voxel_coords_range = (voxel_coords.max(0)[0] + 1).clamp(min=128, max=None)
    voxel_coords_range_list.append(voxel_coords_range.cpu().numpy())
    voxel_feats_list.append(voxel_features)
    voxel_coords_list.append(voxel_coords)

# 合并所有点云的体素信息，打标签，手动聚合，因为每一个本质上已经体素化了
voxel_batch_indices = torch.cat([
    torch.full(
        (pc.shape[0],), i, dtype=torch.int32, device="cuda"
    )
    for i, pc in enumerate(voxel_coords_list)
], dim=0)
voxel_coords = torch.cat([
    pc for pc in voxel_coords_list
], dim=0)
voxel_coords = torch.cat([
    voxel_batch_indices[:, None], voxel_coords
], dim=-1)
voxel_features = torch.cat([
    pc for pc in voxel_feats_list
], dim=0)

# 创建稀疏卷积张量
voxel_coords_ranges = np.max([
    voxel_coords_range for voxel_coords_range in voxel_coords_range_list
], axis=0) # 取三个坐标轴的最大值，用于指定体素的范围

voxel_tensor = spconv.SparseConvTensor(
    voxel_features.to("cuda"), voxel_coords.to("cuda"),
    spatial_shape=voxel_coords_ranges.tolist(),
    batch_size=2,
)

instance_labels = np.random.randint(0, 10, size=100)  # 100个随机实例标签，范围在0到9之间




In [9]:
model = PointNetBackbone(64, 10)
model = model.to("cuda")
print(model)
inputs = torch.randn(64, 64+3, 2048).to("cuda")
# print(model.device)
# print(inputs.device)
outputs = model(inputs)[0]
print(outputs.size()) # 输出(64, 2048, 10) -> (bs, N, fea)

PointNetBackbone(
  (backbone): PointNetSegBackbone(
    (feat): PointNetEncoder(
      (stn): STN3d(
        (conv1): Conv1d(67, 64, kernel_size=(1,), stride=(1,))
        (conv2): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
        (conv3): Conv1d(128, 1024, kernel_size=(1,), stride=(1,))
        (fc1): Linear(in_features=1024, out_features=512, bias=True)
        (fc2): Linear(in_features=512, out_features=256, bias=True)
        (fc3): Linear(in_features=256, out_features=9, bias=True)
        (relu): ReLU()
        (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn3): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn5): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

In [10]:
sparseunet = SparseUNet.build(6, [16,32,48,64,80,96,112], 2, functools.partial(nn.BatchNorm1d, eps=1e-4, momentum=0.1))
sparseunet.to("cuda")
print(sparseunet)

print(voxel_features.size())
print(voxel_coords.size())
print(_.size())
print(pc_voxel_id.size())

print(sparseunet(voxel_tensor)) # output: (N, 16), batch size inside N. 

SparseUNet(
  (stem): SparseSequential(
    (0): SubMConv3d(6, 16, kernel_size=[3, 3, 3], stride=[1, 1, 1], padding=[1, 1, 1], dilation=[1, 1, 1], output_padding=[0, 0, 0], bias=False, algo=ConvAlgo.MaskImplicitGemm)
    (1): BatchNorm1d(16, eps=0.0001, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (ublock): UBlock(
    (encoder_blocks): SparseSequential(
      (0): ResBlock(
        (shortcut): Identity()
        (conv1): SparseSequential(
          (0): SubMConv3d(16, 16, kernel_size=[3, 3, 3], stride=[1, 1, 1], padding=[1, 1, 1], dilation=[1, 1, 1], output_padding=[0, 0, 0], bias=False, algo=ConvAlgo.MaskImplicitGemm)
          (1): BatchNorm1d(16, eps=0.0001, momentum=0.1, affine=True, track_running_stats=True)
        )
        (conv2): SparseSequential(
          (0): SubMConv3d(16, 16, kernel_size=[3, 3, 3], stride=[1, 1, 1], padding=[1, 1, 1], dilation=[1, 1, 1], output_padding=[0, 0, 0], bias=False, algo=ConvAlgo.MaskImplicitGemm)
          (1): Ba

In [7]:
class MainModel(nn.Module):
    def __init__(
        self,
        in_channels: int,
        num_part_classes: int,
        backbone_type: str = "SparseUNet",
        backbone_cfg: Dict = {},
        learning_rate: float = 1e-3,
        # semantic segmentation
        ignore_sem_label: int = -100,
        use_sem_focal_loss: bool = True,
        use_sem_dice_loss: bool = True,
        # instance segmentation
        instance_seg_cfg: Dict = {},
        # npcs segmentation
        symmetry_indices: List = [],
        # training
        training_schedule: List = [],
        # validation
        val_score_threshold: float = 0.09,
        val_min_num_points_per_proposal: int = 3,
        val_nms_iou_threshold: float = 0.3,
        val_ap_iou_threshold: float = 0.5,
        # testing
        visualize_cfg: Dict = {},
        
        debug: bool = True,
        ckpt: str = "", # type: ignore
    ):
        super().__init__()
        self.validation_step_outputs = []
        self.device = "cuda"
        self.in_channels = in_channels
        self.num_part_classes = num_part_classes
        self.backbone_type = backbone_type
        self.backbone_cfg = backbone_cfg
        self.learning_rate = learning_rate
        self.ignore_sem_label = ignore_sem_label
        self.use_sem_focal_loss = use_sem_focal_loss
        self.use_sem_dice_loss = use_sem_dice_loss
        self.visualize_cfg = visualize_cfg
        self.start_scorenet, self.start_npcs = training_schedule
        self.start_clustering = min(self.start_scorenet, self.start_npcs)
        self.val_nms_iou_threshold = val_nms_iou_threshold
        self.val_ap_iou_threshold = val_ap_iou_threshold
        self.val_score_threshold = val_score_threshold
        self.val_min_num_points_per_proposal = val_min_num_points_per_proposal
        self.symmetry_indices = torch.as_tensor(symmetry_indices, dtype=torch.int64).to(self.device)

        self.ball_query_radius = instance_seg_cfg["ball_query_radius"]
        self.max_num_points_per_query = instance_seg_cfg["max_num_points_per_query"]
        self.min_num_points_per_proposal = instance_seg_cfg["min_num_points_per_proposal"]
        self.max_num_points_per_query_shift = instance_seg_cfg["max_num_points_per_query_shift"]
        self.score_fullscale = instance_seg_cfg["score_fullscale"]
        self.score_scale = instance_seg_cfg["score_scale"]
        
        
        ## network
        norm_fn = functools.partial(nn.BatchNorm1d, eps=1e-4, momentum=0.1)
        if self.backbone_type == "SparseUNet":
            channels = self.backbone_cfg["channels"]
            block_repeat = self.backbone_cfg["block_repeat"]
            fea_dim = channels[0]
            self.backbone = SparseUNet.build(in_channels, channels, block_repeat, norm_fn)
        else:
            raise NotImplementedError
        self.sem_seg_head = nn.Linear(fea_dim, self.num_part_classes)
        # offset prediction
        self.offset_head = nn.Sequential(
            nn.Linear(fea_dim, fea_dim),
            norm_fn(fea_dim),
            nn.ReLU(inplace=True),
            nn.Linear(fea_dim, 3),
        )
        
        self.score_unet = SparseUNet.build( # 
            fea_dim, channels[:2], block_repeat, norm_fn, without_stem=True
        )
        self.score_head = nn.Linear(fea_dim, self.num_part_classes - 1) # link to 10 - 1 = 9 class 
        
        
        self.npcs_unet = SparseUNet.build( # masked fea to npcs
            fea_dim, channels[:2], block_repeat, norm_fn, without_stem=True
        )
        self.npcs_head = nn.Linear(fea_dim, 3 * (self.num_part_classes - 1)) # 27 
        
        (
            symmetry_matrix_1, symmetry_matrix_2, symmetry_matrix_3
        ) = get_symmetry_matrix()
        self.symmetry_matrix_1 = symmetry_matrix_1
        self.symmetry_matrix_2 = symmetry_matrix_2
        self.symmetry_matrix_3 = symmetry_matrix_3
    def forward(
        self,
        point_clouds: List[PointCloud],
    ):
        batch_size = len(point_clouds)
        
        # data batch parsing
        data_batch = PointCloud.collate(point_clouds)
        points = data_batch.points
        sem_labels = data_batch.sem_labels
        pc_ids = data_batch.pc_ids
        instance_regions = data_batch.instance_regions
        instance_labels = data_batch.instance_labels
        batch_indices = data_batch.batch_indices
        instance_sem_labels = data_batch.instance_sem_labels
        num_points_per_instance = data_batch.num_points_per_instance
        gt_npcs = data_batch.gt_npcs
        
        
        pt_xyz = points[:, :3]
        # cls_labels.to(pt_xyz.device)

        pc_feature = self.forward_backbone(pc_batch=data_batch)

        # semantic segmentation
        sem_logits = self.forward_sem_seg(pc_feature) # (N, 9)
        
        sem_preds = torch.argmax(sem_logits.detach(), dim=-1) # (N)
        # no loss, only forward 

        sem_seg = Segmentation(
            batch_size=batch_size,
            sem_preds=sem_preds,
            sem_labels=None,
            all_accu=None,
            pixel_accu=None,)
        
        offsets_preds = self.forward_offset(pc_feature) # (N, 3)

        voxel_tensor, pc_voxel_id, proposals = self.proposal_clustering_and_revoxelize(
            pt_xyz = pt_xyz,
            batch_indices=batch_indices,
            pt_features=pc_feature,
            sem_preds=sem_preds,
            offset_preds=offsets_preds,
            instance_labels=instance_labels,
        )
        
        if sem_labels is not None and proposals is not None:
            proposals.sem_labels = sem_labels[proposals.valid_mask][
                proposals.sorted_indices
            ]
        if proposals is not None:
            proposals.instance_sem_labels = instance_sem_labels

                
        # clustering and scoring
        score_logits = self.forward_proposal_score(
            voxel_tensor, pc_voxel_id, proposals
        ) # type: ignore
        proposal_offsets_begin = proposals.proposal_offsets[:-1].long() # type: ignore

        if proposals.sem_labels is not None: # type: ignore
            proposal_sem_labels = proposals.sem_labels[proposal_offsets_begin].long() # type: ignore
        else:
            proposal_sem_labels = proposals.sem_preds[proposal_offsets_begin].long() # type: ignore
        score_logits = score_logits.gather(
            1, proposal_sem_labels[:, None] - 1
        ).squeeze(1)
        proposals.score_preds = score_logits.detach().sigmoid() # type: ignore

            

        npcs_logits = self.forward_proposal_npcs(
            voxel_tensor, pc_voxel_id
        )

        
        # no total loss
        # loss = loss_sem_seg + loss_offset_dist + loss_offset_dir + loss_prop_score + loss_prop_npcs

        return pc_ids, sem_seg, proposals # 索引，实例分隔，候选框，(损失)
    def forward_backbone(
        self,
        pc_batch: PointCloudBatch,
    ):
        if self.backbone_type == "SparseUNet":
            voxel_tensor = pc_batch.voxel_tensor
            pc_voxel_id = pc_batch.pc_voxel_id
            voxel_features = self.backbone(voxel_tensor)
            pc_feature = voxel_features.features[pc_voxel_id]
        else:
            raise ValueError
        
        return pc_feature
    def forward_sem_seg( # 语义分割，每个点的
        self,
        pc_feature: torch.Tensor,
    ) -> torch.Tensor:
        sem_logits = self.sem_seg_head(pc_feature) # (N, 16) to (N, 9)

        return sem_logits
    def forward_offset(
        self,   
        pc_feature: torch.Tensor,
    ) -> torch.Tensor:
        offset = self.offset_head(pc_feature) # (N, 3)

        return offset
    def proposal_clustering_and_revoxelize(
        self,
        pt_xyz: torch.Tensor,  # 输入点云的坐标信息，形状为 (N, 3)，N 是点的数量
        batch_indices: torch.Tensor,  # 每个点所属的批次索引，形状为 (N,)
        pt_features: torch.Tensor,  # 输入点云的特征信息，形状为 (N, C)，C 是特征的维度
        sem_preds: torch.Tensor,  # 每个点的语义预测，形状为 (N, 9)
        offset_preds: torch.Tensor,  # 每个点的偏移预测，形状为 (N, 3)
        instance_labels: Optional[torch.Tensor],  # 每个点的实例标签，形状为 (N, 9)，可选参数，可能为空
    ):
        device = self.device
        # 过滤掉语义预测为零的点
        if instance_labels is not None:
            valid_mask = (sem_preds > 0) & (instance_labels >= 0)
        else:
            valid_mask = sem_preds > 0
        # 根据有效掩码过滤输入数据
        pt_xyz = pt_xyz[valid_mask]
        batch_indices = batch_indices[valid_mask]
        pt_features = pt_features[valid_mask]
        sem_preds = sem_preds[valid_mask].int()
        offset_preds = offset_preds[valid_mask]
        if instance_labels is not None:
            instance_labels = instance_labels[valid_mask]
            
        # get batch offsets (csr) from batch indices
        _, batch_indices_compact, num_points_per_batch = torch.unique_consecutive( # 找不同
            batch_indices, return_inverse=True, return_counts=True
        )
        batch_indices_compact = batch_indices_compact.int()
        batch_offsets = torch.zeros(
            (num_points_per_batch.shape[0] + 1,), dtype=torch.int32, device=device
        )
        batch_offsets[1:] = num_points_per_batch.cumsum(0)
        
        # cluster proposals: dual set
        sorted_cc_labels, sorted_indices = cluster_proposals( # 绝对坐标聚类
            pt_xyz, batch_indices_compact, batch_offsets, sem_preds,
            self.ball_query_radius, self.max_num_points_per_query,
        )

        sorted_cc_labels_shift, sorted_indices_shift = cluster_proposals( # 相对坐标聚类
            pt_xyz + offset_preds, batch_indices_compact, batch_offsets, sem_preds,
            self.ball_query_radius, self.max_num_points_per_query_shift,
        )
        
        # combine clusters
        sorted_cc_labels = torch.cat([
            sorted_cc_labels,
            sorted_cc_labels_shift + sorted_cc_labels.shape[0],
        ], dim=0)
        sorted_indices = torch.cat([sorted_indices, sorted_indices_shift], dim=0)

        # compact the proposal ids
        _, proposal_indices, num_points_per_proposal = torch.unique_consecutive( # 找重复元素
            sorted_cc_labels, return_inverse=True, return_counts=True
        )

        # remove small proposals
        valid_proposal_mask = (
            num_points_per_proposal >= self.min_num_points_per_proposal # 直接删过少的
        )
        # proposal to point
        valid_point_mask = valid_proposal_mask[proposal_indices] # mask的mask

        sorted_indices = sorted_indices[valid_point_mask]
        if sorted_indices.shape[0] == 0:
            return None, None, None

        batch_indices = batch_indices[sorted_indices]
        pt_xyz = pt_xyz[sorted_indices]
        pt_features = pt_features[sorted_indices]
        sem_preds = sem_preds[sorted_indices]
        if instance_labels is not None:
            instance_labels = instance_labels[sorted_indices]

        # re-compact the proposal ids 保留有效的框
        proposal_indices = proposal_indices[valid_point_mask]
        _, proposal_indices, num_points_per_proposal = torch.unique_consecutive(
            proposal_indices, return_inverse=True, return_counts=True
        )
        num_proposals = num_points_per_proposal.shape[0]

        # get proposal batch offsets
        proposal_offsets = torch.zeros(
            num_proposals + 1, dtype=torch.int32, device=device
        )
        proposal_offsets[1:] = num_points_per_proposal.cumsum(0) # cumsum 生成sum数组, 第一个留0, 成offset了

        # voxelization
        voxel_features, voxel_coords, pc_voxel_id = segmented_voxelize(
            pt_xyz, pt_features,
            proposal_offsets, proposal_indices,
            num_points_per_proposal,
            self.score_fullscale, self.score_scale,
        )
        voxel_tensor = spconv.SparseConvTensor(
            voxel_features, voxel_coords.int(),
            spatial_shape=[self.score_fullscale] * 3,
            batch_size=num_proposals,
        )
        if not (pc_voxel_id >= 0).all():
            import pdb
            pdb.set_trace()
            


        proposals = Instances( # 包含了几乎所有信息
            valid_mask=valid_mask,
            sorted_indices=sorted_indices,
            pt_xyz=pt_xyz,
            batch_indices=batch_indices,
            proposal_offsets=proposal_offsets,
            proposal_indices=proposal_indices,
            num_points_per_proposal=num_points_per_proposal,
            sem_preds=sem_preds,
            instance_labels=instance_labels,
        )

        return voxel_tensor, pc_voxel_id, proposals
    def forward_proposal_score(
        self,
        voxel_tensor: spconv.SparseConvTensor,
        pc_voxel_id: torch.Tensor,
        proposals: Instances,
    ):
        proposal_offsets = proposals.proposal_offsets
        proposal_offsets_begin = proposal_offsets[:-1] # type: ignore
        proposal_offsets_end = proposal_offsets[1:] # type: ignore

        score_features = self.score_unet(voxel_tensor)
        score_features = score_features.features[pc_voxel_id]
        pooled_score_features, _ = segmented_maxpool(
            score_features, proposal_offsets_begin, proposal_offsets_end
        )
        score_logits = self.score_head(pooled_score_features)

        return score_logits
    def forward_proposal_npcs(
        self,
        voxel_tensor: spconv.SparseConvTensor,
        pc_voxel_id: torch.Tensor,
    ) -> torch.Tensor:
        npcs_features = self.npcs_unet(voxel_tensor)
        npcs_logits = self.npcs_head(npcs_features.features)
        npcs_logits = npcs_logits[pc_voxel_id] # 通过pc_voxel_id转回成tensor的所有点

        return npcs_logits


In [39]:
backbone_cfg = {
    "channels": [16,32,48,64,80,96,112],
    "block_repeat": 2
}
instance_seg_cfg = {
      "ball_query_radius": 0.04,
      "max_num_points_per_query": 50,
      "min_num_points_per_proposal": 5, # 50 for scannet?
      "max_num_points_per_query_shift": 300,
      "score_fullscale": 28,
      "score_scale": 50,
}
model = MainModel(in_channels=6,
                  num_part_classes=10,
                  backbone_type="SparseUNet",
                  backbone_cfg=backbone_cfg,
                  instance_seg_cfg=instance_seg_cfg,
                  debug=True,
                  learning_rate=0.001,
                  ignore_sem_label=-100,
                  use_sem_focal_loss=True,
                  use_sem_dice_loss=True,
                  training_schedule=[5,10],
                  val_nms_iou_threshold=0.3,
                  val_ap_iou_threshold=0.5,
                  symmetry_indices=[0, 1, 3, 3, 2, 0, 3, 2, 4, 1],).cuda()

print(model)
# 实例化一个PointCloud对象
point_clouds = []
for i in range(10):
    pc = PointCloud(
        pc_id="random_pc",
        points=np.random.rand(2000, 6).astype(np.float32),
        obj_cat=[0, 2, 6],
        sem_labels=None,
        instance_labels=None
    ).to_tensor()
    pc = apply_voxelization(pc, voxel_size=[0.01,0.01,0.01])
    point_clouds.append(pc.to("cuda"))

model(point_clouds)

MainModel(
  (backbone): SparseUNet(
    (stem): SparseSequential(
      (0): SubMConv3d(6, 16, kernel_size=[3, 3, 3], stride=[1, 1, 1], padding=[1, 1, 1], dilation=[1, 1, 1], output_padding=[0, 0, 0], bias=False, algo=ConvAlgo.MaskImplicitGemm)
      (1): BatchNorm1d(16, eps=0.0001, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (ublock): UBlock(
      (encoder_blocks): SparseSequential(
        (0): ResBlock(
          (shortcut): Identity()
          (conv1): SparseSequential(
            (0): SubMConv3d(16, 16, kernel_size=[3, 3, 3], stride=[1, 1, 1], padding=[1, 1, 1], dilation=[1, 1, 1], output_padding=[0, 0, 0], bias=False, algo=ConvAlgo.MaskImplicitGemm)
            (1): BatchNorm1d(16, eps=0.0001, momentum=0.1, affine=True, track_running_stats=True)
          )
          (conv2): SparseSequential(
            (0): SubMConv3d(16, 16, kernel_size=[3, 3, 3], stride=[1, 1, 1], padding=[1, 1, 1], dilation=[1, 1, 1], output_padding=[0, 0, 0], bias=F

(['random_pc',
  'random_pc',
  'random_pc',
  'random_pc',
  'random_pc',
  'random_pc',
  'random_pc',
  'random_pc',
  'random_pc',
  'random_pc'],
 Segmentation(batch_size=10, sem_preds=tensor([5, 4, 0,  ..., 5, 0, 7], device='cuda:0'), sem_labels=None, all_accu=None, pixel_accu=None),
 Instances(valid_mask=tensor([ True,  True, False,  ...,  True, False,  True], device='cuda:0'), sorted_indices=tensor([11525, 11877, 12005, 12027, 12522,    15,    93,   218,   509,   873,
           378,   407,   934,  1000,  1281,  3418,  3590,  3694,  4143,  4343,
          6830,  6976,  7056,  7124,  7612,  8650,  8692,  9280,  9456,  9499,
          9804,  9806,  9869,  9901, 10477,  9932, 10228, 10312, 10554, 11216,
         11827, 12025, 12357, 12470, 12585, 13163, 13761, 13994, 14370, 14528,
         14581], device='cuda:0'), pt_xyz=tensor([[0.5938, 0.4266, 0.8013],
         [0.6248, 0.4266, 0.8155],
         [0.5656, 0.4180, 0.7876],
         [0.6331, 0.3753, 0.7959],
         [0.6415, 0.39

In [8]:
# test
import torch
s = torch.tensor(range(16)).view(4,4)
mask = s > 5
print(mask)
print(s[mask])
s.to("cuda")
print(s.device)
model_ttt = nn.Conv2d(3,3,3)
print(next(model_ttt.parameters()).device)
model_ttt.to("cuda")
print(next(model_ttt.parameters()).device)

tensor([[False, False, False, False],
        [False, False,  True,  True],
        [ True,  True,  True,  True],
        [ True,  True,  True,  True]])
tensor([ 6,  7,  8,  9, 10, 11, 12, 13, 14, 15])
cpu
cpu
cuda:0
