# 🎯 3D点云补全案例 - ShapeNet网络 + PCN数据集

这是一个简单的点云补全演示项目，展示了如何使用深度学习技术从不完整的点云数据中重建完整的3D形状。

## 一、3D点云补全任务简介

3D点云补全（Point Cloud Completion）致力于在传感器采集存在遮挡、稀疏和噪声的现实条件下，重建目标或场景的完整几何。它位于“3D感知”技术栈中，承上游采集/重建（LiDAR、RGB‑D、SfM）之缺，启下游理解/交互（识别、定位、交互、仿真）之用。

- 定位（在3D感知流水线中的角色）
  - 上游：传感器采样不可避免的视角遮挡、距离衰减、材质反射导致点云缺失与不均匀。
  - 中游：补全模块依据先验（形状、结构、拓扑）推断缺失区域，输出致密、连贯的点集或隐式几何。
  - 下游：为检测/分割/姿态估计、配准/建图、物理仿真与渲染提供更完整、更鲁棒的输入。

- 作用（为什么需要补全）
  - 抗遮挡与抗稀疏：恢复关键结构，提高下游任务鲁棒性与精度。
  - 统一密度与尺度：为基于点的网络与NerF/隐式表面方法提供更稳定的几何表示。
  - 降噪与可视化：补齐轮廓、减少孔洞，便于测量、展示与制造。

- 典型应用场景
  - 自动驾驶/机器人：单帧或多帧LiDAR补全，提升目标轮廓与距离估计；机械臂抓取中的形状推断。
  - AR/VR与数字孪生：室内/建筑扫描补洞，生成完整的资产用于渲染与交互。
  - 逆向工程/工业检测：部件扫描补全用于尺寸测量、误差对比与后续CAD重建。
  - 文化遗产/文物修复：受限视角采集下的结构补全与还原。

- 数据与形式
  - 输入：不完整点云（partial），可来自LiDAR、RGB‑D或从CAD网格投影采样。
  - 输出：完整点云（complete）、表面网格或隐式场（如SDF/Occ/NerF）。
  - 监督：成对的 partial/complete（合成管线常用）或弱/自监督（真实数据）。

- 评测指标与目标
  - 几何一致性：Chamfer Distance（L1/L2）、EMD、F‑score@τ。
  - 结构合理性：法向一致、曲率连续、体素/网格重建质量。
  - 下游收益：在检测/分割/抓取等任务中的性能增益。

- 方法概览（简）
  - 基于先验的显式补洞（模板/对称性/检索）与深度学习式隐式/显式重建。
  - 全局先验（编码器‑解码器、扩散/自回归）+ 局部细化（patch/上采样/细节增强）。

本Notebook聚焦最小可行Demo：以简化PCN范式完成从partial到complete的端到端训练与可视化，便于快速理解补全任务的定位与价值。

In [None]:
# 安装必要的依赖包
%pip install torch numpy open3d matplotlib tqdm lmdb

import os
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import open3d as o3d
import matplotlib.pyplot as plt
from tqdm import tqdm
import lmdb
import pickle

# 设置随机种子以确保可重复性
torch.manual_seed(42)
np.random.seed(42)

# 检查CUDA是否可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')


## 二、环境配置与依赖安装

In [None]:
def visualize_point_cloud(points, colors=None, window_name="Point Cloud"):
    """
    使用Open3D可视化点云数据
    
    Args:
        points: numpy数组，形状为(N, 3)，表示点云坐标
        colors: numpy数组，形状为(N, 3)，表示点云颜色，范围[0,1]，默认为None
        window_name: 窗口名称
    """
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(points)
    
    if colors is not None:
        pcd.colors = o3d.utility.Vector3dVector(colors)
    else:
        # 默认使用蓝色
        pcd.paint_uniform_color([0, 0.651, 0.929])
    
    # 创建坐标系
    coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(
        size=0.5, origin=[0, 0, 0])
    
    # 可视化点云
    o3d.visualization.draw_geometries([pcd, coordinate_frame],
                                    window_name=window_name,
                                    width=800,
                                    height=600)

def visualize_partial_complete(partial, complete, window_name="Partial vs Complete"):
    """
    并排显示不完整和完整的点云
    
    Args:
        partial: numpy数组，形状为(N, 3)，表示不完整点云
        complete: numpy数组，形状为(M, 3)，表示完整点云
        window_name: 窗口名称
    """
    # 创建不完整点云对象
    pcd_partial = o3d.geometry.PointCloud()
    pcd_partial.points = o3d.utility.Vector3dVector(partial)
    pcd_partial.paint_uniform_color([1, 0, 0])  # 红色表示不完整点云
    
    # 创建完整点云对象
    pcd_complete = o3d.geometry.PointCloud()
    pcd_complete.points = o3d.utility.Vector3dVector(complete)
    pcd_complete.paint_uniform_color([0, 1, 0])  # 绿色表示完整点云
    
    # 将完整点云向右平移，以便并排显示
    center_complete = pcd_complete.get_center()
    center_partial = pcd_partial.get_center()
    translation = np.array([2.0, 0, 0])  # 向x轴正方向平移2个单位
    pcd_complete.translate(translation)
    
    # 创建坐标系
    coordinate_frame1 = o3d.geometry.TriangleMesh.create_coordinate_frame(
        size=0.5, origin=[0, 0, 0])
    coordinate_frame2 = o3d.geometry.TriangleMesh.create_coordinate_frame(
        size=0.5, origin=translation)
    
    # 可视化点云
    o3d.visualization.draw_geometries([pcd_partial, pcd_complete, 
                                     coordinate_frame1, coordinate_frame2],
                                    window_name=window_name,
                                    width=1600,
                                    height=600)

# 测试可视化函数
if __name__ == "__main__":
    # 生成示例点云数据
    num_points = 1000
    
    # 生成一个球体的点云
    theta = np.random.uniform(0, 2*np.pi, num_points)
    phi = np.random.uniform(0, np.pi, num_points)
    r = np.ones(num_points)
    
    x = r * np.sin(phi) * np.cos(theta)
    y = r * np.sin(phi) * np.sin(theta)
    z = r * np.cos(phi)
    
    complete_cloud = np.stack([x, y, z], axis=1)
    
    # 生成不完整的点云（只保留上半部分）
    partial_cloud = complete_cloud[complete_cloud[:, 2] > 0]
    
    # 可视化
    print("显示完整点云...")
    visualize_point_cloud(complete_cloud)
    
    print("显示不完整点云...")
    visualize_point_cloud(partial_cloud)
    
    print("并排显示对比...")
    visualize_partial_complete(partial_cloud, complete_cloud)


In [None]:
def normalize_point_cloud(points):
    """
    对点云进行归一化处理
    
    Args:
        points: numpy数组，形状为(N, 3)
        
    Returns:
        normalized_points: 归一化后的点云
        centroid: 质心
        scale: 缩放因子
    """
    # 计算质心
    centroid = np.mean(points, axis=0)
    
    # 将点云中心移到原点
    points = points - centroid
    
    # 计算到原点的最大距离
    distances = np.sqrt(np.sum(points ** 2, axis=1))
    scale = np.max(distances)
    
    # 归一化到单位球内
    normalized_points = points / scale
    
    return normalized_points, centroid, scale

def random_sample_points(points, num_points):
    """
    随机采样固定数量的点
    
    Args:
        points: numpy数组，形状为(N, 3)
        num_points: 需要采样的点数
        
    Returns:
        sampled_points: 采样后的点云
    """
    if len(points) >= num_points:
        # 随机采样
        indices = np.random.choice(len(points), num_points, replace=False)
        return points[indices]
    else:
        # 如果点数不足，则需要重复采样
        indices = np.random.choice(len(points), num_points, replace=True)
        return points[indices]

def add_noise(points, sigma=0.01, clip=0.05):
    """
    添加高斯噪声
    
    Args:
        points: numpy数组，形状为(N, 3)
        sigma: 高斯噪声的标准差
        clip: 噪声的最大值
        
    Returns:
        noisy_points: 添加噪声后的点云
    """
    noise = np.clip(np.random.normal(0, sigma, points.shape), -clip, clip)
    return points + noise

def create_partial_point_cloud(points, num_patches=1):
    """
    通过随机移除部分区域来创建不完整点云
    
    Args:
        points: numpy数组，形状为(N, 3)
        num_patches: 要移除的区域数量
        
    Returns:
        partial_points: 不完整的点云
    """
    points = points.copy()
    num_points = len(points)
    
    for _ in range(num_patches):
        # 随机选择一个中心点
        center_idx = np.random.randint(0, num_points)
        center = points[center_idx]
        
        # 计算所有点到中心点的距离
        distances = np.sqrt(np.sum((points - center) ** 2, axis=1))
        
        # 随机选择一个半径（0.2到0.4之间）
        radius = np.random.uniform(0.2, 0.4)
        
        # 移除该半径内的点
        mask = distances > radius
        points = points[mask]
        num_points = len(points)
    
    return points

# 测试预处理函数
if __name__ == "__main__":
    # 生成示例点云
    num_points = 2000
    theta = np.random.uniform(0, 2*np.pi, num_points)
    phi = np.random.uniform(0, np.pi, num_points)
    r = np.ones(num_points)
    
    x = r * np.sin(phi) * np.cos(theta)
    y = r * np.sin(phi) * np.sin(theta)
    z = r * np.cos(phi)
    
    points = np.stack([x, y, z], axis=1)
    
    # 测试归一化
    normalized_points, centroid, scale = normalize_point_cloud(points)
    print(f"归一化后点云的范围: [{np.min(normalized_points):.3f}, {np.max(normalized_points):.3f}]")
    
    # 测试采样
    sampled_points = random_sample_points(normalized_points, 1000)
    print(f"采样后点云的形状: {sampled_points.shape}")
    
    # 测试添加噪声
    noisy_points = add_noise(sampled_points)
    
    # 测试创建不完整点云
    partial_points = create_partial_point_cloud(sampled_points)
    print(f"不完整点云的点数: {len(partial_points)}")
    
    # 可视化结果
    print("\n显示原始点云...")
    visualize_point_cloud(sampled_points)
    
    print("\n显示添加噪声后的点云...")
    visualize_point_cloud(noisy_points)
    
    print("\n显示不完整点云与原始点云的对比...")
    visualize_partial_complete(partial_points, sampled_points)


## 三、数据集下载

In [None]:
class PointNetFeatureExtractor(nn.Module):
    def __init__(self):
        super(PointNetFeatureExtractor, self).__init__()
        
        # 特征提取层
        self.conv1 = nn.Conv1d(3, 64, 1)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.conv3 = nn.Conv1d(128, 256, 1)
        
        # 批归一化层
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(256)
        
    def forward(self, x):
        # 输入x的形状: (batch_size, num_points, 3)
        # 转换为(batch_size, 3, num_points)用于1D卷积
        x = x.transpose(2, 1)
        
        # 应用特征提取层
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        
        # 全局特征
        x_global = torch.max(x, 2, keepdim=True)[0]
        
        return x, x_global

class PointCompletionNet(nn.Module):
    def __init__(self, num_points=2048):
        super(PointCompletionNet, self).__init__()
        
        self.num_points = num_points
        
        # 特征提取器
        self.feature_extractor = PointNetFeatureExtractor()
        
        # 解码器 - 全连接层
        self.fc1 = nn.Linear(256, 512)
        self.fc2 = nn.Linear(512, 1024)
        self.fc3 = nn.Linear(1024, num_points * 3)
        
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(1024)
        
        # Dropout层
        self.dropout = nn.Dropout(p=0.3)
        
    def forward(self, x):
        # 输入x的形状: (batch_size, num_points, 3)
        batch_size = x.size(0)
        
        # 提取特征
        point_features, global_features = self.feature_extractor(x)
        
        # 将全局特征展平
        x = global_features.view(batch_size, -1)
        
        # 解码器
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.dropout(x)
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.dropout(x)
        x = self.fc3(x)
        
        # 重塑为点云形状
        x = x.view(batch_size, self.num_points, 3)
        
        return x

def chamfer_distance(pred, gt, reduce_mean=True):
    """
    计算Chamfer距离
    
    Args:
        pred: 预测的点云，形状为(B, N, 3)
        gt: 真实的点云，形状为(B, M, 3)
        reduce_mean: 是否返回平均值
    
    Returns:
        chamfer_dist: Chamfer距离
    """
    # 将点云转换为(B, N, 1, 3)和(B, 1, M, 3)
    pred = pred.unsqueeze(2)
    gt = gt.unsqueeze(1)
    
    # 计算每个点到另一个点云中所有点的距离
    dist = torch.sum((pred - gt) ** 2, dim=3)  # (B, N, M)
    
    # 找到最近点的距离
    dist1, _ = torch.min(dist, dim=2)  # (B, N)
    dist2, _ = torch.min(dist, dim=1)  # (B, M)
    
    # 计算Chamfer距离
    chamfer_dist = torch.mean(dist1, dim=1) + torch.mean(dist2, dim=1)  # (B,)
    
    if reduce_mean:
        chamfer_dist = torch.mean(chamfer_dist)
    
    return chamfer_dist

# 测试模型
if __name__ == "__main__":
    # 创建模型实例
    model = PointCompletionNet(num_points=2048)
    model = model.to(device)
    print(f"模型参数数量: {sum(p.numel() for p in model.parameters())}")
    
    # 生成测试数据
    batch_size = 2
    input_points = torch.randn(batch_size, 1024, 3).to(device)
    
    # 前向传播
    output_points = model(input_points)
    print(f"输出点云形状: {output_points.shape}")
    
    # 测试Chamfer距离
    target_points = torch.randn(batch_size, 2048, 3).to(device)
    loss = chamfer_distance(output_points, target_points)
    print(f"Chamfer距离: {loss.item():.6f}")
    
    # 可视化结果
    if batch_size > 0:
        input_np = input_points[0].cpu().numpy()
        output_np = output_points[0].detach().cpu().numpy()
        target_np = target_points[0].cpu().numpy()
        
        print("\n显示输入点云...")
        visualize_point_cloud(input_np)
        
        print("\n显示输出点云...")
        visualize_point_cloud(output_np)
        
        print("\n显示目标点云...")
        visualize_point_cloud(target_np)
        
        print("\n显示输入与输出点云的对比...")
        visualize_partial_complete(input_np, output_np)


In [None]:
# 数据集下载/路径检查（本地已下载则直接使用）
DATA_DIR = "/Users/arkin/Desktop/Dev/notebooks/point-cloud-completion/shapenet_car"
print('DATA_DIR =', DATA_DIR)

required_files = ['train-001.lmdb', 'valid.lmdb']
missing = []
if os.path.isdir(DATA_DIR):
    for f in required_files:
        if not os.path.exists(os.path.join(DATA_DIR, f)):
            missing.append(f)
else:
    print('目录不存在！')
    missing = required_files

if missing:
    print('缺失文件:', missing)
    print('将使用合成数据进行演示。')
else:
    print('数据路径完整，可用于训练。')


In [None]:
class PointCloudDataset(Dataset):
    def __init__(self, lmdb_path, num_points=2048, partial_points=1024, 
                 mode='train', transform=None):
        """
        点云数据集类
        
        Args:
            lmdb_path: LMDB数据库路径
            num_points: 完整点云的点数
            partial_points: 不完整点云的点数
            mode: 'train' 或 'valid'
            transform: 数据增强函数
        """
        super(PointCloudDataset, self).__init__()
        
        self.lmdb_path = lmdb_path
        self.num_points = num_points
        self.partial_points = partial_points
        self.mode = mode
        self.transform = transform
        
        # 打开LMDB环境
        self.env = lmdb.open(lmdb_path, readonly=True, lock=False)
        with self.env.begin() as txn:
            self.length = int(txn.get('length'.encode()).decode())
            
    def __len__(self):
        return self.length
    
    def __getitem__(self, index):
        with self.env.begin() as txn:
            # 获取完整点云数据
            key = f'complete_{index}'.encode()
            complete = pickle.loads(txn.get(key))
            
            # 获取不完整点云数据
            key = f'partial_{index}'.encode()
            partial = pickle.loads(txn.get(key))
            
            # 转换为numpy数组
            complete = np.array(complete, dtype=np.float32)
            partial = np.array(partial, dtype=np.float32)
            
            # 数据预处理
            # 1. 归一化
            complete, centroid, scale = normalize_point_cloud(complete)
            partial = (partial - centroid) / scale
            
            # 2. 随机采样
            if complete.shape[0] > self.num_points:
                complete = random_sample_points(complete, self.num_points)
            if partial.shape[0] > self.partial_points:
                partial = random_sample_points(partial, self.partial_points)
            
            # 3. 数据增强
            if self.transform and self.mode == 'train':
                complete = self.transform(complete)
                partial = self.transform(partial)
            
            # 转换为张量
            complete = torch.from_numpy(complete)
            partial = torch.from_numpy(partial)
            
            return {'partial': partial, 'complete': complete}

class PointCloudTransform:
    def __init__(self, noise_sigma=0.01, noise_clip=0.05, 
                 rotation=True, translation=True, scale=True):
        """
        点云数据增强类
        
        Args:
            noise_sigma: 高斯噪声的标准差
            noise_clip: 噪声的最大值
            rotation: 是否进行旋转
            translation: 是否进行平移
            scale: 是否进行缩放
        """
        self.noise_sigma = noise_sigma
        self.noise_clip = noise_clip
        self.rotation = rotation
        self.translation = translation
        self.scale = scale
    
    def __call__(self, points):
        """
        对点云进行数据增强
        
        Args:
            points: numpy数组，形状为(N, 3)
            
        Returns:
            transformed_points: 增强后的点云
        """
        points = points.copy()
        
        # 1. 添加高斯噪声
        points = add_noise(points, self.noise_sigma, self.noise_clip)
        
        # 2. 随机旋转
        if self.rotation:
            # 生成随机旋转角度
            theta = np.random.uniform(0, 2*np.pi)
            # 绕z轴旋转的旋转矩阵
            rotation_matrix = np.array([
                [np.cos(theta), -np.sin(theta), 0],
                [np.sin(theta), np.cos(theta), 0],
                [0, 0, 1]
            ])
            points = points @ rotation_matrix
        
        # 3. 随机平移
        if self.translation:
            translation = np.random.uniform(-0.1, 0.1, size=3)
            points += translation
        
        # 4. 随机缩放
        if self.scale:
            scale = np.random.uniform(0.8, 1.2)
            points *= scale
        
        return points.astype(np.float32)

# 创建数据加载器
def create_dataloader(lmdb_path, batch_size=32, num_workers=4, 
                     num_points=2048, partial_points=1024, mode='train'):
    """
    创建数据加载器
    
    Args:
        lmdb_path: LMDB数据库路径
        batch_size: 批次大小
        num_workers: 数据加载的进程数
        num_points: 完整点云的点数
        partial_points: 不完整点云的点数
        mode: 'train' 或 'valid'
        
    Returns:
        dataloader: 数据加载器
    """
    # 创建数据增强
    transform = None
    if mode == 'train':
        transform = PointCloudTransform(
            noise_sigma=0.01,
            noise_clip=0.05,
            rotation=True,
            translation=True,
            scale=True
        )
    
    # 创建数据集
    dataset = PointCloudDataset(
        lmdb_path=lmdb_path,
        num_points=num_points,
        partial_points=partial_points,
        mode=mode,
        transform=transform
    )
    
    # 创建数据加载器
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=(mode == 'train'),
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True
    )
    
    return dataloader

# 测试数据加载器
if __name__ == "__main__":
    # 数据路径
    train_path = "/Users/arkin/Desktop/Dev/notebooks/point-cloud-completion/shapenet_car/train-001.lmdb"
    valid_path = "/Users/arkin/Desktop/Dev/notebooks/point-cloud-completion/shapenet_car/valid.lmdb"
    
    # 创建数据加载器
    train_loader = create_dataloader(train_path, mode='train')
    valid_loader = create_dataloader(valid_path, mode='valid')
    
    print(f"训练数据加载器大小: {len(train_loader)}")
    print(f"验证数据加载器大小: {len(valid_loader)}")
    
    # 获取一个批次的数据
    batch = next(iter(train_loader))
    partial = batch['partial']
    complete = batch['complete']
    
    print(f"不完整点云形状: {partial.shape}")
    print(f"完整点云形状: {complete.shape}")
    
    # 可视化第一个样本
    partial_np = partial[0].numpy()
    complete_np = complete[0].numpy()
    
    print("\n显示不完整点云与完整点云的对比...")
    visualize_partial_complete(partial_np, complete_np)


## 四、数据可视化展示

In [None]:
def train_model(model, train_loader, valid_loader, num_epochs=100, 
              learning_rate=0.001, device=device):
    """
    训练模型
    
    Args:
        model: 模型实例
        train_loader: 训练数据加载器
        valid_loader: 验证数据加载器
        num_epochs: 训练轮数
        learning_rate: 学习率
        device: 训练设备
    """
    # 优化器
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    # 学习率调度器
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
    
    # 记录最佳模型
    best_valid_loss = float('inf')
    best_model_state = None
    
    # 训练历史
    history = {
        'train_loss': [],
        'valid_loss': [],
        'lr': []
    }
    
    # 训练循环
    for epoch in range(num_epochs):
        # 训练阶段
        model.train()
        train_losses = []
        train_progress = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
        
        for batch in train_progress:
            # 准备数据
            partial = batch['partial'].to(device)
            complete = batch['complete'].to(device)
            
            # 前向传播
            output = model(partial)
            loss = chamfer_distance(output, complete)
            
            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # 记录损失
            train_losses.append(loss.item())
            train_progress.set_postfix({'loss': f'{loss.item():.6f}'})
        
        # 计算平均训练损失
        avg_train_loss = np.mean(train_losses)
        history['train_loss'].append(avg_train_loss)
        
        # 验证阶段
        model.eval()
        valid_losses = []
        valid_progress = tqdm(valid_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Valid]')
        
        with torch.no_grad():
            for batch in valid_progress:
                # 准备数据
                partial = batch['partial'].to(device)
                complete = batch['complete'].to(device)
                
                # 前向传播
                output = model(partial)
                loss = chamfer_distance(output, complete)
                
                # 记录损失
                valid_losses.append(loss.item())
                valid_progress.set_postfix({'loss': f'{loss.item():.6f}'})
        
        # 计算平均验证损失
        avg_valid_loss = np.mean(valid_losses)
        history['valid_loss'].append(avg_valid_loss)
        
        # 更新学习率
        scheduler.step()
        current_lr = scheduler.get_last_lr()[0]
        history['lr'].append(current_lr)
        
        # 打印训练信息
        print(f'\nEpoch {epoch+1}/{num_epochs}:')
        print(f'Train Loss: {avg_train_loss:.6f}')
        print(f'Valid Loss: {avg_valid_loss:.6f}')
        print(f'Learning Rate: {current_lr:.6f}\n')
        
        # 保存最佳模型
        if avg_valid_loss < best_valid_loss:
            best_valid_loss = avg_valid_loss
            best_model_state = model.state_dict()
            print(f'Found new best model with validation loss: {best_valid_loss:.6f}')
            
            # 保存模型
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': best_model_state,
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'valid_loss': best_valid_loss,
                'history': history
            }, 'best_model.pth')
    
    return history, best_model_state

def plot_training_history(history):
    """
    绘制训练历史
    
    Args:
        history: 包含训练历史的字典
    """
    epochs = range(1, len(history['train_loss']) + 1)
    
    # 创建图形
    plt.figure(figsize=(12, 4))
    
    # 绘制损失曲线
    plt.subplot(1, 2, 1)
    plt.plot(epochs, history['train_loss'], 'b-', label='Train Loss')
    plt.plot(epochs, history['valid_loss'], 'r-', label='Valid Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    # 绘制学习率曲线
    plt.subplot(1, 2, 2)
    plt.plot(epochs, history['lr'], 'g-')
    plt.title('Learning Rate')
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.grid(True)
    
    plt.tight_layout()
    plt.show()

# 训练模型
if __name__ == "__main__":
    # 创建数据加载器
    train_path = "/Users/arkin/Desktop/Dev/notebooks/point-cloud-completion/shapenet_car/train-001.lmdb"
    valid_path = "/Users/arkin/Desktop/Dev/notebooks/point-cloud-completion/shapenet_car/valid.lmdb"
    
    train_loader = create_dataloader(train_path, batch_size=32, mode='train')
    valid_loader = create_dataloader(valid_path, batch_size=32, mode='valid')
    
    # 创建模型
    model = PointCompletionNet(num_points=2048)
    model = model.to(device)
    
    # 训练模型
    history, best_model_state = train_model(
        model=model,
        train_loader=train_loader,
        valid_loader=valid_loader,
        num_epochs=100,
        learning_rate=0.001,
        device=device
    )
    
    # 绘制训练历史
    plot_training_history(history)


In [None]:
def evaluate_model(model, test_loader, device=device, num_visualize=5):
    """
    评估模型并可视化结果
    
    Args:
        model: 训练好的模型
        test_loader: 测试数据加载器
        device: 运行设备
        num_visualize: 要可视化的样本数量
    """
    model.eval()
    test_losses = []
    visualize_samples = []
    
    with torch.no_grad():
        for i, batch in enumerate(tqdm(test_loader, desc='Evaluating')):
            # 准备数据
            partial = batch['partial'].to(device)
            complete = batch['complete'].to(device)
            
            # 前向传播
            output = model(partial)
            loss = chamfer_distance(output, complete)
            
            # 记录损失
            test_losses.append(loss.item())
            
            # 收集可视化样本
            if i < num_visualize:
                visualize_samples.append({
                    'partial': partial[0].cpu().numpy(),
                    'complete': complete[0].cpu().numpy(),
                    'output': output[0].cpu().numpy()
                })
    
    # 计算平均测试损失
    avg_test_loss = np.mean(test_losses)
    print(f'\nAverage Test Loss: {avg_test_loss:.6f}')
    
    # 可视化结果
    for i, sample in enumerate(visualize_samples):
        plt.figure(figsize=(15, 5))
        
        # 创建三个子图
        ax1 = plt.subplot(131, projection='3d')
        ax2 = plt.subplot(132, projection='3d')
        ax3 = plt.subplot(133, projection='3d')
        
        # 绘制不完整点云
        partial = sample['partial']
        ax1.scatter(partial[:, 0], partial[:, 1], partial[:, 2], c='r', marker='.')
        ax1.set_title('Partial Point Cloud')
        
        # 绘制完整点云
        complete = sample['complete']
        ax2.scatter(complete[:, 0], complete[:, 1], complete[:, 2], c='g', marker='.')
        ax2.set_title('Ground Truth')
        
        # 绘制预测点云
        output = sample['output']
        ax3.scatter(output[:, 0], output[:, 1], output[:, 2], c='b', marker='.')
        ax3.set_title('Model Output')
        
        # 设置视角
        for ax in [ax1, ax2, ax3]:
            ax.view_init(elev=30, azim=45)
            ax.set_xlabel('X')
            ax.set_ylabel('Y')
            ax.set_zlabel('Z')
        
        plt.tight_layout()
        plt.show()
        
        # 使用Open3D进行交互式可视化
        print(f'\nSample {i+1} - Interactive Visualization:')
        print('显示不完整点云与预测点云的对比...')
        visualize_partial_complete(partial, output)
        
        print('显示预测点云与真实点云的对比...')
        visualize_partial_complete(output, complete)

def compute_metrics(model, test_loader, device=device):
    """
    计算详细的评估指标
    
    Args:
        model: 训练好的模型
        test_loader: 测试数据加载器
        device: 运行设备
    """
    model.eval()
    chamfer_distances = []
    point_nums = []
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc='Computing Metrics'):
            partial = batch['partial'].to(device)
            complete = batch['complete'].to(device)
            output = model(partial)
            
            # 计算每个样本的Chamfer距离
            cd = chamfer_distance(output, complete, reduce_mean=False)
            chamfer_distances.extend(cd.cpu().numpy())
            
            # 记录点的数量
            point_nums.extend([
                partial.shape[1],  # 输入点数
                complete.shape[1],  # 目标点数
                output.shape[1]    # 输出点数
            ])
    
    # 计算统计信息
    cd_mean = np.mean(chamfer_distances)
    cd_std = np.std(chamfer_distances)
    cd_median = np.median(chamfer_distances)
    cd_min = np.min(chamfer_distances)
    cd_max = np.max(chamfer_distances)
    
    # 打印评估结果
    print('\nEvaluation Metrics:')
    print(f'Chamfer Distance:')
    print(f'  Mean   : {cd_mean:.6f}')
    print(f'  Std    : {cd_std:.6f}')
    print(f'  Median : {cd_median:.6f}')
    print(f'  Min    : {cd_min:.6f}')
    print(f'  Max    : {cd_max:.6f}')
    
    # 绘制Chamfer距离分布
    plt.figure(figsize=(10, 5))
    plt.hist(chamfer_distances, bins=50, density=True)
    plt.axvline(cd_mean, color='r', linestyle='--', label=f'Mean: {cd_mean:.6f}')
    plt.axvline(cd_median, color='g', linestyle='--', label=f'Median: {cd_median:.6f}')
    plt.title('Distribution of Chamfer Distances')
    plt.xlabel('Chamfer Distance')
    plt.ylabel('Density')
    plt.legend()
    plt.grid(True)
    plt.show()

# 评估模型
if __name__ == "__main__":
    # 加载最佳模型
    checkpoint = torch.load('best_model.pth')
    model = PointCompletionNet(num_points=2048)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    
    # 创建测试数据加载器
    test_path = "/Users/arkin/Desktop/Dev/notebooks/point-cloud-completion/shapenet_car/valid.lmdb"
    test_loader = create_dataloader(test_path, batch_size=32, mode='valid')
    
    # 评估模型
    print("正在评估模型...")
    evaluate_model(model, test_loader, num_visualize=5)
    
    # 计算详细指标
    print("\n计算详细评估指标...")
    compute_metrics(model, test_loader)


## 五、数据预处理

## 六、ShapeNet模型架构

## 七、数据集类与数据加载器

## 八、模型训练

## 九、模型评估与可视化