In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR, ReduceLROnPlateau
import torch.nn.functional as F
import torch.nn.init as init
from torch.utils.data import Dataset, DataLoader
from timm.models.layers import DropPath
import numpy as np
from sklearn.metrics import accuracy_score, balanced_accuracy_score
from sklearn.neighbors import KDTree
# from mmcv.ops.group_points import knn, grouping_operation   # knn: k近邻算法   grouping_operation: 使用索引获取特征数据
# from mmcv.ops import gather_points, furthest_point_sample, three_nn, three_interpolate # gather_points: 使用索引获取坐标数据   furthest_point_sample: 最远点采样
from mmcv.ops import three_nn, three_interpolate 
from knn_cuda import KNN
from pointnet2_ops import pointnet2_utils
from glob import glob
import h5py
import os
import json
from tqdm import tqdm
from time import time
import random
# import matplotlib.pyplot as plt
# import laspy
import pickle
from thop import profile   # 计算模型参数量和FLOPs

### DataSet

#### S3DIS

In [2]:
class S3DISDataset(Dataset):
    '''
    def __init__(self, root="./data/S3DIS", num_point=40960, sub_grid_size=0.04, split="train", 
                 test_area=5, transform=None, num_layers=4, subsample_ratio=[4, 4, 4, 2], k=16):
                 
    return:
        x, neighbor_index, subsample_index, upsample_index, feature, label
     #  x: [num_point, 3]   neighbor_index: [num_point, k]   subsample_index: [num_point//stride, k]  upsample_index: [num_point, 1]
     #  feature: [num_point, 4]   label: [num_point,]
    '''
    def __init__(self, root="./data/S3DIS", num_point=40960, sub_grid_size=0.04, split="train", 
                 test_area=5, transform=None, num_layers=4, subsample_ratio=[4, 4, 4, 2], k=16):
        '''
        sub_grid_size: 以sub_grid_size m³为一个网格，在内随机选择一个点，代表该网格
        num_layers: 深度神经网络Encoder层的数目
        subsample_ratio: 下采样点数比例
        k: k近邻
        '''
        
        self.root = root
        self.num_point = num_point   # 所需采样的点数
        self.sub_grid_size = sub_grid_size
        self.split = split
        self.transform = transform     # 数据增强
        self.num_layers = num_layers   # Encoder层的数目
        self.subsample_ratio = subsample_ratio
        self.k = k   # k近邻

        self.original_path = root+"/"+"original"
        self.tree_path = root+"/"+"sub_grid_sample"

        self.test_proj = []
        self.test_proj_label = []
        
        self.trees = []
        self.colors = []
        self.labels = []
        
        self.possibility = []
        self.min_possibility = []
        
        original_files = [os.path.basename(file) for file in glob(self.original_path+"/*.h5")]     # ['Area_1_conferenceRoom_1.h5', 'Area_1_conferenceRoom_2.h5', ..., 'Area_6_pantry_1.h5']
        if split=="train":
            original_files = [file for file in original_files if int(file.split("_")[1]) != test_area]
        else:
            original_files = [file for file in original_files if int(file.split("_")[1]) == test_area]
        
        for original_file in original_files:
            filename = original_file.split(".")[0]   # 'Area_1_conferenceRoom_1'
            
            # Read sub-sampled point cloud data
            f = h5py.File(self.tree_path+"/"+filename+".h5", "r")
            sub_points = np.array(f["data"])     # [S, 6]  XYZRGB
            sub_labels = np.array(f["label"])    # [S,]    L
            f.close()
            
            # Read search tree data
            with open(self.tree_path+"/"+filename+"_KDTree.pkl", "rb") as f:
                search_tree = pickle.load(f)
            
            self.trees.append(search_tree)  # 由sub-sample数据生成的search_tree，search_tree.data获取sub-sample数据
            self.colors.append(sub_points[:, 3:6])   # [S, 3]
            self.labels.append(sub_labels)   # [S,]
            
            if split!="train":   # Test
                with open(self.tree_path+"/"+filename+"_proj.pkl", "rb") as f:
                    proj_index, proj_labels = pickle.load(f)   # [N,]  [N,]
                self.test_proj.append(proj_index)
                self.test_proj_label.append(proj_labels)
            
        for color in self.colors:
            possi = np.random.rand(color.shape[0]) * 1e-3   # [S,]  range:[0, 0.001)
            self.possibility.append(possi)   # 对每个点云文件都存储 [S,] 随机数
            self.min_possibility.append(min(possi))   # 对每个点云文件都存储 1 个随机数

    def __getitem__(self, index):
        
        # 根据 min_possibility 选取点云
        pc_index = np.argmin(self.min_possibility)
        # 根据 pc_index 在possibility内 选取 一个点
        p_index = np.argmin(self.possibility[pc_index])
        
        # 根据 pc_index 获取 sub-sample 的点云坐标数据
        points = np.array(self.trees[pc_index].data)   # [S, 3]
        # 根据 p_index 获取 sub-sample 的某个点坐标数据
        center_point = points[p_index, :].reshape(1, -1)   # [1, 3]
        
        # 给 center_point 添加噪声
        noise = np.random.normal(scale=0.35, size=center_point.shape)   # [1, 3]
        pick_point = center_point + noise   # [1, 3]
        
        # 检查 points 的点数是符合 num_point 
        if points.shape[0] < self.num_point:   # 下采样点数 < num_point
            query_index = self.trees[pc_index].query(pick_point, k=points.shape[0])[1][0].astype(np.int32)  # [points.shape[0],]
            supplement_index = np.random.choice(query_index, self.num_point-points.shape[0])
            query_index = np.concatenate([query_index, supplement_index])   # [num_point,]
        else:
            query_index = self.trees[pc_index].query(pick_point, k=self.num_point)[1][0].astype(np.int32)   # [num_point,]
        
        # 打乱 query_index
        shuffle_index = np.arange(len(query_index))
        np.random.shuffle(shuffle_index)
        query_index = query_index[shuffle_index]
        
        # 根据 query_index 获取相应的坐标数据、颜色数据以及标签值
        query_xyz = points[query_index]   # [num_point, 3]
        query_height = query_xyz[:, -1:]  # [num_point, 1]  高度特征信息
        query_xyz = query_xyz - pick_point   # [S', 3]
        query_color = self.colors[pc_index][query_index]   # [num_point, 3]
        query_label = self.labels[pc_index][query_index]   # [num_point,]
        
        # 到此为止，进行了以下步骤：
        # 1、首先，随机选取了一个点云
        # 2、在该点云内，随机选取一个点
        # 3、以该点为中心检索出距离其最近的k个点，获取对应索引
        # 4、根据索引，获取坐标数据、颜色数据以及标签值，并且将索引点的坐标相对化。
        
        # 根据邻近点与中心点的距离，更新 possibility 数据（距离越近，增加越多。增加范围：[0, 1]）
        dists = np.sum(query_xyz**2, axis=-1)   # [num_point,]
        delta = (1 - dists/max(dists))**2
        self.possibility[pc_index][query_index] += delta
        self.min_possibility[pc_index] = min(self.possibility[pc_index])
        
        # ============================================================================================================
        # 数据增强
        if self.transform:
            query_xyz, query_color, query_label = self.transform(query_xyz, query_color, query_label)
        # ============================================================================================================
        
        query_xyz = query_xyz.astype(np.float32)
        feature = np.concatenate([query_xyz, query_color, query_height], axis=-1).astype(np.float32)   # [num_point, 4]
        label = query_label.astype(np.int32)
            
        return query_xyz, feature, label   # [N, 3]  [N, 7]  [N,]
        
    def __len__(self):
        if self.split=="train":
            return len(self.trees) * 240
        else:
            return len(self.trees) * 100

### 网络结构

#### FPS

In [3]:
def fps(points, number):
    '''
    Args:
        points: [B, N, 3]
        number: subsample point number
    return:
        fps_idx
    '''
    fps_idx = pointnet2_utils.furthest_point_sample(points, number)
    # fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1, 2).contiguous()
    return fps_idx

#### index_points

In [4]:
def index_points(data, index):
    '''
    获取索引点的坐标或特征
    
    data：[B, C, N]
    index：[B, S, K]
    
    return；
        new_data    # [B, C, S, K]
    '''
    data = data.permute(0, 2, 1)   # [B, N, C]
    
    device = data.device
    
    B = data.shape[0]
    
    index = index.long()
    
    view_shape = list(index.shape)    # [B, S]
    view_shape[1:] = [1] * (len(view_shape) - 1)    # 变为 [B, 1]     后续需要使用  .view().repeat()
    repeat_shape = list(index.shape)  # [B, S]
    repeat_shape[0] = 1         # [1, S]
    
    batch_index = torch.arange(0, B, dtype=torch.long).view(view_shape).repeat(repeat_shape).to(device)   # [B, S]
    
    new_data = data[batch_index, index]
    
    return new_data.permute(0, 3, 1, 2)

#### LRP

In [5]:
def LPR(x, index):   # [B, 3, N]
    '''
    输入:
        x: [B, 3, N]
        index: [B, N, k]
    返回:
        knn_x, relative_x, local_rep, exp_x, ratio_volume   # [B, 3, N, k]  [B, 3, N, k]   [B, 3, N, k]  [B, 1, N, k]  [B, N]
    '''
    B, _, N = x.shape
    k = index.shape[-1]

    knn_x = index_points(x, index)   # [B, 3, N, k]
    central_x = x.unsqueeze(-1).repeat(1, 1, 1, k)   # [B, 3, N, k]

    relative_x = central_x - knn_x   # [B, 3, N, k]

    # 相对点的 alpha 和 beta
    alpha = torch.atan2(relative_x[:, 1, :, :], relative_x[:, 0, :, :]).reshape(B, 1, N, k)   # [B, 1, N, k]
    xy_dist = torch.sqrt(torch.sum(torch.square(relative_x[:, :2, :, :]), dim=1)).reshape(B, 1, N, k)   # [B, 1, N, k]
    beta = torch.atan2(relative_x[:, 2:, :, :], xy_dist).reshape(B, 1, N, k)   # [B, 1, N, k]
    relative_dist = torch.sqrt(torch.sum(torch.square(relative_x), dim=1)).reshape(B, 1, N, k)   # [B, 1, N, k]

    exp_x = torch.exp(-relative_dist)   # [B, 1, N, k]
    local_volume = torch.pow(torch.max(relative_dist.squeeze(1), dim=-1)[0], 3)   # [B, N]  每个neighbor的体积

    # 质心点的 alpha 和 beta
    barycentrer_x = torch.mean(knn_x, dim=-1)   # [B, 3, N]  质心
    direction_relative = x - barycentrer_x   # [B, 3, N]
    direction_relative = direction_relative.unsqueeze(-1).repeat(1, 1, 1, k)   # [B, 3, N, k]
    barycentrer_alpha = torch.atan2(direction_relative[:, 1, :, :], direction_relative[:, 0, :, :]).reshape(B, 1, N, k)   # [B, 1, N, k]
    xy_dist = torch.sqrt(torch.sum(torch.square(direction_relative[:, :2, :, :]), dim=1)).reshape(B, 1, N, k)   # [B, 1, N, k]
    barycentrer_beta = torch.atan2(direction_relative[:, 2:, :, :], xy_dist).reshape(B, 1, N, k)   # [B, 1, N, k]

    angle = torch.cat([alpha-barycentrer_alpha, beta-barycentrer_beta], dim=1)   # [B, 2, N, k]

    local_rep = torch.cat([relative_dist, angle], dim=1)   # [B, 3, N, k]

    global_dist = torch.sqrt(torch.sum(torch.square(x), dim=1))   # [B, N]
    global_volume = torch.pow(torch.max(global_dist, dim=-1, keepdim=True)[0], 3)   # [B, 1]

    ratio_volume = local_volume / global_volume   # [B, N]

    return knn_x, relative_x, local_rep, exp_x, ratio_volume   # [B, 3, N, k]  [B, 3, N, k]   [B, 3, N, k]  [B, 1, N, k]  [B, N]

#### FlaResMLP

In [10]:
class InvResMLP(nn.Module):
    '''
    def __init__(self, input_channel):
    
    def forward(self, x, neighbor_index, feature):   # [B, 3, N]  [B, N, k]  [B, C, N]
    
    返回:
        feature    # [B, C, N]
    '''
    def __init__(self, input_channel):
        super().__init__()
        
        self.input_channel = input_channel
        
        self.mlpf = nn.Sequential(
                        nn.Conv1d(input_channel, (input_channel//3)*3, 1, bias=False),
                        nn.BatchNorm1d((input_channel//3)*3),
                        nn.LeakyReLU(0.2, True)
                    )
        
        self.mlp1 = nn.Sequential(
                       nn.Conv1d((input_channel//3)*3, 4*input_channel, 1, bias=False),
                       nn.BatchNorm1d(4*input_channel),
                       nn.ReLU(True)
                    )
        
        self.mlp2 = nn.Sequential(
                       nn.Conv1d(4*input_channel, input_channel, 1, bias=False),
                       nn.BatchNorm1d(input_channel),
                    )
    
    def forward(self, relative_x, x_encoding, neighbor_index, feature):   # [B, 3, N, k]  [B, C, N, k]  [B, N, k]  [B, C, N]
        identity = feature

        feature = self.mlpf(feature)   # [B, C, N]
        knn_feature = index_points(feature, neighbor_index)   # [B, C, N, k]
        feature = x_encoding + knn_feature   # [B, C, N, k]
        
        B, C, N, k = feature.shape
        
        x_en = relative_x.unsqueeze(2).repeat(1, 1, self.input_channel//3, 1, 1).reshape(B, C, N, k)
        feature = x_en * feature   # [B, C, N, k]
        
        feature = torch.max(feature, dim=-1)[0]   # [B, C, N]
        
        feature = self.mlp2(self.mlp1(feature))   # [B, C, N]
        
        # residual
        feature =  F.relu(feature+identity, inplace=True)   # [B, C, N]
        
        return feature    # [B, C, N]

#### Encoder

In [11]:
class Encoder(nn.Module):
    '''
    def __init__(self, input_channel, output_channel, k, stride, inverse_num):
    
    def forward(self, x, feature):   # [B, 3, N]  [B, C, N]
    
    return:
        feature, sub_x, neighbor_index   # [B, out, S]  [B, 3, S]
    '''
    def __init__(self, input_channel, output_channel, k, stride, inverse_num):
        super().__init__()
        
        self.k = k
        self.stride = stride

        self.mlp_f1 = nn.Sequential(
                        nn.Conv2d(input_channel, 1, 1),
                    )
        
        self.mlp_f2 = nn.Sequential(
                        nn.Conv2d(input_channel, output_channel//2, 1),
                        nn.BatchNorm2d(output_channel//2),
                        nn.ReLU(True)
                    )

        self.Rw = nn.Parameter(torch.randn(output_channel//2, 2))
        
        self.mlp1 = nn.Sequential(
                        nn.Conv2d(input_channel+9, output_channel, 1),
                        nn.BatchNorm2d(output_channel),
                        nn.ReLU(True)
                    )
        
        self.mlp2 = nn.Sequential(
                        nn.Conv1d(output_channel, output_channel, 1),
                        nn.BatchNorm1d(output_channel),
                    )
        self.res = nn.Sequential(
                        nn.Conv1d(input_channel, output_channel, 1),
                        nn.BatchNorm1d(output_channel),
                    )

        self.x_mlp = nn.Sequential(
                        nn.Conv2d(3, (output_channel//3)*3, 1),
                        nn.BatchNorm2d((output_channel//3)*3),
                        nn.ReLU(True)
                    )
        
        self.invres = nn.ModuleList()
        for i in range(inverse_num):
            self.invres.append(InvResMLP(output_channel))
        
        self.knn = KNN(k=k, transpose_mode=True)
        
    def forward(self, x, feature):   # [B, 3, N]  [B, C, N]
        _, _, N = x.shape
        
        identity = feature   # [B, C, N]
        
        _, neighbor_index = self.knn(x.permute(0, 2, 1).contiguous(), x.permute(0, 2, 1).contiguous())  # [B, N, k]

        # grouping
        knn_x, relative_x, local_rep, _, _ = LPR(x, neighbor_index)   # [B, 3, N, k]  [B, 3, N, k]  [B, 3, N, k]
        knn_feature = index_points(feature, neighbor_index) - feature.unsqueeze(-1)   # [B, C, N, k]
        
        x_enc = torch.cat([relative_x, local_rep], dim=1)   # [B, 6, N, k]

        scaling = self.mlp_f1(knn_feature)   # [B, 1, N, k]
        feature = self.mlp_f2(knn_feature)   # [B, out//2, N, k]
        rotation = (feature.permute(0, 2, 3, 1) @ self.Rw).permute(0, 3, 1, 2)   # [B, 2, N, k]

        feature = torch.cat([x_enc, knn_feature, scaling, rotation], dim=1)   # [B, C+9, N, k]
        feature = self.mlp1(feature)   # [B, out, N, k]
        feature = torch.max(feature, dim=-1)[0]   # [B, out, N]

        # residual
        feature = self.mlp2(feature)   # [B, out, N]
        res = self.res(identity)   # [B, out, N]
        feature = F.relu(feature+res, True)

        x_encoding = self.x_mlp(relative_x)
        for res in self.invres:
            feature = res(relative_x, x_encoding, neighbor_index, feature)

        # sub-sample
        sample_num = N//self.stride
        fps_idx = fps(x.permute(0, 2, 1).contiguous(), sample_num)
        sub_x = pointnet2_utils.gather_operation(x, fps_idx)   # [B, 3, S]
        sub_feature = pointnet2_utils.gather_operation(feature, fps_idx)    # [B, C, S]
        
        return sub_feature, sub_x, neighbor_index   # [B, out, S]  [B, 3, S]

#### serialization

In [12]:
def serialization(pos, feat, x_res=None, order="z", layers_outputs=[], grid_size=0.02):
    '''
    Args:
        order(str): [ "xyz", "xzy", "yxz", "yzx", "zxy", "zyx", "hilbert", "z", "z-trans" ]
    '''
    bs, n_p, _ = pos.size()
    if not isinstance(order, list):
        order = [order]

    scaled_coord = pos / grid_size
    grid_coord = torch.floor(scaled_coord).to(torch.int64)
    min_coord = grid_coord.min(dim=1, keepdim=True)[0]
    grid_coord = grid_coord - min_coord

    batch_idx = torch.arange(0, pos.shape[0], 1.0).unsqueeze(1).repeat(1, pos.shape[1]).to(torch.int64).to(pos.device)

    point_dict = {'batch': batch_idx.flatten(), 'grid_coord': grid_coord.flatten(0, 1), }
    point_dict = Point(**point_dict)
    point_dict.serialization(order=order)

    order = point_dict.serialized_order
    inverse_order = point_dict.serialized_inverse

    pos = pos.flatten(0, 1)[order].reshape(bs, n_p, -1).contiguous()
    feat = feat.flatten(0, 1)[order].reshape(bs, n_p, -1).contiguous()
    if x_res is not None:
        x_res = x_res.flatten(0, 1)[order].reshape(bs, n_p, -1).contiguous()

    for i in range(len(layers_outputs)):
        layers_outputs[i] = layers_outputs[i].flatten(0, 1)[order].reshape(bs, n_p, -1).contiguous()
    return pos, feat, x_res

#### Mamba Block

In [13]:
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

try:
    from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
    RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
from timm.models.layers import DropPath

class Block(nn.Module):
    def __init__(
        self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False,
            residual_in_fp32=False, drop_path=0.
    ):
        """
        Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"

        This Block has a slightly different structure compared to a regular
        prenorm Transformer block.
        The standard block is: LN -> MHA/MLP -> Add.
        [Ref: https://arxiv.org/abs/2002.04745]
        Here we have: Add -> LN -> Mixer, returning both
        the hidden_states (output of the mixer) and the residual.
        This is purely for performance reasons, as we can fuse add and LayerNorm.
        The residual needs to be provided (except for the very first block).
        """
        super().__init__()
        self.residual_in_fp32 = residual_in_fp32
        self.fused_add_norm = fused_add_norm
        self.mixer = mixer_cls(dim)
        self.norm = norm_cls(dim)
        
        # drop path 
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        if self.fused_add_norm:
            assert RMSNorm is not None, "RMSNorm import fails"
            assert isinstance(
                self.norm, (nn.LayerNorm, RMSNorm)
            ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"

    def forward(
        self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
    ):
        r"""Pass the input through the encoder layer.

        Args:
            hidden_states: the sequence to the encoder layer (required).
            residual: hidden_states = Mixer(LN(residual))
        """
        if not self.fused_add_norm:
            residual = (self.drop_path(hidden_states) + residual) if residual is not None else hidden_states
            hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
            if self.residual_in_fp32:
                residual = residual.to(torch.float32)
        else:
            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
            hidden_states, residual = fused_add_norm_fn(
                self.drop_path(hidden_states),
                self.norm.weight,
                self.norm.bias,
                residual=residual,
                prenorm=True,
                residual_in_fp32=self.residual_in_fp32,
                eps=self.norm.eps,
            )
        hidden_states = self.mixer(hidden_states, inference_params=inference_params)
        return hidden_states, residual

    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)

In [14]:
from functools import partial

from mamba_ssm.modules.mamba_simple import Mamba

def create_block(
        d_model,
        ssm_cfg=None,
        norm_epsilon=1e-5,
        rms_norm=False,
        residual_in_fp32=False,
        fused_add_norm=False,
        layer_idx=None,
        drop_path=0.,
        device=None,
        dtype=None,
):
    if ssm_cfg is None:
        ssm_cfg = {}
    factory_kwargs = {"device": device, "dtype": dtype}

    mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
    norm_cls = partial(
        nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
    )
    block = Block(
        d_model,
        mixer_cls,
        norm_cls=norm_cls,
        fused_add_norm=fused_add_norm,
        residual_in_fp32=residual_in_fp32,
        drop_path=drop_path,
    )
    block.layer_idx = layer_idx
    return block

#### Mamba

In [15]:
def _init_weights(
        module,
        n_layer,
        initializer_range=0.02,  # Now only used for embedding layer.
        rescale_prenorm_residual=True,
        n_residuals_per_layer=1,  # Change to 2 if we have MLP
):
    if isinstance(module, nn.Linear):
        if module.bias is not None:
            if not getattr(module.bias, "_no_reinit", False):
                nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        nn.init.normal_(module.weight, std=initializer_range)

    if rescale_prenorm_residual:
        # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
        #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
        #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
        #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
        #
        # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
        for name, p in module.named_parameters():
            if name in ["out_proj.weight", "fc2.weight"]:
                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
                # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
                # We need to reinit p since this code could be called multiple times
                # Having just p *= scale would repeatedly scale it down
                nn.init.kaiming_uniform_(p, a=math.sqrt(5))
                with torch.no_grad():
                    p /= math.sqrt(n_residuals_per_layer * n_layer)


In [16]:
import math

class MixerModel(nn.Module):
    def __init__(
            self,
            d_model: int,
            n_layer: int,
            ssm_cfg=None,
            norm_epsilon: float = 1e-5,
            rms_norm: bool = False,
            initializer_cfg=None,
            fused_add_norm=False,
            residual_in_fp32=False,
            drop_out_in_block: int = 0.,
            drop_path: int = 0.1,
            device=None,
            dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.residual_in_fp32 = residual_in_fp32

        # self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)

        # We change the order of residual and layer norm:
        # Instead of LN -> Attn / MLP -> Add, we do:
        # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
        # the main branch (output of MLP / Mixer). The model definition is unchanged.
        # This is for performance reason: we can fuse add + layer_norm.
        self.fused_add_norm = fused_add_norm
        if self.fused_add_norm:
            if layer_norm_fn is None or rms_norm_fn is None:
                raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")

        self.layers = nn.ModuleList(
            [
                create_block(
                    d_model,
                    ssm_cfg=ssm_cfg,
                    norm_epsilon=norm_epsilon,
                    rms_norm=rms_norm,
                    residual_in_fp32=residual_in_fp32,
                    fused_add_norm=fused_add_norm,
                    layer_idx=i,
                    drop_path=drop_path,
                    **factory_kwargs,
                )
                for i in range(n_layer)
            ]
        )

        self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
            d_model, eps=norm_epsilon, **factory_kwargs
        )

        self.apply(
            partial(
                _init_weights,
                n_layer=n_layer,
                **(initializer_cfg if initializer_cfg is not None else {}),
            )
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.drop_out_in_block = nn.Dropout(drop_out_in_block) if drop_out_in_block > 0. else nn.Identity()

    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        return {
            i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
            for i, layer in enumerate(self.layers)
        }

    def forward(self, input_ids, pos, inference_params=None):
        hidden_states = input_ids  # + pos
        residual = None
        hidden_states = hidden_states + pos
        for layer in self.layers:                   # hidden_states: all 32, 192, 384
            hidden_states, residual = layer(
                hidden_states, residual, inference_params=inference_params
            )
            hidden_states = self.drop_out_in_block(hidden_states)
        if not self.fused_add_norm:
            residual = (hidden_states + residual) if residual is not None else hidden_states
            hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
        else:
            # Set prenorm=False here since we don't need the residual
            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
            hidden_states = fused_add_norm_fn(
                hidden_states,
                self.norm_f.weight,
                self.norm_f.bias,
                eps=self.norm_f.eps,
                residual=residual,
                prenorm=False,
                residual_in_fp32=self.residual_in_fp32,
            )

        return hidden_states

#### FeaturePropagation

In [17]:
class FeaturePropagation(nn.Module):
    '''
    def __init__(self, input_channel, mlp)
    
    def forward(self, coords1, coords2, points1, points2)
        输入：
            某一层输入：coords1  [B, C, N], points1  [B, D, N]
            某一层输出：coords2  [B, C, S]，points2  [B, D', S]
        返回：
            new_points  [B, D'', N]
    '''
    def __init__(self, input_channel, mlp):
        super().__init__()
        self.mlp_conv = nn.ModuleList()
        self.mlp_bn = nn.ModuleList()
        for output_channel in mlp:
            self.mlp_conv.append(nn.Conv1d(input_channel, output_channel, 1, bias=False))
            self.mlp_bn.append(nn.BatchNorm1d(output_channel))
            input_channel = output_channel

    def forward(self, coords1, coords2, points1, points2):
        '''                                                     第一层          第二层          第三层          第四层
            coords1：[B, C, N]   Set Abstraction的输入     eg.  [B, 3, 2048]    [B, 3, 1024]    [B, 3, 256]     [B, 3, 64]
            coords2：[B, C, S]   Set Abstraction的输出     eg.  [B, 3, 1024]    [B, 3, 256]     [B, 3, 64]      [B, 3, 16]
            points1：[B, D, N]   Set Abstraction的输入     eg.  [B, 9, 2048]    [B, 64, 1024]   [B, 128, 256]   [B, 256, 64]
            points2：[B, D', S]  Set Abstraction的输出     eg.  [B, 64, 1024]   [B, 128, 256]   [B, 256, 64]    [B, 512, 16]
        '''
        # 以第四层为例标明注释
        
        coords1 = coords1.permute(0, 2, 1)    # [B, 64, 3]
        coords2 = coords2.permute(0, 2, 1)    # [B, 16, 3]
        
        dists, index = three_nn(coords1, coords2)   # [B, 64, 3]  [B, 16, 3]
        
        dist_recip = 1 / (dists + 1e-8)
        norm = torch.sum(dist_recip, -1, keepdim=True)
        weight = dist_recip / norm    # [B, 64, 3]
        
        interpolated_points = three_interpolate(points2, index, weight)

        new_points = torch.cat([points1, interpolated_points], dim=1)   # [B, 768, 64]

        for conv,bn in zip(self.mlp_conv, self.mlp_bn):
            new_points = F.relu(bn(conv(new_points)), inplace=True)
        # 最后输出new_points.shape: [B, 256, 64]
        
        return new_points

### Random Seed

In [None]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

### Semantic Seg Task

#### IoU

In [None]:
def caculate_sem_IOU(pred, label):
    '''
    pred: [B × 4096]
    label: [B × 4096]
    '''
    I_all = torch.zeros(13)
    U_all = torch.zeros(13)

    for sem in range(13):
        I = torch.sum(torch.logical_and(pred==sem, label==sem))
        U = torch.sum(torch.logical_or(pred==sem, label==sem))
        I_all[sem] = I_all[sem] + I
        U_all[sem] = U_all[sem] + U
    
    return I_all / U_all

#### Train & Val

In [None]:
def train():
    setup_seed(1)
    
    transform = Compose([RandomRotate(), RandomScale([0.9, 1.1]), RandomFlip(0.5), RandomJitter(0.005, 0.02), 
                         RandomDropColor()])

    train_loader = DataLoader(S3DISDataset(num_point=40960, split="train", transform=transform), 
                              2, True, drop_last=False, num_workers=8, pin_memory=True)
    test_loader = DataLoader(S3DISDataset(num_point=40960, split="test", transform=None),
                             4, True, drop_last=True, num_workers=8, pin_memory=True)

    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"
    
    START_EPOCH = 0
    EPOCH = 100
    lr = 0.01
    
    model = semseg_network().to(device)
    optimizer = optim.AdamW(model.parameters(), lr, weight_decay=1e-4)
    # scheduler = MultiStepLR(optimizer, [int(EPOCH*0.6), int(EPOCH*0.8)], gamma=0.1)
    scheduler = CosineAnnealingLR(optimizer, EPOCH, 1e-5)
    criterion = cal_loss
    
    best_test_iou = 0
    
#     # 加载断点，继续训练
#     path = "./AnisoVector/semseg_last.pkl"
#     checkpoint = torch.load(path)
#     model.load_state_dict(checkpoint["model"])
#     optimizer.load_state_dict(checkpoint["optimizer"])
#     scheduler.load_state_dict(checkpoint["scheduler"])
#     best_test_iou = checkpoint["best_test_iou"]
#     START_EPOCH = checkpoint["epoch"] + 1
    
    for epoch in range(START_EPOCH, EPOCH):
        '''
        Train
        '''
        start_time = time()
        
        train_loss = 0
        count = 0
        train_pred_seg = []
        train_true_seg = []
        train_pred_iou = []
        train_true_iou = []
        
        model.train()
        
        for x, feature, semseg in tqdm(train_loader, total=len(train_loader)):

            B = x.shape[0]
            
            x = x.to(device).transpose(2, 1).float()
            feature = feature.to(device).transpose(2, 1).float()
            semseg = semseg.to(device).int()

            optimizer.zero_grad()
            
            pred_semseg = model(x, feature)
            
            pred_semseg = pred_semseg.transpose(2, 1).contiguous()  # [B, 4096, 13]
            
            loss = criterion(pred_semseg.view(-1, 13), semseg.view(-1))
            loss.backward()
            optimizer.step()
            
            count = count + B
        
            pred_semseg_class = pred_semseg.max(-1)[1]   # [B, 4096]
            
            train_loss = train_loss + loss.item() * B
            
            pred_semseg_class = pred_semseg_class.cpu()
            semseg = semseg.cpu()
            
            train_pred_seg.append(pred_semseg_class.view(-1))   # [B×4096]
            train_true_seg.append(semseg.view(-1))   # [B×4096]
        
        end_time = time()
        
        scheduler.step()
        
        train_pred_seg = torch.cat(train_pred_seg)
        train_true_seg = torch.cat(train_true_seg)
        
        acc = accuracy_score(train_true_seg, train_pred_seg)
        avg_acc = balanced_accuracy_score(train_true_seg, train_pred_seg)
        
        iou = caculate_sem_IOU(train_pred_seg, train_true_seg)
        
        outstr = "Train: %s, loss: %s, acc: %s, avg acc: %s, mIOU: %s, time-consuming: %s" % (str(epoch), str(train_loss / count),
                                                                                                  str(acc), str(avg_acc), 
                                                                                                   str(torch.mean(iou).item()), str(end_time-start_time))
        print(outstr)
        with open("./PointMamba/train.txt", "a") as f:
            f.write(outstr)
            f.write("\n")
        
        '''
        Val
        '''
        val_start = time()
        
        test_loss = 0
        count = 0
        test_pred_seg = []
        test_true_seg = []
        test_pred_iou = []
        test_true_iou = []
        
        with torch.no_grad():
            model.eval()
            for x, feature, semseg in test_loader:
                B = x.shape[0]

                x = x.to(device).transpose(2, 1).float()
                feature = feature.to(device).transpose(2, 1).float()
                semseg = semseg.to(device).int()
                
                pred_semseg = model(x, feature)

                pred_semseg = pred_semseg.transpose(2, 1).contiguous()  # [B, 4096, 13]

                loss = criterion(pred_semseg.view(-1, 13), semseg.view(-1))

                count = count + B

                pred_semseg_class = pred_semseg.max(-1)[1]   # [B, 4096]

                test_loss = test_loss + loss.item() * B

                pred_semseg_class = pred_semseg_class.cpu()
                semseg = semseg.cpu()

                test_pred_seg.append(pred_semseg_class.view(-1))   # [B×4096]
                test_true_seg.append(semseg.view(-1))   # [B×4096]
            
            val_end = time()

            test_pred_seg = torch.cat(test_pred_seg)
            test_true_seg = torch.cat(test_true_seg)

            acc = accuracy_score(test_true_seg, test_pred_seg)
            avg_acc = balanced_accuracy_score(test_true_seg, test_pred_seg)

            iou = caculate_sem_IOU(test_pred_seg, test_true_seg)

            outstr = "Test: %s, loss: %s, acc: %s, avg acc: %s, mIOU: %s, time-consuming: %s" % (str(epoch), str(test_loss / count),
                                                                                                      str(acc), str(avg_acc), 
                                                                                                       str(torch.mean(iou).item()), str(val_end-val_start))
            print(outstr)
            with open("./PointMamba/test.txt", "a") as f:
                f.write(outstr)
                f.write("\n")

            if torch.mean(iou) > best_test_iou:
                best_test_iou = torch.mean(iou)

                # 保存模型的各个参数（断点训练需要）
                checkpoint = {
                            "model": model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            "epoch": epoch,
                            'scheduler': scheduler.state_dict(),
                            'best_test_iou': best_test_iou
                        }
                torch.save(checkpoint, './PointMamba/semseg_'+str(epoch)+".pkl")
            # 保存最新模型
            checkpoint = {
                        "model": model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        "epoch": epoch,
                        'scheduler': scheduler.state_dict(),
                        'best_test_iou': best_test_iou
                    }
            torch.save(checkpoint, './PointMamba/semseg_last.pkl')