In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
from tqdm import tqdm

# 定义设备
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [2]:
class SyntheticPointCloud(Dataset):
    """生成带随机位姿的合成点云数据集"""
    def __init__(self, num_samples=1000, num_points=1024):
        self.num_samples = num_samples
        self.num_points = num_points
        # 生成基础形状（立方体）
        self.base_points = self._generate_cube()

    def _generate_cube(self, size=1.0):
        """生成立方体表面点云"""
        points = []
        for x in np.linspace(-size/2, size/2, 10):
            for y in np.linspace(-size/2, size/2, 10):
                for z in np.linspace(-size/2, size/2, 10):
                    if abs(x) == size/2 or abs(y) == size/2 or abs(z) == size/2:
                        points.append([x, y, z])
        return torch.tensor(points, dtype=torch.float32)

    def _apply_random_transform(self, points):
        """应用随机旋转和平移"""
        # 生成随机旋转矩阵
        angle = np.random.uniform(0, 2*np.pi)
        axis = np.random.randn(3)
        axis /= np.linalg.norm(axis)
        R = self._axis_angle_to_matrix(angle, axis)
        # 生成随机平移
        t = np.random.randn(3) * 0.5
        # 应用变换
        transformed_points = (R @ points.T).T + t
        return transformed_points.astype(np.float32), R.astype(np.float32), t.astype(np.float32)

    def _axis_angle_to_matrix(self, angle, axis):
        """轴角转旋转矩阵"""
        axis = axis / np.linalg.norm(axis)
        a = np.cos(angle / 2)
        b, c, d = -axis * np.sin(angle / 2)
        return np.array([
            [a*a + b*b - c*c - d*d, 2*(b*c - a*d), 2*(b*d + a*c)],
            [2*(b*c + a*d), a*a + c*c - b*b - d*d, 2*(c*d - a*b)],
            [2*(b*d - a*c), 2*(c*d + a*b), a*a + d*d - b*b - c*c]
        ])

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # 从基础形状中采样点
        idx = np.random.choice(len(self.base_points), self.num_points, replace=True)
        points = self.base_points[idx]
        # 应用随机变换
        points_np = points.numpy()
        transformed_points, R_true, t_true = self._apply_random_transform(points_np)
        return {
            'points': torch.from_numpy(transformed_points),
            'R_true': torch.from_numpy(R_true),
            't_true': torch.from_numpy(t_true)
        }

In [3]:
# from point_transformer_modules import *  # 包含你提供的所有模块
import sys
sys.path.append('/home/dataset/zbz_new/Mulsen3')
from models.models import PointTransformer

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np


class PoseNetWithTransformer(nn.Module):
    def __init__(self, num_group=1024, group_size=128, encoder_dims=384):
        super().__init__()
        # Point Transformer 主干网络
        self.backbone = PointTransformer(
            group_size=group_size,
            num_group=num_group,
            encoder_dims=encoder_dims
        )
        self.backbone.load_model_from_ckpt(r"/home/dataset/zbz_new/Mulsen3/checkpoints/pointmae_pretrain.pth")
        # 位姿预测头
        self.rot_head = nn.Sequential(
            nn.Linear(encoder_dims*3*num_group, 512),
            # nn.BatchNorm1d(512),
            nn.LayerNorm(512),
            nn.ReLU(),
            nn.Linear(512, 6)  # 6D旋转表示
        )
        
        self.trans_head = nn.Sequential(
            nn.Linear(encoder_dims*3*num_group, 512),
            # nn.BatchNorm1d(512),
            nn.LayerNorm(512),
            nn.ReLU(),
            nn.Linear(512, 3)  # 平移向量
        )

    def forward(self, pts):
        # 输入pts形状: (B, N, 3)
        B, N, _ = pts.shape
        
        # 通过Transformer主干网络
        features, center, ori_idx, center_idx,pts = self.backbone(pts.transpose(1, 2).contiguous().float())  # features形状: (B, 1152)
        # print(features.shape)
        features = features.view(features.size(0), -1)
        # print(features.shape)
        # 位姿预测
        rot_params = self.rot_head(features)  # (B, 6)
        trans = self.trans_head(features)     # (B, 3)
        
        # 将6D旋转转换为旋转矩阵
        R = self.rot6d_to_matrix(rot_params)  # (B, 3, 3)
        return R, trans

    def rot6d_to_matrix(self, d6):
        """将6D旋转表示转换为旋转矩阵"""
        a1, a2 = d6[..., :3], d6[..., 3:]
        b1 = F.normalize(a1, dim=-1)
        b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
        b2 = F.normalize(b2, dim=-1)
        b3 = torch.cross(b1, b2, dim=-1)
        return torch.stack([b1, b2, b3], dim=-2)

# 训练和测试代码（与之前类似，只需替换模型）
if __name__ == "__main__":
    # 参数设置
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_epochs = 50
    batch_size = 1
    
    # 初始化模型
    model = PoseNetWithTransformer(
        num_group=1024,
        group_size=128,
        encoder_dims=384
    ).to(device)
    
    # 数据集和加载器（使用之前定义的SyntheticPointCloud）
    train_dataset = SyntheticPointCloud(num_samples=1000)
    test_dataset = SyntheticPointCloud(num_samples=200)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)

    # 优化器
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
    
    # 训练循环
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            points = batch['points'].to(device)  # (B, N, 3)
            R_true = batch['R_true'].to(device)  # (B, 3, 3)
            t_true = batch['t_true'].to(device)  # (B, 3)
            
            # 前向传播
            R_pred, t_pred = model(points)
            
            # 计算损失
            loss_R = F.mse_loss(R_pred, R_true)
            loss_t = F.mse_loss(t_pred, t_true)
            loss = 0.7*loss_R + 0.3*loss_t
            
            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        # 验证
        model.eval()
        val_R_error = []
        val_t_error = []
        with torch.no_grad():
            for batch in test_loader:
                points = batch['points'].to(device)
                R_true = batch['R_true'].to(device)
                t_true = batch['t_true'].to(device)
                
                R_pred, t_pred = model(points)
                
                # 计算旋转误差（角度差）
                trace = torch.einsum('bii->b', R_pred @ R_true.transpose(1,2))
                theta = torch.acos((trace - 1)/2)
                val_R_error.extend(torch.rad2deg(theta).cpu().numpy())
                
                # 计算平移误差
                val_t_error.extend(torch.norm(t_pred - t_true, dim=1).cpu().numpy())
        
        print(f"Epoch {epoch+1} | "
              f"Train Loss: {total_loss/len(train_loader):.4f} | "
              f"Val Rot Error: {np.mean(val_R_error):.2f}° | "
              f"Val Trans Error: {np.mean(val_t_error):.2f}m")

# 测试用例
def test_single_example():
    # 加载训练好的模型
    model = PoseNetWithTransformer().to(device)
    model.load_state_dict(torch.load("best_model.pth"))
    model.eval()
    
    # 生成测试数据
    test_data = SyntheticPointCloud(num_samples=1)[0]
    points = test_data['points'].unsqueeze(0).to(device)
    R_true = test_data['R_true'].numpy()
    t_true = test_data['t_true'].numpy()
    
    # 推理
    with torch.no_grad():
        R_pred, t_pred = model(points)
    
    print("\n真实旋转矩阵:\n", R_true)
    print("预测旋转矩阵:\n", R_pred.squeeze().cpu().numpy())
    print("真实平移:", t_true)
    print("预测平移:", t_pred.squeeze().cpu().numpy())

if __name__ == "__main__":
    test_single_example()

Epoch 1: 100%|██████████| 1000/1000 [05:23<00:00,  3.09it/s]


Epoch 1 | Train Loss: 0.3510 | Val Rot Error: 94.21° | Val Trans Error: 0.13m


Epoch 2: 100%|██████████| 1000/1000 [05:22<00:00,  3.10it/s]


Epoch 2 | Train Loss: 0.3216 | Val Rot Error: 86.69° | Val Trans Error: 0.12m


Epoch 3: 100%|██████████| 1000/1000 [05:24<00:00,  3.08it/s]


Epoch 3 | Train Loss: 0.3159 | Val Rot Error: 99.29° | Val Trans Error: 0.15m


Epoch 4: 100%|██████████| 1000/1000 [05:22<00:00,  3.10it/s]


Epoch 4 | Train Loss: 0.3096 | Val Rot Error: 81.39° | Val Trans Error: 0.14m


Epoch 5: 100%|██████████| 1000/1000 [05:22<00:00,  3.10it/s]


Epoch 5 | Train Loss: 0.3110 | Val Rot Error: 90.30° | Val Trans Error: 0.08m


Epoch 6: 100%|██████████| 1000/1000 [05:22<00:00,  3.10it/s]


Epoch 6 | Train Loss: 0.3114 | Val Rot Error: 92.81° | Val Trans Error: 0.10m


Epoch 7: 100%|██████████| 1000/1000 [05:22<00:00,  3.10it/s]


Epoch 7 | Train Loss: 0.3083 | Val Rot Error: 90.23° | Val Trans Error: 0.12m


Epoch 8: 100%|██████████| 1000/1000 [05:22<00:00,  3.10it/s]


Epoch 8 | Train Loss: 0.3203 | Val Rot Error: 92.75° | Val Trans Error: 0.10m


Epoch 9: 100%|██████████| 1000/1000 [05:22<00:00,  3.10it/s]


Epoch 9 | Train Loss: 0.2996 | Val Rot Error: 94.49° | Val Trans Error: 0.09m


Epoch 10: 100%|██████████| 1000/1000 [05:22<00:00,  3.10it/s]


Epoch 10 | Train Loss: 0.3074 | Val Rot Error: 95.58° | Val Trans Error: 0.07m


Epoch 11: 100%|██████████| 1000/1000 [05:23<00:00,  3.09it/s]


Epoch 11 | Train Loss: 0.3145 | Val Rot Error: 87.67° | Val Trans Error: 0.11m


Epoch 12: 100%|██████████| 1000/1000 [05:24<00:00,  3.09it/s]


Epoch 12 | Train Loss: 0.3087 | Val Rot Error: 93.31° | Val Trans Error: 0.10m


Epoch 13: 100%|██████████| 1000/1000 [05:22<00:00,  3.10it/s]


Epoch 13 | Train Loss: 0.3130 | Val Rot Error: 90.11° | Val Trans Error: 0.12m


Epoch 14: 100%|██████████| 1000/1000 [05:22<00:00,  3.10it/s]


Epoch 14 | Train Loss: 0.3127 | Val Rot Error: 87.06° | Val Trans Error: 0.11m


Epoch 15: 100%|██████████| 1000/1000 [05:22<00:00,  3.10it/s]


Epoch 15 | Train Loss: 0.3837 | Val Rot Error: 92.09° | Val Trans Error: 0.80m


Epoch 16: 100%|██████████| 1000/1000 [05:22<00:00,  3.10it/s]


Epoch 16 | Train Loss: 0.4030 | Val Rot Error: 89.14° | Val Trans Error: 0.82m


Epoch 17: 100%|██████████| 1000/1000 [05:22<00:00,  3.10it/s]


Epoch 17 | Train Loss: 0.3825 | Val Rot Error: 85.86° | Val Trans Error: 0.69m


Epoch 18: 100%|██████████| 1000/1000 [05:23<00:00,  3.09it/s]


Epoch 18 | Train Loss: 0.3628 | Val Rot Error: 94.56° | Val Trans Error: 0.60m


Epoch 19: 100%|██████████| 1000/1000 [05:23<00:00,  3.09it/s]


Epoch 19 | Train Loss: 0.3457 | Val Rot Error: 92.66° | Val Trans Error: 0.46m


Epoch 20: 100%|██████████| 1000/1000 [05:22<00:00,  3.10it/s]


Epoch 20 | Train Loss: 0.3257 | Val Rot Error: 88.54° | Val Trans Error: 0.43m


Epoch 21: 100%|██████████| 1000/1000 [05:22<00:00,  3.10it/s]


Epoch 21 | Train Loss: 0.3312 | Val Rot Error: 92.29° | Val Trans Error: 0.16m


Epoch 22: 100%|██████████| 1000/1000 [05:22<00:00,  3.10it/s]


Epoch 22 | Train Loss: 0.3297 | Val Rot Error: 88.88° | Val Trans Error: 0.13m


Epoch 23: 100%|██████████| 1000/1000 [05:22<00:00,  3.10it/s]


Epoch 23 | Train Loss: 0.3069 | Val Rot Error: 96.16° | Val Trans Error: 0.10m


Epoch 24: 100%|██████████| 1000/1000 [05:22<00:00,  3.10it/s]


Epoch 24 | Train Loss: 0.3124 | Val Rot Error: 90.62° | Val Trans Error: 0.10m


Epoch 25: 100%|██████████| 1000/1000 [05:22<00:00,  3.10it/s]


Epoch 25 | Train Loss: 0.3095 | Val Rot Error: 87.74° | Val Trans Error: 0.11m


Epoch 26: 100%|██████████| 1000/1000 [05:22<00:00,  3.10it/s]


Epoch 26 | Train Loss: 0.3234 | Val Rot Error: 89.78° | Val Trans Error: 0.08m


Epoch 27: 100%|██████████| 1000/1000 [05:25<00:00,  3.08it/s]


Epoch 27 | Train Loss: 0.3079 | Val Rot Error: 85.42° | Val Trans Error: 0.08m


Epoch 28: 100%|██████████| 1000/1000 [05:22<00:00,  3.10it/s]


Epoch 28 | Train Loss: 0.3030 | Val Rot Error: 83.46° | Val Trans Error: 0.08m


Epoch 29: 100%|██████████| 1000/1000 [05:23<00:00,  3.10it/s]


Epoch 29 | Train Loss: 0.3108 | Val Rot Error: 94.57° | Val Trans Error: 0.09m


Epoch 30: 100%|██████████| 1000/1000 [05:22<00:00,  3.10it/s]


Epoch 30 | Train Loss: 0.3137 | Val Rot Error: 91.21° | Val Trans Error: 0.13m


Epoch 31: 100%|██████████| 1000/1000 [05:22<00:00,  3.10it/s]


Epoch 31 | Train Loss: 0.3095 | Val Rot Error: 93.03° | Val Trans Error: 0.07m


Epoch 32: 100%|██████████| 1000/1000 [05:22<00:00,  3.10it/s]


Epoch 32 | Train Loss: 0.3063 | Val Rot Error: 89.02° | Val Trans Error: 0.10m


Epoch 33: 100%|██████████| 1000/1000 [05:22<00:00,  3.10it/s]


Epoch 33 | Train Loss: 0.3041 | Val Rot Error: 93.64° | Val Trans Error: 0.13m


Epoch 34:  52%|█████▏    | 524/1000 [02:49<02:34,  3.09it/s]


KeyboardInterrupt: 

In [3]:
import torch
import torch.nn as nn
import numpy as np

class PositionalEncoding(nn.Module):
    def __init__(self, num_freqs=10, include_input=True):
        """
        :param num_freqs: 编码的频率数量，控制编码的维度
        :param include_input: 是否将原始的输入坐标(x, y, z)也包含进编码
        """
        super(PositionalEncoding, self).__init__()
        self.num_freqs = num_freqs
        self.include_input = include_input

    def forward(self, x):
        """
        :param x: 输入坐标，形状为(Batch, 3)，即 (x, y, z)
        :return: 编码后的坐标特征
        """
        # 获取x, y, z的分量
        x = x.unsqueeze(-1)  # (B, 3, 1)
        
        # 定义频率值
        freqs = 2.0 ** torch.linspace(0, self.num_freqs - 1, self.num_freqs).to(x.device)  # (num_freqs,)
        
        # 对每个坐标维度进行频率编码
        encoded = [x * freq for freq in freqs]  # 对每个频率进行扩展
        
        # 对于每个频率，分别计算sin和cos
        encoding = []
        for e in encoded:
            encoding.append(torch.sin(e))  # sin编码
            encoding.append(torch.cos(e))  # cos编码
        
        # 将sin和cos编码结果拼接在一起
        encoding = torch.cat(encoding, dim=-1)  # (B, 3 * 2 * num_freqs)

        # 如果需要将原始输入坐标(x, y, z)也包含进编码中
        if self.include_input:
            encoding = torch.cat([x.squeeze(-1), encoding], dim=-1)  # (B, 3 + 3 * 2 * num_freqs)
        
        return encoding


# 测试
xyz = torch.randn(1, 3)  # 输入的x, y, z坐标 (Batch, 3)

# 创建位置编码器，指定频率数量
position_encoder = PositionalEncoding(num_freqs=10, include_input=False)

# 获取编码后的结果
encoded_xyz = position_encoder(xyz)
print(encoded_xyz.shape)  # 输出编码后的维度 (2, 3 + 3 * 2 * 10)


torch.Size([1, 3, 20])


In [4]:
encoded_xyz

tensor([[[-0.7333,  0.6799, -0.9971, -0.0756,  0.1508, -0.9886, -0.2981,
           0.9545, -0.5691,  0.8223, -0.9359,  0.3523, -0.6594, -0.7518,
           0.9915,  0.1305,  0.2587, -0.9660, -0.4998,  0.8661],
         [-0.0383,  0.9993, -0.0765,  0.9971, -0.1525,  0.9883, -0.3014,
           0.9535, -0.5747,  0.8183, -0.9407,  0.3393, -0.6384, -0.7697,
           0.9828,  0.1849,  0.3634, -0.9316, -0.6771,  0.7359],
         [-0.1808,  0.9835, -0.3556,  0.9346, -0.6647,  0.7471, -0.9932,
           0.1162, -0.2309, -0.9730,  0.4493,  0.8934,  0.8028,  0.5962,
           0.9573, -0.2891, -0.5536, -0.8328,  0.9220,  0.3871]]])

In [6]:
import torch
import numpy as np

# 定义位置编码器 (Embedder)
class Embedder:
    def __init__(self, **kwargs):
        self.kwargs = kwargs
        self.create_embedding_fn()

    def create_embedding_fn(self):
        embed_fns = []
        d = self.kwargs['input_dims']
        out_dim = 0
        if self.kwargs['include_input']:
            embed_fns.append(lambda x: x)
            out_dim += d

        max_freq = self.kwargs['max_freq_log2']
        N_freqs = self.kwargs['num_freqs']

        if self.kwargs['log_sampling']:
            freq_bands = 2.**torch.linspace(0., max_freq, N_freqs)
        else:
            freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs)

        for freq in freq_bands:
            for p_fn in self.kwargs['periodic_fns']:
                embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
                out_dim += d

        self.embed_fns = embed_fns
        self.out_dim = out_dim

    def embed(self, inputs):
        return torch.cat([fn(inputs) for fn in self.embed_fns], -1)


# 测试用例
def test_position_encoding():

    embed_kwargs = {
        'include_input': True,
        'input_dims': 3,  # 3D 坐标
        'max_freq_log2': 4,  # 最大频率 (log2 scale)
        'num_freqs': 10,  # 频率数量
        'log_sampling': True,  # 使用对数采样
        'periodic_fns': [torch.sin, torch.cos],  # 使用 sin 和 cos 编码
    }


    embedder = Embedder(**embed_kwargs)

    # 随机生成一些 3D 坐标作为输入
    xyz = torch.rand((5, 3)) * 2 - 1  # 生成5个随机3D坐标，范围在 [-1, 1] 之间

    # 对坐标进行位置编码
    encoded_xyz = embedder.embed(xyz)

    print(f"原始坐标 (xyz): {xyz}")
    print(f"位置编码后的结果: {encoded_xyz}")
    print(f"位置编码后的维度: {encoded_xyz.shape}")


# 运行测试
test_position_encoding()


原始坐标 (xyz): tensor([[-0.0170, -0.9113,  0.1114],
        [ 0.6311,  0.9102, -0.8234],
        [ 0.2032,  0.6423, -0.1358],
        [ 0.9093, -0.7948,  0.4350],
        [ 0.0484,  0.0912, -0.2194]])
位置编码后的结果: tensor([[-0.0170, -0.9113,  0.1114, -0.0170, -0.7903,  0.1111,  0.9999,  0.6127,
          0.9938, -0.0231, -0.9458,  0.1510,  0.9997,  0.3247,  0.9885, -0.0315,
         -0.9932,  0.2048,  0.9995, -0.1165,  0.9788, -0.0428, -0.7481,  0.2770,
          0.9991, -0.6636,  0.9609, -0.0583, -0.0167,  0.3727,  0.9983, -0.9999,
          0.9280, -0.0793,  0.8960,  0.4966,  0.9969, -0.4440,  0.8680, -0.1078,
          0.4765,  0.6497,  0.9942,  0.8792,  0.7602, -0.1464, -0.9998,  0.8205,
          0.9892, -0.0202,  0.5716, -0.1986,  0.9609,  0.9661,  0.9801, -0.2768,
          0.2583, -0.2688, -0.9031,  0.9778,  0.9632, -0.4294, -0.2096],
        [ 0.6311,  0.9102, -0.8234,  0.5900,  0.7896, -0.7335,  0.8074,  0.6136,
          0.6797,  0.7571,  0.9453, -0.9003,  0.6534,  0.3262,  0.4352,