In [1]:
import torch
import numpy as np
# import open3d as o3d
from pointnet.model import PointNetCls

# 一. 加载数据

In [2]:
def load_ply(file_path):
    # 使用 open3d 读取 PLY 文件
    pcd = o3d.io.read_point_cloud(file_path)
    # 提取点云数据（x, y, z坐标）
    points = np.asarray(pcd.points)
    return points

def load_off(file_path):
    with open(file_path, 'r') as file:
        lines = file.readlines()
        # if lines[0].strip() != 'OFF':
        #     raise ValueError('Not a valid OFF header')

        parts = lines[1].strip().split()
        num_vertices = int(parts[0])

        vertices = []
        for i in range(2, 2 + num_vertices):
            vertex = list(map(float, lines[i].strip().split()))
            vertices.append(vertex)

        return np.array(vertices)

In [3]:
CLASS_MAP = {
    0: "airplane",
    1: "bag",
    2: "cap",
    3: "car",
    4: "chair",
    5: "earphone",
    6: "guitar",
    7: "knife",
    8: "lamp",
    9: "laptop",
    10: "motorbike",
    11: "mug",
    12: "pistol",
    13: "rocket",
    14: "skateboard",
    15: "table"
}

# 二. 预测

In [4]:
def preprocess_points(points, num_points=2500):
    if points.shape[0] < num_points:
        # 如果点数不足2500，进行补零
        points = np.pad(points, ((0, num_points - points.shape[0]), (0, 0)), mode='constant')
    elif points.shape[0] > num_points:
        # 如果点数超过2500，进行随机下采样
        indices = np.random.choice(points.shape[0], num_points, replace=False)
        points = points[indices]
    return points

def classify_point_cloud(points_arr, model_path):
    # 加载模型
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    classifier = PointNetCls(k=16)  # 根据你的数据集类别数调整k的值
    classifier = classifier.to(device)
    classifier.load_state_dict(torch.load(model_path))
    classifier.eval()

    # 加载和预处理点云数据
    points = preprocess_points(points_arr)
    points = torch.tensor(points, dtype=torch.float32).unsqueeze(0)  # 添加批次维度

    # 转置点云数据以符合网络输入格式
    points = points.transpose(2, 1).to(device)

    # 模型推断
    with torch.no_grad():
        pred, _, _ = classifier(points)
        pred_choice = pred.data.max(1)[1]
        print(f"Predicted class index: {pred_choice.item()}, Predicted class name: {CLASS_MAP[pred_choice.item()]}")



# Z. 模型预测

## 测试1

In [7]:
model_path = './cls/cls_model_0.pth'
ply_file = './test_datesets.lee/chair_0011.off'
classify_point_cloud(load_off(ply_file), model_path)

Predicted class index: 8, Predicted class name: lamp


  classifier.load_state_dict(torch.load(model_path))
