<a href="https://colab.research.google.com/github/Hannah1123/Learning_Attention_is_all_you_need/blob/master/pointnet2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim.lr_scheduler as lr_sched

from torch.utils.data import DataLoader, DistributedSampler
from torchvision import transforms

In [None]:
!pip install pytorch-lightning

Collecting pytorch-lightning
[?25l  Downloading https://files.pythonhosted.org/packages/61/9e/db4e1e3036e045a25d5c37617ded31a673a61f4befc62c5231818810b3a7/pytorch-lightning-0.7.1.tar.gz (6.0MB)
[K     |████████████████████████████████| 6.0MB 2.6MB/s 
Collecting future>=0.17.1
[?25l  Downloading https://files.pythonhosted.org/packages/45/0b/38b06fd9b92dc2b68d58b75f900e97884c45bedd2ff83203d933cf5851c9/future-0.18.2.tar.gz (829kB)
[K     |████████████████████████████████| 829kB 41.8MB/s 
Building wheels for collected packages: pytorch-lightning, future
  Building wheel for pytorch-lightning (setup.py) ... [?25l[?25hdone
  Created wheel for pytorch-lightning: filename=pytorch_lightning-0.7.1-cp36-none-any.whl size=145306 sha256=b91950c33dafa2922570c16134640063707f6daefcf5a3eb42493cd015644b94
  Stored in directory: /root/.cache/pip/wheels/dc/93/61/14094d2116ff739513dda993007501ae5701b78386b39d5912
  Building wheel for future (setup.py) ... [?25l[?25hdone
  Created wheel for future: 

In [None]:
import pytorch_lightning as pl

In [None]:
from pointnet2_ops.pointnet2_modules import PointnetFPModule, PointnetSAModule

In [None]:
lr_clip = 1e-5
bnm_clip = 1e-2


class PointNet2ClassificationSSG(pl.LightningModule):
    def __init__(self, args):
        super().__init__()

        self.hparams = args

        self._build_model()

    def _build_model(self):
        self.SA_modules = nn.ModuleList()
        self.SA_modules.append(
            PointnetSAModule(
                npoint=512,
                radius=0.2,
                nsample=64,
                mlp=[3, 64, 64, 128],
                use_xyz=self.hparams.model.use_xyz,
            )
        )
        self.SA_modules.append(
            PointnetSAModule(
                npoint=128,
                radius=0.4,
                nsample=64,
                mlp=[128, 128, 128, 256],
                use_xyz=self.hparams.model.use_xyz,
            )
        )
        self.SA_modules.append(
            PointnetSAModule(
                mlp=[256, 256, 512, 1024], use_xyz=self.hparams.model.use_xyz
            )
        )

        self.fc_layer = nn.Sequential(
            nn.Linear(1024, 512, bias=False),
            nn.BatchNorm1d(512),
            nn.ReLU(True),
            nn.Linear(512, 256, bias=False),
            nn.BatchNorm1d(256),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(256, 40),
        )
        
    def _break_up_pc(self, pc):
        xyz = pc[..., 0:3].contiguous()
        features = pc[..., 3:].transpose(1, 2).contiguous() if pc.size(-1) > 3 else None

        return xyz, features

    def forward(self, pointcloud):
        r"""
            Forward pass of the network
            Parameters
            ----------
            pointcloud: Variable(torch.cuda.FloatTensor)
                (B, N, 3 + input_channels) tensor
                Point cloud to run predicts on
                Each point in the point-cloud MUST
                be formated as (x, y, z, features...)
        """
        xyz, features = self._break_up_pc(pointcloud)

        for module in self.SA_modules:
            xyz, features = module(xyz, features)

        return self.fc_layer(features.squeeze(-1))

    def training_step(self, batch, batch_idx):
        pc, labels = batch

        logits = self.forward(pc)
        loss = F.cross_entropy(logits, labels)
        with torch.no_grad():
            acc = (torch.argmax(logits, dim=1) == labels).float().mean()

        log = dict(train_loss=loss, train_acc=acc)

        return dict(loss=loss, log=log, progress_bar=dict(train_acc=acc))

    def validation_step(self, batch, batch_idx):
        pc, labels = batch

        logits = self.forward(pc)
        loss = F.cross_entropy(logits, labels)
        acc = (torch.argmax(logits, dim=1) == labels).float().mean()

        return dict(val_loss=loss, val_acc=acc)

    def validation_end(self, outputs):
        reduced_outputs = {}
        for k in outputs[0]:
            for o in outputs:
                reduced_outputs[k] = reduced_outputs.get(k, []) + [o[k]]

        for k in reduced_outputs:
            reduced_outputs[k] = torch.stack(reduced_outputs[k]).mean()

        reduced_outputs.update(
            dict(log=reduced_outputs.copy(), progress_bar=reduced_outputs.copy())
        )

        return reduced_outputs

    def configure_optimizers(self):
        lr_lbmd = lambda _: max(
            self.hparams.optimizer.lr_decay
            ** (
                int(
                    self.global_step
                    * self.hparams.batch_size
                    / self.hparams.optimizer.decay_step
                )
            ),
            lr_clip / self.hparams.optimizer.lr,
        )
        bn_lbmd = lambda _: max(
            self.hparams.optimizer.bn_momentum
            * self.hparams.optimizer.bnm_decay
            ** (
                int(
                    self.global_step
                    * self.hparams.batch_size
                    / self.hparams.optimizer.decay_step
                )
            ),
            bnm_clip,
        )

        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=self.hparams.optimizer.lr,
            weight_decay=self.hparams.optimizer.weight_decay,
        )
        lr_scheduler = lr_sched.LambdaLR(optimizer, lr_lambda=lr_lbmd)
        bnm_scheduler = BNMomentumScheduler(self, bn_lambda=bn_lbmd)

        return [optimizer], [lr_scheduler, bnm_scheduler]

    def _build_dataloader(self, mode="train"):
        train_transforms = transforms.Compose(
            [
                d_utils.PointcloudToTensor(),
                d_utils.PointcloudScale(),
                d_utils.PointcloudRotate(),
                d_utils.PointcloudRotatePerturbation(),
                d_utils.PointcloudTranslate(),
                d_utils.PointcloudJitter(),
                d_utils.PointcloudRandomInputDropout(),
            ]
        )

        dset = ModelNet40Cls(
            self.hparams.num_points,
            transforms=train_transforms if mode == "train" else None,
            train=mode == "train",
        )
        return DataLoader(
            dset,
            batch_size=self.hparams.batch_size,
            shuffle=mode == "train",
            num_workers=4,
            pin_memory=True,
            drop_last=mode == "train",
        )

    @pl.data_loader
    def train_dataloader(self):
        return self._build_dataloader(mode="train")

    @pl.data_loader
    def val_dataloader(self):
        return self._build_dataloader(mode="val")



In [None]:
from typing import List, Optional, Tuple
def build_shared_mlp(mlp_spec: List[int], bn: bool = True):
    layers = []
    for i in range(1, len(mlp_spec)):
        layers.append(
            nn.Conv2d(mlp_spec[i - 1], mlp_spec[i], kernel_size=1, bias=not bn)
        )
        if bn:
            layers.append(nn.BatchNorm2d(mlp_spec[i]))
        layers.append(nn.ReLU(True))

    return nn.Sequential(*layers)


class _PointnetSAModuleBase(nn.Module):
    def __init__(self):
        super(_PointnetSAModuleBase, self).__init__()
        self.npoint = None
        self.groupers = None
        self.mlps = None

    def forward(
        self, xyz: torch.Tensor, features: Optional[torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""
        Parameters
        ----------
        xyz : torch.Tensor
            (B, N, 3) tensor of the xyz coordinates of the features (1*1024*3)
        features : torch.Tensor
            (B, C, N) tensor of the descriptors of the the features  (1*C*1024)
        Returns
        -------
        new_xyz : torch.Tensor
            (B, npoint, 3) tensor of the new features' xyz  (1*512*3)
        new_features : torch.Tensor
            (B,  \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors (1*40*512)
        """

        new_features_list = []

        xyz_flipped = xyz.transpose(1, 2).contiguous()
        new_xyz = (
            pointnet2_utils.gather_operation(
                xyz_flipped, pointnet2_utils.furthest_point_sample(xyz, self.npoint)
            )
            .transpose(1, 2)
            .contiguous()
            if self.npoint is not None
            else None
        )

        for i in range(len(self.groupers)):
            new_features = self.groupers[i](
                xyz, new_xyz, features
            )  # (B, C, npoint, nsample)

            new_features = self.mlps[i](new_features)  # (B, mlp[-1], npoint, nsample)
            new_features = F.max_pool2d(
                new_features, kernel_size=[1, new_features.size(3)]
            )  # (B, mlp[-1], npoint, 1)
            new_features = new_features.squeeze(-1)  # (B, mlp[-1], npoint)

            new_features_list.append(new_features)

        return new_xyz, torch.cat(new_features_list, dim=1)


class PointnetSAModuleMSG(_PointnetSAModuleBase):
    r"""Pointnet set abstrction layer with multiscale grouping
    Parameters
    ----------
    npoint : int
        Number of features
    radii : list of float32
        list of radii to group with
    nsamples : list of int32
        Number of samples in each ball query
    mlps : list of list of int32
        Spec of the pointnet before the global max_pool for each scale
    bn : bool
        Use batchnorm
    """

    def __init__(self, npoint, radii, nsamples, mlps, bn=True, use_xyz=True):
        # type: (PointnetSAModuleMSG, int, List[float], List[int], List[List[int]], bool, bool) -> None
        super(PointnetSAModuleMSG, self).__init__()

        assert len(radii) == len(nsamples) == len(mlps)

        self.npoint = npoint
        self.groupers = nn.ModuleList()
        self.mlps = nn.ModuleList()
        for i in range(len(radii)):
            radius = radii[i]
            nsample = nsamples[i]
            self.groupers.append(
                pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
                if npoint is not None
                else pointnet2_utils.GroupAll(use_xyz)
            )
            mlp_spec = mlps[i]
            if use_xyz:
                mlp_spec[0] += 3

            self.mlps.append(build_shared_mlp(mlp_spec, bn))

In [None]:
#PointnetSAModule： set abstrction layer 的参数
class PointnetSAModule(PointnetSAModuleMSG):
    r"""Pointnet set abstrction layer
    Parameters
    ----------
    npoint : int
        Number of features 数目
    radius : float
        Radius of ball 半径
    nsample : int
        Number of samples in the ball query 半径内的点的数目
    mlp : list
        Spec of the pointnet before the global max_pool
    bn : bool
        Use batchnorm
    """

    def __init__(
        self, mlp, npoint=None, radius=None, nsample=None, bn=True, use_xyz=True
    ):
        # type: (PointnetSAModule, List[int], int, float, int, bool, bool) -> None
        super(PointnetSAModule, self).__init__(
            mlps=[mlp],
            npoint=npoint,
            radii=[radius],
            nsamples=[nsample],
            bn=bn,
            use_xyz=use_xyz,
        )

In [None]:
class PointnetFPModule(nn.Module):
    r"""Propigates the features of one set to another
    Parameters
    ----------
    mlp : list
        Pointnet module parameters
    bn : bool
        Use batchnorm
    """

    def __init__(self, mlp, bn=True):
        # type: (PointnetFPModule, List[int], bool) -> None
        super(PointnetFPModule, self).__init__()
        self.mlp = build_shared_mlp(mlp, bn=bn)

    def forward(self, unknown, known, unknow_feats, known_feats):
        # type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
        r"""
        Parameters
        ----------
        unknown : torch.Tensor
            (B, n, 3) tensor of the xyz positions of the unknown features
        known : torch.Tensor
            (B, m, 3) tensor of the xyz positions of the known features
        unknow_feats : torch.Tensor
            (B, C1, n) tensor of the features to be propigated to
        known_feats : torch.Tensor
            (B, C2, m) tensor of features to be propigated
        Returns
        -------
        new_features : torch.Tensor
            (B, mlp[-1], n) tensor of the features of the unknown features
        """

        if known is not None:
            dist, idx = pointnet2_utils.three_nn(unknown, known)
            dist_recip = 1.0 / (dist + 1e-8)
            norm = torch.sum(dist_recip, dim=2, keepdim=True)
            weight = dist_recip / norm

            interpolated_feats = pointnet2_utils.three_interpolate(
                known_feats, idx, weight
            )
        else:
            interpolated_feats = known_feats.expand(
                *(known_feats.size()[0:2] + [unknown.size(1)])
            )

        if unknow_feats is not None:
            new_features = torch.cat(
                [interpolated_feats, unknow_feats], dim=1
            )  # (B, C2 + C1, n)
        else:
            new_features = interpolated_feats

        new_features = new_features.unsqueeze(-1)
        new_features = self.mlp(new_features)

        return new_features.squeeze(-1)

In [None]:
class PointNet2ClassificationSSG(pl.LightningModule):
    def __init__(self, args):
        super().__init__()

        self.hparams = args

        self._build_model()

    def _build_model(self):
        self.SA_modules = nn.ModuleList()
        self.SA_modules.append(
            PointnetSAModule(
                npoint=512,
                radius=0.2,
                nsample=64,
                mlp=[3, 64, 64, 128],
                use_xyz=self.hparams.model.use_xyz,
            )
        )
        self.SA_modules.append(
            PointnetSAModule(
                npoint=128,
                radius=0.4,
                nsample=64,
                mlp=[128, 128, 128, 256],
                use_xyz=self.hparams.model.use_xyz,
            )
        )
        self.SA_modules.append(
            PointnetSAModule(
                mlp=[256, 256, 512, 1024], use_xyz=self.hparams.model.use_xyz
            )
        )

        self.fc_layer = nn.Sequential(
            nn.Linear(1024, 512, bias=False),
            nn.BatchNorm1d(512),
            nn.ReLU(True),
            nn.Linear(512, 256, bias=False),
            nn.BatchNorm1d(256),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(256, 40),
        )
        
    def _break_up_pc(self, pc):
        xyz = pc[..., 0:3].contiguous()
        features = pc[..., 3:].transpose(1, 2).contiguous() if pc.size(-1) > 3 else None

        return xyz, features

    def forward(self, pointcloud):
        r"""
            Forward pass of the network
            Parameters
            ----------
            pointcloud: Variable(torch.cuda.FloatTensor)
                (B, N, 3 + input_channels) tensor
                Point cloud to run predicts on
                Each point in the point-cloud MUST
                be formated as (x, y, z, features...)
        """
        xyz, features = self._break_up_pc(pointcloud)

        for module in self.SA_modules:
            xyz, features = module(xyz, features)

        return self.fc_layer(features.squeeze(-1))

In [None]:
dummy_point_cloud = torch.rand(2, 512, 3)
dummy_object_classes = torch.randint(high=40, size=(2,))

In [None]:
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
def square_distance(src, dst):
    """
    Calculate Euclid distance between each two points.
    src^T * dst = xn * xm + yn * ym + zn * zm；
    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
    Input:
        src: source points, [B, N, C]
        dst: target points, [B, M, C]
    Output:
        dist: per-point square distance, [B, N, M]
    """
    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
    dist += torch.sum(src ** 2, -1).view(B, N, 1)
    dist += torch.sum(dst ** 2, -1).view(B, 1, M)
    return dist

In [None]:
a0 = torch.rand(1, 512, 3)
b0 = torch.rand(1, 512, 3)


In [None]:
dist0 = square_distance(a0, b0)

In [None]:
dist0.shape

torch.Size([1, 512, 512])

In [None]:
b1 = torch.rand(1, 64, 3)
dist1 = square_distance(a0, b1)
dist1.shape

torch.Size([1, 512, 64])

In [None]:

torch.sum(a0, -1).shape

torch.Size([1, 512])

In [None]:
def index_points(points, idx):
    """
    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S] （或者是[B X S]
    Return:
        new_points:, indexed points data, [B, S, C]
    """
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
    batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
    new_points = points[batch_indices, idx, :]
    return new_points

In [None]:
[1]*2

[1, 1]

In [None]:
a = torch.arange(5)
print('a：', a)

a： tensor([0, 1, 2, 3, 4])


In [None]:
a.view(5, 1, 1)

tensor([[[0]],

        [[1]],

        [[2]],

        [[3]],

        [[4]]])

In [None]:
a.view(5, 1, 1).repeat(1, 2, 3)

tensor([[[0, 0, 0],
         [0, 0, 0]],

        [[1, 1, 1],
         [1, 1, 1]],

        [[2, 2, 2],
         [2, 2, 2]],

        [[3, 3, 3],
         [3, 3, 3]],

        [[4, 4, 4],
         [4, 4, 4]]])

In [None]:
idx0 = torch.arange(512).repeat(3, 1)
view_shape = list(idx0.shape)
print(view_shape)

[3, 512]


In [None]:
view_shape[1:] = [1] * (len(view_shape) - 1)
print(view_shape)

[3, 1]


In [None]:
repeat_shape = list(idx0.shape)
repeat_shape[0] = 1
print('repeat_shape', repeat_shape)

repeat_shape [1, 512]


In [None]:
batch_indices = torch.arange(3, dtype=torch.long).view(view_shape).repeat(repeat_shape)
print('batch_indices', batch_indices)
print(batch_indices.shape)

batch_indices tensor([[0, 0, 0,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        [2, 2, 2,  ..., 2, 2, 2]])
torch.Size([3, 512])


In [None]:
batch_indices.long(), idx0.long()

(tensor([[0, 0, 0,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 1, 1, 1],
         [2, 2, 2,  ..., 2, 2, 2]]),
 tensor([[  0,   1,   2,  ..., 509, 510, 511],
         [  0,   1,   2,  ..., 509, 510, 511],
         [  0,   1,   2,  ..., 509, 510, 511]]))

In [None]:
points0 = torch.rand(3, 1024, 3)
new_points = points0[batch_indices, idx0, :]
print('points0', points0)
print('new_points', new_points)

points0 tensor([[[0.6989, 0.9967, 0.0904],
         [0.7578, 0.9562, 0.7620],
         [0.5944, 0.6421, 0.4615],
         ...,
         [0.9596, 0.6343, 0.1422],
         [0.8262, 0.3135, 0.7420],
         [0.1190, 0.4086, 0.6637]],

        [[0.4742, 0.6821, 0.8652],
         [0.7306, 0.0478, 0.6693],
         [0.2708, 0.3941, 0.0818],
         ...,
         [0.8244, 0.9031, 0.5429],
         [0.0101, 0.4054, 0.8145],
         [0.8775, 0.2921, 0.9876]],

        [[0.4233, 0.5617, 0.6352],
         [0.9891, 0.0732, 0.9106],
         [0.6275, 0.7898, 0.8102],
         ...,
         [0.6580, 0.8335, 0.8425],
         [0.2947, 0.7030, 0.7679],
         [0.5393, 0.8973, 0.8553]]])
new_points tensor([[[0.6989, 0.9967, 0.0904],
         [0.7578, 0.9562, 0.7620],
         [0.5944, 0.6421, 0.4615],
         ...,
         [0.5953, 0.0190, 0.9683],
         [0.0036, 0.2304, 0.4267],
         [0.9336, 0.9481, 0.4985]],

        [[0.4742, 0.6821, 0.8652],
         [0.7306, 0.0478, 0.6693],
       

In [None]:
def farthest_point_sample(xyz, npoint):
    """
    Input:
        xyz: pointcloud data, [B, N, 3]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [B, npoint]
        随机选取一个点为第一个点，找出与这个点最远的点（为第二个点），
        再找出与第二个点距离最远的点（为第三个点）。。。直到找齐512个点
        ？？？如果有两个点互为最远的点呢？？？

    """
    device = xyz.device
    B, N, C = xyz.shape
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)  
                    #[B npoints] 个0 用于储存farthest点的index
    distance = torch.ones(B, N).to(device) * 1e10           # [B N]
    farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 
                             #在0~N中选出B个随机值作为初始位置[B]
    batch_indices = torch.arange(B, dtype=torch.long).to(device)
    for i in range(npoint):
        centroids[:, i] = farthest #用于储存farthest点的index
        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)  
                                 #每个Batch上 初始的中心点[B 1 3]
        dist = torch.sum((xyz - centroid) ** 2, -1) #每个点与中心点的距离 [B N]
        mask = dist < distance  #该距离小于10^10的点的位置为true
        distance[mask] = dist[mask]  #距离小于10^10的点的位置 distance上的1换成dist上的值
        farthest = torch.max(distance, -1)[1]
    return centroids



In [None]:
import torch

In [None]:
xyz = torch.rand(1, 1024, 3)
centroids = torch.zeros(1, 512)
distance = torch.ones(1, 1024)
farthest = torch.randint(0, 1024, (2,), dtype=torch.long)

In [None]:
print(farthest)

tensor([764, 287])


In [None]:
centroid = xyz[0 , farthest, :].view(1, 1, 3)

In [None]:
d0 = (xyz - centroid) ** 2
d0.shape

torch.Size([1, 1024, 3])

In [None]:
dist = torch.sum((xyz - centroid) ** 2, -1)
print(dist.shape)

torch.Size([1, 1024])


In [None]:
#mask = dist < distance 
#distance[mask] = dist[mask]

In [None]:
distance0 = torch.ones(10).long()
dist0 = torch.arange(10)
print(distance0)
print(dist0)

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])


In [None]:
dist1 = torch.arange(0, 20, step=2)
print(dist1)

tensor([ 0,  2,  4,  6,  8, 10, 12, 14, 16, 18])


In [None]:
mask = dist0 < dist1
print(mask)

tensor([False,  True,  True,  True,  True,  True,  True,  True,  True,  True])


In [None]:
distance0[mask] = dist0[mask]
print(distance0)

tensor([1, 1, 2, 3, 4, 5, 6, 7, 8, 9])


In [None]:
def query_ball_point(radius, nsample, xyz, new_xyz):
    """
    Input:
        radius: local region radius
        nsample: max sample number in local region
        xyz: all points, [B, N, 3]
        new_xyz: query points, [B, S, 3]
    Return:
        group_idx: grouped points index, [B, S, nsample]
    """
    device = xyz.device
    B, N, C = xyz.shape
    _, S, _ = new_xyz.shape
    group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 
                                             #1*512*1024
    sqrdists = square_distance(new_xyz, xyz)
    group_idx[sqrdists > radius ** 2] = N #两点之间的距离如果大于半径， 则idx=1024

    group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 
    #[0]是数据，[1]是index 
    #升序 从小到大，取最小的前两个（1*512*2）

    group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
    mask = group_idx == N #true的点：对某个点来说，所有的点都在r之外。 其他的点为false
    group_idx[mask] = group_first[mask] #true的点被换成原GROUP_Idx中第一个点 ？？？？

    return group_idx

In [None]:
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]

In [None]:
group_idx0 = query_ball_point(0.1, 2, points0, new_points)
print(group_idx0.shape)


torch.Size([3, 512, 2])


In [None]:
def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
    """
    Input:
        npoint:
        radius:
        nsample:
        xyz: input points position data, [B, N, 3]
        points: input points data, [B, N, D] ：带着每个点的feature？
    Return:
        new_xyz: sampled points position data, [B, npoint, nsample, 3]
        new_points: sampled points data, [B, npoint, nsample, 3+D]
    """
    B, N, C = xyz.shape
    S = npoint
    fps_idx = farthest_point_sample(xyz, npoint) # 各个中心点的idx [B, npoint]
    torch.cuda.empty_cache()
    new_xyz = index_points(xyz, fps_idx) #各个中心点的坐标 [B, npoint， C]
    torch.cuda.empty_cache()
    idx = query_ball_point(radius, nsample, xyz, new_xyz) #各个中心点对应的neighbor的index [B, npoint, nsample]
    torch.cuda.empty_cache()
    grouped_xyz = index_points(xyz, idx) # 各个中心点对应的neighbor的坐标，[B, npoint, nsample, C]
    torch.cuda.empty_cache()
    grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) #
    torch.cuda.empty_cache()

    if points is not None:
        grouped_points = index_points(points, idx)
        new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) 
               # [B, npoint, nsample, C+D] 把feature接在坐标后面
    else:
        new_points = grouped_xyz_norm
    if returnfps:
        return new_xyz, new_points, grouped_xyz, fps_idx
    else:
        return new_xyz, new_points


In [None]:
def sample_and_group_all(xyz, points):
    """
    Input:
        xyz: input points position data, [B, N, 3]
        points: input points data, [B, N, D]
    Return:
        new_xyz: sampled points position data, [B, 1, 3] 一个全是0的tensor？？
        new_points: sampled points data, [B, 1, N, 3+D]
    """
    device = xyz.device
    B, N, C = xyz.shape
    new_xyz = torch.zeros(B, 1, C).to(device) 
    grouped_xyz = xyz.view(B, 1, N, C)
    if points is not None:
        new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
    else:
        new_points = grouped_xyz
    return new_xyz, new_points

In [None]:
class PointNetSetAbstraction(nn.Module):
    def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
        super(PointNetSetAbstraction, self).__init__()
        self.npoint = npoint
        self.radius = radius
        self.nsample = nsample
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        last_channel = in_channel #3
        for out_channel in mlp: # [64, 128, 128, 256] --->64
            self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) #[3, 64]
            self.mlp_bns.append(nn.BatchNorm2d(out_channel)) [64]
            last_channel = out_channel [64]
        self.group_all = group_all  #将整个点云group

    def forward(self, xyz, points):
        """
        Input:
            xyz: input points position data, [B, C, N]
            points: input points data, [B, D, N]
            mlp:一组通道数如 [64，128，128，256]
          
        Return:
            new_xyz: sampled points position data, [B, C, S]
            new_points_concat: sample points feature data, [B, D', S]
        """
        xyz = xyz.permute(0, 2, 1)
        if points is not None:
            points = points.permute(0, 2, 1)

        if self.group_all:
            new_xyz, new_points = sample_and_group_all(xyz, points)
        else:
            new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
        # new_xyz: sampled points position data, [B, npoint, C]
        # new_points: sampled points data, [B, npoint, nsample, C+D]
        new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] 通道数为第二个，另外两个的顺序有关系吗？？？？
        for i, conv in enumerate(self.mlp_convs): 
          #i 为计数，
            bn = self.mlp_bns[i]
            new_points =  F.relu(bn(conv(new_points)))

        new_points = torch.max(new_points, 2)[0]
        new_xyz = new_xyz.permute(0, 2, 1)
        return new_xyz, new_points

**x**   
[ batch_size, channels, height_1, width_1 ]  
batch_size 一个batch中样例的个数       2  
channels 通道数，也就是当前层的深度 1  
height_1, 图片的高  7  
width_1, 图片的宽 3

**Conv2d的参数**   
[ channels, output, height_2, width_2 ]

channels, 通道数，和上面保持一致，也就是当前层的深度  1  
output 输出的深度                                 8  
height_2, 过滤器filter的高                                                      2  
width_2, 过滤器filter的宽


In [None]:
class PointNetSetAbstractionMsg(nn.Module):
    def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
        super(PointNetSetAbstractionMsg, self).__init__()
        self.npoint = npoint 
        self.radius_list = radius_list
        self.nsample_list = nsample_list
        self.conv_blocks = nn.ModuleList()
        self.bn_blocks = nn.ModuleList()
        for i in range(len(mlp_list)):
            convs = nn.ModuleList()
            bns = nn.ModuleList()
            last_channel = in_channel + 3
            for out_channel in mlp_list[i]:
                convs.append(nn.Conv2d(last_channel, out_channel, 1))
                bns.append(nn.BatchNorm2d(out_channel))
                last_channel = out_channel
            self.conv_blocks.append(convs)
            self.bn_blocks.append(bns)

    def forward(self, xyz, points):
        """
        Input:
            xyz: input points position data, [B, C, N]
            points: input points data, [B, D, N]
        Return:
            new_xyz: sampled points position data, [B, C, S]
            new_points_concat: sample points feature data, [B, D', S]
        """
        xyz = xyz.permute(0, 2, 1) #[B 1024 3]
        if points is not None:
            points = points.permute(0, 2, 1)

        B, N, C = xyz.shape
        S = self.npoint       #512
        new_xyz = index_points(xyz, farthest_point_sample(xyz, S)) #最远距离的npoints个点[B 512 3]
        new_points_list = []
        for i, radius in enumerate(self.radius_list): #一组半径
            K = self.nsample_list[i] #每个半径值对应的sample个数为K
            group_idx = query_ball_point(radius, K, xyz, new_xyz)
            grouped_xyz = index_points(xyz, group_idx) #每个半径和k对应的点云分组
            grouped_xyz -= new_xyz.view(B, S, 1, C)  #减去中心点 norm
            if points is not None:
                grouped_points = index_points(points, group_idx)
                grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
            else:
                grouped_points = grouped_xyz

            grouped_points = grouped_points.permute(0, 3, 2, 1)  # [B, D, K, S]
            for j in range(len(self.conv_blocks[i])):     #每一组grouping都提取特征
                conv = self.conv_blocks[i][j]
                bn = self.bn_blocks[i][j]
                grouped_points =  F.relu(bn(conv(grouped_points)))
            new_points = torch.max(grouped_points, 2)[0]  # [B, D', S]
            new_points_list.append(new_points) #特征所组成的list

        new_xyz = new_xyz.permute(0, 2, 1)
        new_points_concat = torch.cat(new_points_list, dim=1)
        return new_xyz, new_points_concat

In [None]:
class PointNetFeaturePropagation(nn.Module):
    def __init__(self, in_channel, mlp):
        super(PointNetFeaturePropagation, self).__init__()
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        last_channel = in_channel
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
            self.mlp_bns.append(nn.BatchNorm1d(out_channel))
            last_channel = out_channel

    def forward(self, xyz1, xyz2, points1, points2):
        """
        Input:
            xyz1: input points position data, [B, C, N]
            xyz2: sampled input points position data, [B, C, S]
            points1: input points data, [B, D, N]
            points2: input points data, [B, D, S]  ：feature
        Return:
            new_points: upsampled points data, [B, D', N]
        """
        xyz1 = xyz1.permute(0, 2, 1)
        xyz2 = xyz2.permute(0, 2, 1)

        points2 = points2.permute(0, 2, 1)
        B, N, C = xyz1.shape
        _, S, _ = xyz2.shape

        if S == 1:
            interpolated_points = points2.repeat(1, N, 1)
        else:
            dists = square_distance(xyz1, xyz2) #[B N M]
            dists, idx = dists.sort(dim=-1)   #sampled点中与原input某一点最近的三个点
            dists, idx = dists[:, :, :3], idx[:, :, :3]  # [B, N, 3]

            dist_recip = 1.0 / (dists + 1e-8)
            norm = torch.sum(dist_recip, dim=2, keepdim=True) #除了被操作的dim维度值降为1，
                                        #其它维度与输入张量input相同
            weight = dist_recip / norm
            interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
                                 #[B N 3 D] * [B N 3 1] = [B N 3 D], SUM:[B N D]
        if points1 is not None:
            points1 = points1.permute(0, 2, 1)
            new_points = torch.cat([points1, interpolated_points], dim=-1) #dim=-1 
        else:
            new_points = interpolated_points #把这三个点的特征当成该点的特征

        new_points = new_points.permute(0, 2, 1)
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            new_points = F.relu(bn(conv(new_points)))
        return new_points