# Introduction

ModelNet40 数据集介绍
ModelNet40 是一个广泛用于3D物体识别和分类的点云数据集。它由普林斯顿大学维护，包含40个类别的三维模型，总共有12311个CAD模型。这些模型被分为训练集和测试集，其中训练集包含9843个模型，测试集包含2468个模型

在本文中将实现一个 PointNet++ 进行训练和测试

## 数据库参考

https://www.kaggle.com/datasets/balraj98/modelnet40-princeton-3d-object-dataset

这个数据库里有几个文件夹的名称有问题，如果是在线使用的话可能需要在代码里进行判断和修改

## 代码实现

### 首先是实现 Dataset 和 DataLoader

In [None]:
import os
import torch
import pandas as pd
import numpy as np
# 进度条库
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader

root_dir = "/kaggle/input/modelnet40-princeton-3d-object-dataset/ModelNet40"
csv_file = "/kaggle/input/modelnet40-princeton-3d-object-dataset/metadata_modelnet40.csv"

class ModelNet40Dataset(Dataset):
    def __init__(self, csv_file, root_dir, num_points=2048, transform=None):
        self.data_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.num_points = num_points
        self.transform = transform
        self.label_to_idx = { label : idx for idx, label in enumerate(sorted(set(pd.read_csv(csv_file)['class']))) }

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        # 如果传入的参数 idx 是一个张量，就将其转化成数组进行处理
        if torch.is_tensor(idx):
            idx = idx.tolist()
        label = self.data_frame.iloc[idx, 1]
        path = self.data_frame.iloc[idx, 3]
        if label == 'tv':
            path = 'tv_stand' + self.data_frame.iloc[idx, 3][2:]
        elif label == 'night':
            path = 'night_stand' + self.data_frame.iloc[idx, 3][5:]
        elif label == 'range':
            path = 'range_hood' + self.data_frame.iloc[idx, 3][5:]
        elif label == 'glass':
            path = 'glass_box' + self.data_frame.iloc[idx, 3][5:]
        elif label == 'flower':
            path = 'flower_pot' + self.data_frame.iloc[idx, 3][6:]
        object_path = os.path.join(self.root_dir, path)
        label = self.label_to_idx[label]

        points = self.read_off(object_path)

        if len(points) < self.num_points:
            # 如果点的数量没有达到标准，就进行上采样
            print("points is too less ", len(points))
            indices = np.random.choice(len(points), size=self.num_points, replace=True)
            points = points[indices]
        elif len(points) > self.num_points:
            # 如果点的数量超过了标准，就进行下采样
            print("points is too large ", len(points))
            points = points[np.random.choice(len(points), size=self.num_points, replace=False)]
        
        points = torch.from_numpy(points).float()
        label = torch.tensor(label).long()

        if self.transform:
            points = self.transform(points)
        return points, label

    def read_off(self, file_path):
        with open(file_path, 'r') as f:
            lines = f.readlines()

        start_index = 2

        if lines[0].strip() != 'OFF':
            if lines[0].strip()[0:3] == 'OFF':
                remain = lines[0].strip()[3:]
                num_vertices = int(remain.split()[0])
                start_index = 1
            else:
                print(file_path)
                raise ValueError('Invalid OFF file format')
        else:
            num_vertices = int(lines[1].split()[0])
        vertices = []
        for i in range(start_index, start_index + num_vertices):
            vertex = list(map(float, lines[i].split()))
            vertices.append(vertex)
        return np.array(vertices)

### 接下来实现一些辅助的函数

主要是最远点采样



In [None]:
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# ------------------- 核心辅助函数 -------------------
def farthest_point_sample(xyz, npoint):
    device = xyz.device
    B, N, C = xyz.shape
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
    distance = torch.ones(B, N).to(device) * 1e10
    farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
    
    for i in range(npoint):
        centroids[:, i] = farthest
        centroid = xyz[torch.arange(B), farthest, :].view(B, 1, 3)
        dist = torch.sum((xyz - centroid) ** 2, -1)
        mask = dist < distance
        distance[mask] = dist[mask]
        farthest = torch.max(distance, -1)[1]
    return centroids

def index_points(points, idx):
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = [1] * len(view_shape)
    repeat_shape[1] = points.shape[1]
    batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
    new_points = points[batch_indices, idx, :]
    return new_points

def query_ball_point(radius, nsample, xyz, new_xyz):
    device = xyz.device
    B, S, C = new_xyz.shape
    _, N, _ = xyz.shape

    sqrdists = torch.cdist(new_xyz, xyz)
    idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
    idx[sqrdists > radius ** 2] = N
    idx = idx.sort(dim=-1)[0][:, :, :nsample]
    idx[idx == N] = 0  # 处理无效索引
    
    grouped_xyz = index_points(xyz, idx)
    grouped_xyz -= new_xyz.view(B, S, 1, C)
    return grouped_xyz, idx

### 接下来实现PointNet++模块

+ 包含3个 SetAbstraction 层，逐步下采样点云并提取特征
+ 全局特征经过3个券链接层进行分类
+ 使用 BatchNorm 和 Dropout 提升泛化能力


In [None]:
class PointNetSetAbstraction(nn.Module):
    def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all=False):
        super().__init__()
        self.npoint = npoint
        self.radius = radius
        self.nsample = nsample
        self.group_all = group_all
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        last_channel = in_channel
        
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
            self.mlp_bns.append(nn.BatchNorm2d(out_channel))
            last_channel = out_channel

    def forward(self, xyz, points):
        if self.group_all:
            new_xyz, new_points = sample_and_group_all(xyz, points)
        else:
            new_xyz = index_points(xyz, farthest_point_sample(xyz, self.npoint))
            grouped_xyz, grouped_points = query_ball_point(
                self.radius, self.nsample, xyz, new_xyz)
            
            if points is not None:
                grouped_points = index_points(points, idx)
                grouped_points = torch.cat([grouped_xyz, grouped_points], dim=-1)
            else:
                grouped_points = grouped_xyz

        grouped_points = grouped_points.permute(0, 3, 2, 1)
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            grouped_points = F.relu(bn(conv(grouped_points)))
        
        new_points = torch.max(grouped_points, 2)[0]
        return new_xyz, new_points

# ------------------- 完整模型 -------------------
class PointNet2Cls(nn.Module):
    def __init__(self, num_classes=40):
        super().__init__()
        self.sa1 = PointNetSetAbstraction(
            512, 0.2, 32, 3, [64, 64, 128], False)
        self.sa2 = PointNetSetAbstraction(
            128, 0.4, 64, 128+3, [128, 128, 256], False)
        self.sa3 = PointNetSetAbstraction(
            None, None, None, 256+3, [256, 512, 1024], True)
        
        self.fc1 = nn.Linear(1024, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.drop1 = nn.Dropout(0.4)
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.drop2 = nn.Dropout(0.5)
        self.fc3 = nn.Linear(256, num_classes)

    def forward(self, xyz):
        B, _, _ = xyz.size()
        xyz = xyz.permute(0, 2, 1)
        
        l1_xyz, l1_points = self.sa1(xyz, None)
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
        
        x = l3_points.view(B, 1024)
        x = self.drop1(F.relu(self.bn1(self.fc1(x))))
        x = self.drop2(F.relu(self.bn2(self.fc2(x))))
        x = self.fc3(x)
        return x

### 接下来添加测试和训练的函数

In [None]:
# 训练函数
def train(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    for data, labels in tqdm(train_loader):
        data = data.to(device).permute(0, 2, 1)  # [B, 3, N]
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    return total_loss/len(train_loader), 100.*correct/total

# 测试函数
def test(model, test_loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for data, labels in tqdm(test_loader):
            data = data.to(device).permute(0, 2, 1)
            labels = labels.to(device)
            outputs = model(data)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    return total_loss/len(test_loader), 100.*correct/total

### 主程序


In [None]:
if __name__ == "__main__":
    # 初始化
    dataset = ModelNet40Dataset(csv_file=csv_file, root_dir=root_dir, num_points=8192)

    train_indices = dataset.data_frame[dataset.data_frame['split'] == 'train'].index
    test_indices = dataset.data_frame[dataset.data_frame['split'] == 'test'].index

    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    test_dataset = torch.utils.data.Subset(dataset, test_indices)

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)    
    
    model = PointNet2Cls().to(config['device'])
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
    
    best_acc = 0
    for epoch in range(config['epochs']):
        train_loss, train_acc = train(model, train_loader, criterion, optimizer, config['device'])
        test_loss, test_acc = test(model, test_loader, criterion, config['device'])
        
        print(f'Epoch {epoch+1}/{config["epochs"]}')
        print(f'Train Loss: {train_loss:.4f} Acc: {train_acc:.2f}%')
        print(f'Test Loss: {test_loss:.4f} Acc: {test_acc:.2f}%')
        
        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), 'best_model.pth')
    
    print(f'Best Test Accuracy: {best_acc:.2f}%')