In [None]:
import os
import torch
import accelerate

from rd3d.datasets import build_dataloader
from rd3d.models import build_detector
from rd3d.core import Config
from rd3d import PROJECT_ROOT
from rd3d.core.base import ScopeTimer

os.chdir(PROJECT_ROOT)
acc = accelerate.Accelerator()

from rd3d.models.dense_heads.point_seg import PointSegmentor


In [None]:

class VoxelGrouper(torch.nn.Module):
    def __init__(self, grid_size, group_size):
        super().__init__()
        self.grid_size = torch.tensor(grid_size).cuda()
        self.group_size = group_size

    @staticmethod
    def padding(bs, indices, group_size, bid_num_cum):
        """
        (n1,...,nb)
        (ng1,...,ngb)

        """
        elems_num_flat = bid_num_cum[1:] - bid_num_cum[:-1]
        elems_num_group = torch.div(elems_num_flat + group_size - 1, group_size, rounding_mode='floor') * group_size
        elems_num_pad = elems_num_group - elems_num_flat

        indices_padding_list = []
        for i in range(bs):
            indices_padding_list.append(indices[:, bid_num_cum[i]:bid_num_cum[i + 1]])
            if elems_num_pad[i] > 0:
                indices_padding_list.append(indices_padding_list[-1][:, -elems_num_pad[i]:])
        indices_padding = torch.cat(indices_padding_list, dim=-1)
        return indices_padding

    @staticmethod
    def indices_convert_from_to(indices2, indices1):
        indices2_in_original = torch.empty_like(indices1)
        indices2_in_original[indices2] = torch.arange(indices2.shape[0], device=indices2.device)
        indices2_in_original = indices2_in_original[indices1]
        return indices2_in_original

    def forward(self, vox_coors, vox_numbs):
        """

        each block handle a batch sample.

        """
        bid, vox_coors = torch.split(vox_coors, [1, 3], dim=-1)
        bid = bid.view(-1)
        order = sfc.min_required_order(vox_coors)
        bs_info = torch.nn.functional.pad(torch.bincount(bid), (1, 0), mode='constant', value=0)
        bid_num_cum = torch.cumsum(bs_info, dim=0)
        bs = bid_num_cum.shape[0] - 1

        vox_coors = torch.cat((vox_coors, self.grid_size - vox_coors)).int()
        codes = sfc.hilbert_curve_encoder(vox_coors, order)[None, :].view(2, -1)
        codes += bid << order * 3
        indices = torch.argsort(codes, dim=-1)
        return indices
        indices1_group, indices2_group = self.padding(bs, indices, self.group_size, bid_num_cum)
        indices_to_1 = self.indices_convert_from_to(indices2_group, indices1_group)
        indices_to_2 = self.indices_convert_from_to(indices1_group, indices2_group)
        return indices1_group, indices2_group, indices_to_1, indices_to_2


class Layer(torch.nn.Module):
    def __init__(self, input_channel, output_channels):
        super().__init__()
        self.linear1 = torch.nn.Identity()
        self.linear2 = torch.nn.Identity()

    def forward(self, x):
        n, g, c = x.size()
        loc = self.linear1(x.view(-1, c)).view(n, g, -1)
        # glb = torch.nn.functional.max_pool2d(loc, kernel_size=(g, 1))
        x = self.linear2(loc)
        return x


class Block(torch.nn.Module):
    def __init__(self, group_size, cfg):
        super().__init__()
        model_list = []
        for i in range(len(cfg) - 1):
            model_list.append(Layer(cfg[i], cfg[i + 1]))
        self.group_size = group_size
        self.layers = torch.nn.Sequential(*model_list)

    def group(self, x):
        return x.view(-1, self.group_size, x.shape[-1])

    def flatten(self, x):
        return x.view(-1, x.shape[-1])

    def forward(self, x1, i_to, i_from):
        x1 = self.flatten(self.layers(self.group(x1)))
        x2 = self.flatten(self.layers(self.group(x1[i_to])))
        return x2, i_from, i_to


class CurveBackBone(torch.nn.Module):
    def __init__(self, group_size, mlps):
        super().__init__()

        self.blocks = torch.nn.ModuleList(
            [Block(group_size, mlp) for mlp in mlps]
        )

    def forward(self, x, i_to, i_from):
        for block in self.blocks:
            x, i_to, i_from = block(x, i_to, i_from)
        return x

In [None]:



cfg = Config.fromfile_py("configs/voxformer/voxformer_4x2_80e_kitti_3cls.py")
dataloader = build_dataloader(cfg.DATASET, cfg.RUN)
model = build_detector(cfg.MODEL, dataset=dataloader.dataset).cuda()
vfe = model.vfe
batch_dict = next(iter(dataloader))
dataloader.dataset.load_data_to_gpu(batch_dict)

In [None]:
grouper = VoxelGrouper(vfe.grid_size, 32)

batch_dict = vfe(batch_dict)
vox_feats = batch_dict['voxel_features']
vox_coors = batch_dict['voxel_coords']
vox_numbs = batch_dict['voxel_numbers']

In [None]:
from rd3d.ops import sfc

group_size = 32
# indices = grouper(vox_coors)
# out = indices_grouping(indices, vox_numbs, group_size)
grid_size = torch.tensor((0, 2000, 2000, 2000)).cuda()
for _ in range(1000):
    with ScopeTimer("", average=True, verbose=False) as t:
        order = sfc.min_required_order(vox_coors)
        vox_coors2 = vox_coors.detach().clone()
        vox_coors2[:,1:]  = group_size-vox_coors2[:,1:]
        vox_coors = torch.cat((vox_coors, vox_coors2)).int()
        # codes = sfc.hilbert_curve_encoder(vox_coors, order)[None, :].view(2, -1)
        # indices = torch.argsort(codes, dim=-1)
        # out = sfc.indices_grouping(indices, vox_numbs, group_size)

print(t.duration)

In [None]:
# from rd3d.ops.sfc import indices_grouping
#
# for i in range(10000):
#     indices_grouping(indices, vox_numbs, group_size)
# for i in range(10000):
#     with ScopeTimer("grouping:", average=True, verbose=False) as t:
#         indices_grouping(indices, vox_numbs, group_size)
#
# print(t.duration)

In [None]:
# print(indices.shape)
# backbone = CurveBackBone(32, mlps).cuda()
#
# ind1, ind2, ind21, ind12 = grouper(vox_coors)
#
# vox_coors = vox_coors[ind1]
# vox_feats = vox_feats[ind1]
# vox_feats = backbone(vox_feats, ind12, ind21)
# seg = PointSegmentor(c, 3, [c, c]).cuda()
# seg(vox_coors, vox_feats)
# seg.assign_targets(batch_dict)