In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
from tqdm import tqdm

# 定义设备
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [None]:
class SyntheticPointCloud(Dataset):
    """生成带随机位姿的合成点云数据集"""
    def __init__(self, num_samples=1000, num_points=1024):
        self.num_samples = num_samples
        self.num_points = num_points
        # 生成基础形状（立方体）
        self.base_points = self._generate_cube()

    def _generate_cube(self, size=1.0):
        """生成立方体表面点云"""
        points = []
        for x in np.linspace(-size/2, size/2, 10):
            for y in np.linspace(-size/2, size/2, 10):
                for z in np.linspace(-size/2, size/2, 10):
                    if abs(x) == size/2 or abs(y) == size/2 or abs(z) == size/2:
                        points.append([x, y, z])
        return torch.tensor(points, dtype=torch.float32)

    def _apply_random_transform(self, points):
        """应用随机旋转和平移"""
        # 生成随机旋转矩阵
        angle = np.random.uniform(0, 2*np.pi)
        axis = np.random.randn(3)
        axis /= np.linalg.norm(axis)
        R = self._axis_angle_to_matrix(angle, axis)
        # 生成随机平移
        t = np.random.randn(3) * 0.5
        # 应用变换
        transformed_points = (R @ points.T).T + t
        return transformed_points.astype(np.float32), R.astype(np.float32), t.astype(np.float32)

    def _axis_angle_to_matrix(self, angle, axis):
        """轴角转旋转矩阵"""
        axis = axis / np.linalg.norm(axis)
        a = np.cos(angle / 2)
        b, c, d = -axis * np.sin(angle / 2)
        return np.array([
            [a*a + b*b - c*c - d*d, 2*(b*c - a*d), 2*(b*d + a*c)],
            [2*(b*c + a*d), a*a + c*c - b*b - d*d, 2*(c*d - a*b)],
            [2*(b*d - a*c), 2*(c*d + a*b), a*a + d*d - b*b - c*c]
        ])

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # 从基础形状中采样点
        idx = np.random.choice(len(self.base_points), self.num_points, replace=True)
        points = self.base_points[idx]
        # 应用随机变换
        points_np = points.numpy()
        transformed_points, R_true, t_true = self._apply_random_transform(points_np)
        return {
            'points': torch.from_numpy(transformed_points),
            'R_true': torch.from_numpy(R_true),
            't_true': torch.from_numpy(t_true)
        }

In [None]:
# from point_transformer_modules import *  # 包含你提供的所有模块
import sys
sys.path.append('/home/dataset/zbz_new/Mulsen3')
from models.models import PointTransformer