In [1]:
import torch
import numpy as np
import os
from torch.utils.data import Dataset, DataLoader
from torch import nn
import torch.nn.functional as F
from ignite.metrics import Accuracy
from tqdm.notebook import tqdm

---
## 数据部分

In [2]:
class ModelNet40(Dataset):
    def __init__(self, root, split='train', npoints=1024):
        super(ModelNet40, self).__init__()
        self.npoints = npoints
        self.class_to_idx = {}

        with open(os.path.join(root, 'modelnet40_shape_names.txt'), 'r') as f:
            for i, line in enumerate(f):
                line = line.strip()
                self.class_to_idx[line] = i

        self.file_paths = []
        self.labels = []
        with open(os.path.join(root, 'modelnet40_'+split+'.txt'), 'r') as f:
            for line in f:
                line = line.strip()
                temp = line.split('_')
                self.file_paths.append(os.path.join(root, '_'.join(temp[0:-1]), line + '.txt'))
                self.labels.append(self.class_to_idx['_'.join(temp[0:-1])])
    
    def __len__(self):
        return len(self.file_paths)

    def pcd_norm(self, points):
        mean = points.mean(axis=0)
        points = points - mean
        max_dis = (np.sqrt((np.square(points)).sum(axis=1))).max()
        points = points / max_dis

        return points

    def __getitem__(self, index):
        file = self.file_paths[index]
        label = self.labels[index]

        points = np.genfromtxt(file, delimiter=',', dtype=np.float32)

        points = points[0:self.npoints, :]

        points[:, 0:3] = self.pcd_norm(points[:, 0:3])

        return points, label

In [3]:
train_dataset = ModelNet40('modelnet40_normal_resampled', split='train')
test_dataset = ModelNet40('modelnet40_normal_resampled', split='test')

In [4]:
print(train_dataset.class_to_idx)
print(len(train_dataset))
print(len(test_dataset))

{'airplane': 0, 'bathtub': 1, 'bed': 2, 'bench': 3, 'bookshelf': 4, 'bottle': 5, 'bowl': 6, 'car': 7, 'chair': 8, 'cone': 9, 'cup': 10, 'curtain': 11, 'desk': 12, 'door': 13, 'dresser': 14, 'flower_pot': 15, 'glass_box': 16, 'guitar': 17, 'keyboard': 18, 'lamp': 19, 'laptop': 20, 'mantel': 21, 'monitor': 22, 'night_stand': 23, 'person': 24, 'piano': 25, 'plant': 26, 'radio': 27, 'range_hood': 28, 'sink': 29, 'sofa': 30, 'stairs': 31, 'stool': 32, 'table': 33, 'tent': 34, 'toilet': 35, 'tv_stand': 36, 'vase': 37, 'wardrobe': 38, 'xbox': 39}
9843
2468


In [5]:
pcd, gt = train_dataset[0]

In [6]:
# 检验norm是否正确
mean = pcd.mean(axis=0)
print(mean)
dis = np.sqrt((np.square(pcd[:, 0:3])).sum(axis=1))
print(dis)
max_dis = dis.max()
print(max_dis)

[ 5.6534191e-09 -3.8970029e-08 -2.0793777e-08 -2.2611087e-02
  6.2716450e-03  6.5562748e-03]
[0.18740387 0.9741179  0.99170816 ... 0.42549008 0.17417295 0.16141398]
1.0


In [7]:
print(type(pcd), type(gt))
print(pcd.shape, pcd.dtype)

<class 'numpy.ndarray'> <class 'int'>
(1024, 6) float32


In [8]:
train_dataloader = DataLoader(train_dataset, batch_size=24, shuffle=True, num_workers=8)
test_dataloader = DataLoader(test_dataset, batch_size=24, shuffle=False, num_workers=8)

In [9]:
pcds, labels = next(iter(train_dataloader))

In [10]:
print(pcds.shape, pcds.dtype)
print(labels.shape, labels.dtype)

torch.Size([24, 1024, 6]) torch.float32
torch.Size([24]) torch.int64


---
## 模型部分

In [11]:
class PointSetAbstractionLayer(nn.Module):
    def __init__(self, nsamples, radius, k, in_channels, mlp_units, is_group_all=False):
        super(PointSetAbstractionLayer, self).__init__()
        self.nsamples = nsamples
        self.radius = radius
        self.k = k
        self.is_group_all = is_group_all
        
        mlp = [nn.Conv2d(in_channels, mlp_units[0], kernel_size=1),
                    nn.BatchNorm2d(mlp_units[0]),
                    nn.ReLU(inplace=True)]
        for i in range(len(mlp_units) - 1):
            mlp += [nn.Conv2d(mlp_units[i], mlp_units[i + 1], kernel_size=1),
                    nn.BatchNorm2d(mlp_units[i + 1]),
                    nn.ReLU(inplace=True)]

        self.mlp = nn.Sequential(*mlp)
    
    def fps(self, points):
        """
        points.shape = (b, n, 3)
        return indices.shape = (b, self.nsamples)
        """
        b, n, _ = points.shape
        device = points.device
        dis = torch.ones((b, n), device=device) * 1e10
        indices = torch.zeros((b, self.nsamples), device=device, dtype=torch.long)

        for i in range(1, self.nsamples):
            cur_index = indices[:, i - 1].view(b, 1, 1).repeat(1, 1, 3)
            cur_point = points.gather(1, cur_index)

            temp = (points - cur_point).square().sum(axis=2)
            mask = (temp < dis)
            dis[mask] = temp[mask]

            index = dis.argmax(dim=1)
            dis[list(range(b)), index] = 0
            indices[:, i] = index
        return indices


    def index_points(self, points, indices):
        """
        points.shape = (b, n, c)
        indices.shape = (b, self.nsamples) or (b, self.nsamples, k)
        return res.shape = (b, self.nsamples, c) or (b, self.nsamples, k, c)
        """
        _, _, c = points.shape
        if len(indices.shape) == 2:
            indices = indices.unsqueeze(dim=2).expand(-1, -1, c)
            res = points.gather(dim=1, index=indices)
        elif len(indices.shape) == 3:
            indices = indices.unsqueeze(dim=3).expand(-1, -1, -1, c)
            points = points.unsqueeze(dim=1).expand(-1, self.nsamples, -1, -1)
            res = points.gather(dim=2, index=indices)
        
        return res
    
    def index_points_2(self, points, idx):
        device = points.device
        B = points.shape[0]
        view_shape = list(idx.shape)
        view_shape[1:] = [1] * (len(view_shape) - 1)
        repeat_shape = list(idx.shape)
        repeat_shape[0] = 1
        batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
        new_points = points[batch_indices, idx, :]
        return new_points

    def get_square_distance(self, points_1, points_2):
        """
        points_1.shape = (b, n, 3)
        points_2.shape = (b, self.nsamples, 3)
        return res.shape = (b, self.nsampels, n)
        """
        b, n, _ = points_1.shape

        points_1 =points_1.view(b, 1, n, 3)
        points_2 =points_2.view(b, self.nsamples, 1, 3)   
        res = (points_2 -points_1).square().sum(dim=-1)   # 内部会自动做广播处理

        return res
    
    def group(self, points, features, centroids, distance):
        """
        points.shape = (b, n, 3)
        features.shape = (b, n, c)
        centroids.shape = (b, self.nsamples, 3)
        distance.shape = (b, self.nsampels, n)
        return res.shape = (b, self.nsamples, k, 3+c)
        """
        sorted_distance, indices = distance.sort(dim=-1)
        sorted_distance = sorted_distance[:, :, 0:self.k]
        indices = indices[:, :, 0:self.k]

        temp = indices[:, :, 0].unsqueeze(dim=2).repeat(1, 1, self.k)
        mask = (sorted_distance > self.radius ** 2)
        indices[mask] = temp[mask]

        group_points = self.index_points(points, indices)
        group_point_features = self.index_points(features, indices)

        # group_points = self.index_points_2(points, indices)
        # group_point_features = self.index_points_2(features, indices)

        temp = centroids.unsqueeze(dim=2)
        group_points = group_points - temp   # 要的是相对坐标

        res = torch.cat((group_points, group_point_features), dim=-1)
        return res

    def group_all(self, points, features):
        """
        points.shape = (b, n, 3)
        features.shape = (b, n, c)
        return centroids.shape = (b, 3, 1)
        return group_features.shape = (b, c', 1)
        """
        b, n, _ = points.shape
        device = points.device
        indices = torch.randint(0, n, (b, 1), device=device)

        centroids = self.index_points(points, indices)
        # centroids = self.index_points_2(points, indices)

        indices = torch.arange(0, n, device=device)
        indices = indices.view(1, 1, n).repeat(b, 1, 1)
        group_points = self.index_points(points, indices)
        group_features = self.index_points(features, indices)

        # group_points = self.index_points_2(points, indices)
        # group_features = self.index_points_2(features, indices)

        temp = centroids.unsqueeze(dim=2)
        group_points = group_points - temp

        group_features = torch.cat((group_points, group_features), dim=-1)
        group_features = group_features.permute(0, 3, 2, 1)

        group_features = self.mlp(group_features)
        group_features, _ = group_features.max(dim=2)

        centroids = centroids.permute(0, 2, 1)

        return centroids, group_features

    
    def forward(self, points, features):
        """
        points.shape = (b, 3, n)   坐标信息
        features.shape = (b, c, n)   特征信息
        return centroids.shape = (b, 3, self.nsamples)
        return group_features.shape = (b, c', self.nsamples)
        """
        points = points.permute(0, 2, 1)
        features = features.permute(0, 2, 1)
        if self.is_group_all:
            centroids, group_features = self.group_all(points, features)
            return centroids, group_features

        fps_indices = self.fps(points)

        centroids = self.index_points(points, fps_indices)
        # centroids = self.index_points_2(points, fps_indices)

        square_distance = self.get_square_distance(points, centroids)

        group_features = self.group(points, features, centroids, square_distance)
        group_features = group_features.permute(0, 3, 2, 1)

        group_features = self.mlp(group_features)
        group_features, _ = group_features.max(dim=2)

        centroids = centroids.permute(0, 2, 1)

        return centroids, group_features

In [12]:
class PointNetPlusPlus(nn.Module):
    def __init__(self, class_num):
        super(PointNetPlusPlus, self).__init__()
        self.sa1 = PointSetAbstractionLayer(512, 0.2, 32, 6+3, [64, 64, 128])
        self.sa2 = PointSetAbstractionLayer(128, 0.4, 64, 128+3, [128, 128, 256])
        self.sa3 = PointSetAbstractionLayer(1, None, None, 256+3, [256, 512, 1024], True)

        self.mlp = nn.Sequential(nn.Linear(1024, 512),
                                nn.BatchNorm1d(512),
                                nn.ReLU(inplace=True),
                                nn.Dropout(0.4),
                                nn.Linear(512, 256),
                                nn.BatchNorm1d(256),
                                nn.ReLU(inplace=True),
                                nn.Dropout(0.4),
                                nn.Linear(256, class_num))

    
    def forward(self, x):
        """
        x.shape = (b, n, 3+c)
        """
        x = x.permute(0, 2, 1)
        points = x[:, 0:3, :]
        features = x

        points_layer1, features_layer1 = self.sa1(points, features)
        points_layer2, features_layer2 = self.sa2(points_layer1, features_layer1)
        points_layer3, features_layer3 = self.sa3(points_layer2, features_layer2)

        final_features = features_layer3.squeeze()

        y = self.mlp(final_features)
        return y

In [13]:
def train_loop(dataloader, model, loss_fn, metric_fn, optimizer, device, cur_epoch, total_epoch, show_gap, interval):
    model.train()
    if cur_epoch % show_gap == 0:
        pbar = tqdm(dataloader, desc=f'Epoch {cur_epoch}/{total_epoch}', unit='batch')
    else:
        pbar = dataloader

    for i, (x, y) in enumerate(pbar):
        x = x.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()

        metric_fn.reset()
        metric_fn.update((y_pred, y))
        acc = metric_fn.compute()

        if cur_epoch % show_gap == 0 and i % interval == 0:
            pbar.set_postfix_str(f'loss={loss:.4f}, acc={acc:.4f}')

In [14]:
best_acc = 0
best_epoch = 0

def test_loop(dataloader, model, loss_fn, metric_fn, device, cur_epoch, path, show_gap):
    model.eval()
    steps = len(dataloader)
    loss = 0
    acc = 0

    with torch.no_grad():
        for x, y in dataloader:
            x = x.to(device)
            y = y.to(device)
            y_pred = model(x)
            loss += loss_fn(y_pred, y)

            metric_fn.reset()
            metric_fn.update((y_pred, y))
            acc += metric_fn.compute()
    loss = loss / steps
    acc = acc / steps

    global best_acc, best_epoch
    if acc >= best_acc:
        torch.save(model.state_dict(), path)
        best_acc = acc
        best_epoch = cur_epoch

    if cur_epoch % show_gap == 0:
        print(f'test_loss={loss:.4f}, test_acc={acc:.4f}')

---
## 开始训练

In [15]:
device = 'cuda:0'

pointnet = PointNetPlusPlus(40).to(device)
loss_fn = nn.CrossEntropyLoss()
metric_fn = Accuracy(device=device)
optimizer = torch.optim.Adam(pointnet.parameters(), lr=0.001, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 200)

In [16]:
epochs = 200
show_gap = 1
save_path = 'pointnet++_cls.pth'
for i in range(epochs):
    train_loop(train_dataloader, pointnet, loss_fn, metric_fn, optimizer, device, i, epochs, show_gap, 1)
    test_loop(test_dataloader, pointnet, loss_fn, metric_fn, device, i, save_path, show_gap)
    lr_scheduler.step()

Epoch 0/200:   0%|          | 0/411 [00:00<?, ?batch/s]

KeyboardInterrupt: 

In [None]:
print(best_acc, best_epoch)