In [9]:
import yaml  
  
# 读取YAML文件  
with open('config.yaml', 'r', encoding='utf-8') as file:  
    config = yaml.safe_load(file)  
  
# 提取字段  
ROOT_DIR = config['project_base_dir']  
ShapeNet_path = config['dataset']['ShapeNet_path'] 

In [4]:
import open3d as o3d
import numpy as np
import torch
from data_utils.ShapeNetDataLoader import PartNormalDataset
import argparse
import os
import logging
import sys
import importlib
from tqdm import tqdm
import yaml  

In [5]:
def parse_args():
    '''参数解析'''
    parser = argparse.ArgumentParser('PointNet')
    parser.add_argument('--batch_size', type=int, default=24, help='测试时的批处理大小')
    parser.add_argument('--gpu', type=str, default='0', help='指定GPU设备')
    parser.add_argument('--num_point', type=int, default=2048, help='点云中的点数')
    parser.add_argument('--log_dir', type=str, required=True, help='实验的根目录')
    parser.add_argument('--normal', action='store_true', default=False, help='是否使用法线信息')
    parser.add_argument('--num_votes', type=int, default=3, help='通过投票聚合分割分数的次数')
    parser.add_argument('--ply_file', type=str, required=True, help='待测试的PLY文件路径')
    return parser.parse_args()


In [6]:
def load_ply(ply_file, num_point):
    """读取PLY文件并将点云数据预处理成模型输入格式"""
    pcd = o3d.io.read_point_cloud(ply_file)
    points = np.asarray(pcd.points)
    
    # 如果点云的点数大于 num_point，随机抽取 num_point 个点；否则填充到 num_point
    if len(points) > num_point:
        indices = np.random.choice(len(points), num_point, replace=False)
        points = points[indices]
    elif len(points) < num_point:
        padding = np.zeros((num_point - len(points), 3))
        points = np.vstack((points, padding))
    
    # 归一化处理
    points -= np.mean(points, axis=0)
    points /= np.max(np.sqrt(np.sum(points**2, axis=1)))
    
    return torch.from_numpy(points).float()

In [10]:
def main(args):
    # 定义一个内部函数 log_string，用于记录日志信息并将信息打印到控制台
    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    # 设置CUDA可见的设备（指定使用的GPU）, 根据命令行参数设置
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    # 检查是否有可用的CUDA设备, 并设置计算设备（GPU或CPU）
    device = torch.device("cuda" if torch.cuda.is_available() and args.gpu != '-1' else "cpu")
    experiment_dir = 'log/part_seg/' + args.log_dir

    '''
    日志设置
    '''
    logger = logging.getLogger("Model")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/eval.txt' % experiment_dir)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    log_string('PARAMETER ...')
    log_string(args)

    # 加载测试数据集
    root = ShapeNet_path
    TEST_DATASET = PartNormalDataset(root=root, npoints=args.num_point, split='test', normal_channel=args.normal)
    testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=4)
    log_string("The number of test data is: %d" % len(TEST_DATASET))
    
    num_classes = 16
    num_part = 50

    '''MODEL LOADING'''
    model_name = os.listdir(experiment_dir + '/logs')[0].split('.')[0]
    MODEL = importlib.import_module(model_name)
    classifier = MODEL.get_model(num_part, normal_channel=args.normal).to(device)
    checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth')
    classifier.load_state_dict(checkpoint['model_state_dict'])

    with torch.no_grad():
        test_metrics = {}
        total_correct = 0
        total_seen = 0
        total_seen_class = [0 for _ in range(num_part)]
        total_correct_class = [0 for _ in range(num_part)]
        shape_ious = {cat: [] for cat in seg_classes.keys()}
        seg_label_to_cat = {label: cat for cat in seg_classes.keys() for label in seg_classes[cat]}

        classifier = classifier.eval()

        # 读取和测试PLY文件
        ply_points = load_ply(args.ply_file, args.num_point)
        ply_points = ply_points.unsqueeze(0).to(device)  # 添加批次维度

        # 获取标签为0的虚拟标签
        ply_label = torch.zeros(1, args.num_point).long().to(device)
        
        # 初始化投票池
        vote_pool = torch.zeros(1, args.num_point, num_part).to(device)
        for _ in range(args.num_votes):
            seg_pred, _ = classifier(ply_points, ply_label)
            vote_pool += seg_pred
        
        # 计算最终的预测结果
        seg_pred = vote_pool / args.num_votes
        cur_pred_val = seg_pred.cpu().data.numpy().squeeze()
        predicted_labels = np.argmax(cur_pred_val, axis=1)

        # 打印和记录预测结果
        log_string('PLY file prediction:')
        log_string(predicted_labels)

        # 计算和记录测试指标
        # (请根据需要计算准确率和IoU，下面的代码只是示例)
        total_correct += np.sum(predicted_labels == ply_label.cpu().data.numpy().squeeze())
        total_seen += args.num_point
        test_metrics['accuracy'] = total_correct / float(total_seen)

        log_string('Accuracy is: %.5f' % test_metrics['accuracy'])

if __name__ == '__main__':
    args = parse_args()
    main(args)

usage: PointNet [-h] [--batch_size BATCH_SIZE] [--gpu GPU]
                [--num_point NUM_POINT] --log_dir LOG_DIR [--normal]
                [--num_votes NUM_VOTES] --ply_file PLY_FILE
PointNet: error: the following arguments are required: --log_dir, --ply_file


SystemExit: 2