In [4]:
import logging
import os
import pickle

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

from s3dis_util import crop_pc
import s3dis_transforms as T


In [2]:
class S3DIS(Dataset):
    classes = ['ceiling',
               'floor',
               'wall',
               'beam',
               'column',
               'window',
               'door',
               'chair',
               'table',
               'bookcase',
               'sofa',
               'board',
               'clutter']
    num_classes = 13
    num_per_class = np.array([3370714, 2856755, 4919229, 318158, 375640, 478001, 974733,
                              650464, 791496, 88727, 1284130, 229758, 2272837], dtype=np.int32)
    class2color = {'ceiling':     [0, 255, 0],
                   'floor':       [0, 0, 255],
                   'wall':        [0, 255, 255],
                   'beam':        [255, 255, 0],
                   'column':      [255, 0, 255],
                   'window':      [100, 100, 255],
                   'door':        [200, 200, 100],
                   'table':       [170, 120, 200],
                   'chair':       [255, 0, 0],
                   'sofa':        [200, 100, 100],
                   'bookcase':    [10, 200, 100],
                   'board':       [200, 200, 200],
                   'clutter':     [50, 50, 50]}
    cmap = [*class2color.values()]
    """S3DIS dataset, loading the subsampled entire room as input without block/sphere subsampling.
    Args:
        data_root (str, optional): Defaults to 'data/S3DIS/s3disfull'.
        test_area (int, optional): Defaults to 5.
        voxel_size (float, optional): the voxel size for donwampling. Defaults to 0.04.
        voxel_max (_type_, optional): subsample the max number of point per point cloud. Set None to use all points.  Defaults to None.
        split (str, optional): Defaults to 'train'.
        transform (_type_, optional): Defaults to None.
        loop (int, optional): split loops for each epoch. Defaults to 1.
        presample (bool, optional): wheter to downsample each point cloud before training. Set to False to downsample on-the-fly. Defaults to True.
        variable (bool, optional): where to use the original number of points. The number of point per point cloud is variable. Defaults to False.
        n_shifted (int, optional): the number of shifted coordinates to be used. Defaults to 1 to use the height.
    """

    def __init__(self,
                 data_root: str = 'data/S3DIS/s3disfull',
                 test_area: int = 5,
                 voxel_size: float = 0.04,
                 voxel_max=None,
                 split: str = 'train',
                 transform=None,
                 loop: int = 1,
                 presample: bool = False,
                 variable: bool = False,
                 n_shifted: int = 1,
                 append_height: bool = True
                 ):

        super().__init__()
        self.split, self.voxel_size, self.transform, self.voxel_max, self.loop = \
            split, voxel_size, transform, voxel_max, loop
        self.presample = presample
        self.variable = variable
        self.n_shifted = n_shifted
        self.append_height = append_height

        raw_root = os.path.join(data_root, 'raw')
        self.raw_root = raw_root
        data_list = sorted(os.listdir(raw_root))
        data_list = [item[:-4] for item in data_list if 'Area_' in item]
        if split == 'train':
            self.data_list = [
                item for item in data_list if not 'Area_{}'.format(test_area) in item]
        else:
            self.data_list = [
                item for item in data_list if 'Area_{}'.format(test_area) in item]

        processed_root = os.path.join(data_root, 'processed')
        filename = os.path.join(
            processed_root, f's3dis_{split}_area{test_area}_{voxel_size:.3f}.pkl')
        if presample and not os.path.exists(filename):
            np.random.seed(0)
            self.data = []            
            for item in tqdm(self.data_list, desc=f'Loading S3DISFull {split} split on Test Area {test_area}'):
                data_path = os.path.join(raw_root, item + '.npy')
                cdata = np.load(data_path).astype(np.float32)
                cdata[:, :3] -= np.min(cdata[:, :3], 0)
                if voxel_size is not None:
                    coord, feat, label = cdata[:,
                                               0:3], cdata[:, 3:6], cdata[:, 6:7]
                    coord, feat, label = crop_pc(
                        coord, feat, label, self.split, self.voxel_size, self.voxel_max,
                        downsample=not self.presample, variable=self.variable)
                    cdata = np.hstack((coord, feat, label))
                self.data.append(cdata)
            npoints = np.array([len(data) for data in self.data])
            logging.info('split: %s, median npoints %.1f, avg num points %.1f, std %.1f' % (
                self.split, np.median(npoints), np.average(npoints), np.std(npoints)))
            os.makedirs(processed_root, exist_ok=True)
            with open(filename, 'wb') as f:
                pickle.dump(self.data, f)
                print(f"{filename} saved successfully")
        elif presample:
            with open(filename, 'rb') as f:
                self.data = pickle.load(f)
                print(f"{filename} load successfully")
        self.data_idx = np.arange(len(self.data_list))
        assert len(self.data_idx) > 0
        logging.info(f"\nTotally {len(self.data_idx)} samples in {split} set")

    def __getitem__(self, idx):
        data_idx = self.data_idx[idx % len(self.data_idx)]
        if self.presample:
            coord, feat, label = np.split(self.data[data_idx], [3, 6], axis=1)
        else:
            data_path = os.path.join(
                self.raw_root, self.data_list[data_idx] + '.npy')
            cdata = np.load(data_path).astype(np.float32)
            cdata[:, :3] -= np.min(cdata[:, :3], 0)
            coord, feat, label = cdata[:, :3], cdata[:, 3:6], cdata[:, 6:7]
            coord, feat, label = crop_pc(
                coord, feat, label, self.split, self.voxel_size, self.voxel_max,
                downsample=not self.presample, variable=self.variable)
        label = label.squeeze(-1).astype(np.int_)
        data = {'pos': coord, 'x': feat, 'y': label}
        # augmentation
        if self.transform is not None:
            data['pos'], data['x'], data['y'] = self.transform(
                data['pos'], data['x'], data['y'])
        return data

    def __len__(self):
        return len(self.data_idx) * self.loop


In [3]:
train_ds = S3DIS(split='train', voxel_max=24000, presample=True, transform=T.Compose([
    T.PointCloudFloorCentering(), T.AppendHeight(),
    T.RandomScale(), T.RandomRotate(), T.RandomJitter(),
    T.ChromaticNormalize(), T.ChromaticAutoContrast(),
    T.RandomDropColor(), T.ToTensor()
]))
test_ds = S3DIS(split='val', presample=True, transform=T.Compose([
    T.PointCloudFloorCentering(), T.AppendHeight(),
    T.ChromaticNormalize(), T.ToTensor()]))


data/S3DIS/s3disfull/processed/s3dis_train_area5_0.040.pkl load successfully
data/S3DIS/s3disfull/processed/s3dis_val_area5_0.040.pkl load successfully


In [7]:
dataloader = DataLoader(test_ds, batch_size=1)

In [8]:
for item in dataloader:
    pass