In [None]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

In [None]:
class PointCloudACCDataset(Dataset):
    def __init__(self, pointcloud, context):
        """
        Args:
            pointcloud: NumPy数组 [B, N, C]
            context: NumPy数组 [B, 6]
        """
        self.pointcloud = torch.from_numpy(pointcloud).float()  # 转为Tensor [B, N, C]
        self.context_1 = torch.from_numpy(context[:,:3]).float()  # 转为Tensor [B, 3]
        self.context_2 = torch.from_numpy(context[:,3:]).float()  # 转为Tensor [B, 3]
        
        assert len(self.pointcloud) == len(self.context_1), "数据长度不一致"

    def __len__(self):
        return len(self.pointcloud)

    def __getitem__(self, idx):
        return {
            'pointcloud': self.pointcloud[idx],  # 形状 [N, C]
            'context_1': self.context_1[idx],    # 形状 [3]
            'context_2': self.context_2[idx]     # 形状 [3]
        }
    '''
    Dataloader使用示例
    for batch in dataloader:
        pointcloud_batch = batch['pointcloud']  # 形状 [batch_size, N, C]
        context_1 = batch['context_1']                # 形状 [batch_size, 3]
        context_2 = batch['context_2']                # 形状 [batch_size, 3]
    '''

In [None]:
'''Load Data'''
data_path = os.path.join("../data/train_data_pointcloud.npz")
data = np.load(data_path)
loaded_pointcloud = np.transpose(data['pointcloud'],(0,2,1))  # 形状 [B, N, D]-> [B, D, N]
loaded_ACC = data['ACC']    

In [None]:
len = loaded_pointcloud.shape[0]
train_len = len * 0.8
train_dataset = PointCloudACCDataset(loaded_pointcloud[:train_len], loaded_ACC[:train_len])
train_dataloader = DataLoader(
    train_dataset,
    batch_size=10,#这不用注释了吧
    shuffle=True,# 打乱
    num_workers=2,# 多进程加载数
    pin_memory=True  # 加速GPU传输
)

In [None]:
for batch in train_dataloader:
        pointcloud_batch = batch['pointcloud']  # 形状 [batch_size, N, C]
        context_1 = batch['context_1']                # 形状 [batch_size, 3]
        context_2 = batch['context_2']                # 形状 [batch_size, 3]
        break