In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import points_query
import phf_cuda
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import open3d as o3d

In [None]:
class S3dis(Dataset):
    def __init__(self, root, split, loop, npoints=24000, voxel_size=0.04, test_area=5, transforms=None):
        super(S3dis, self).__init__()
        self.root = root
        self.split = split
        self.loop = loop
        self.npoints = npoints
        self.voxel_size = voxel_size
        self.transforms = transforms
        self.idx_to_class = {0: 'ceiling', 1: 'floor', 2: 'wall', 3: 'beam', 4: 'column', 
                5: 'window', 6: 'door', 7: 'table', 8: 'chair', 9: 'sofa', 10: 'bookcase', 11: 'board', 12: 'clutter'}
        
        room_list = os.listdir(root)
        if split == 'train':
            self.room_list = list(filter(lambda x : f'Area_{test_area}' not in x, room_list))
        else:
            self.room_list = list(filter(lambda x : f'Area_{test_area}' in x, room_list))
    
    def __len__(self):
        return len(self.room_list) * self.loop

    def voxel_grid_sampling(self, pos):
        """
        pos.shape = (n, 3)
        """
        voxel_indices = np.floor(pos / self.voxel_size).astype(np.int64)
        voxel_max = voxel_indices.max(axis=0)
        
        temp = np.ones_like(voxel_max)
        temp[1] = voxel_max[0]
        temp[2] = voxel_max[0] * voxel_max[1]
        
        voxel_hash = (voxel_indices * temp).sum(axis=-1)
        sort_idx = voxel_hash.argsort()
        
        _, counts = np.unique(voxel_hash, return_counts=True)
        if self.split == 'test':   # test时需要的东西和train，val时不同
            return sort_idx, counts
        
        idx_select = np.cumsum(np.insert(counts, 0, 0)[0:-1]) + np.random.randint(0, counts.max(), counts.size) % counts
        return sort_idx[idx_select]
    
    def __getitem__(self, index):
        room = os.path.join(self.root, self.room_list[index % len(self.room_list)])
        points = np.load(room)
        
        # 大家都这样做
        points[:, 0:3] = points[:, 0:3] - np.min(points[:, 0:3], axis=0)
        
        if self.split == 'test':
            sort_idx, counts = self.voxel_grid_sampling(points[:, 0:3])
            pos, x, y = points[:, 0:3], points[:, 3:-1], points[:, -1]
            pos, x, y = pos.astype(np.float32), x.astype(np.float32), y.astype(np.int64)
            return pos, x, y, sort_idx, counts
        
        # train, val的流程
        sample_indices = self.voxel_grid_sampling(points[:, 0:3])
        # 再随机采固定个点
        if self.split == 'train':
            sample_indices = np.random.choice(sample_indices, (self.npoints, ))
        pos, x, y = points[sample_indices, 0:3], points[sample_indices, 3:-1], points[sample_indices, -1]
        if self.transforms:
            pos, x = self.transforms(pos, x)
        
        pos, x, y = pos.astype(np.float32), x.astype(np.float32), y.astype(np.int64)
        return pos, x, y


class Compose:
    def __init__(self, transforms):
        """
        transforms: List
        """
        self.transforms = transforms

    def __call__(self, pos, x):
        for transform in self.transforms:
            pos, x = transform(pos, x)
        return pos, x


class PointCloudFloorCentering:
    def __init__(self):
        pass
    
    def __call__(self, pos, x):
        pos = pos - pos.mean(axis=0, keepdims=True)
        pos[:, 2] = pos[:, 2] - pos[:, 2].min()
        
        return pos, x


class ColorNormalize:
    def __init__(self, mean=[0.5136457, 0.49523646, 0.44921124], std=[0.18308958, 0.18415008, 0.19252081]):
        self.mean = mean
        self.std = std
    
    def __call__(self, pos, x):
        x = x / 255
        x = (x - self.mean) / self.std
        
        return pos, x


def index_points(points, indices):
    """
    points.shape = (b, n, c)
    indices.shape = (b, nsamples) or (b, nsamples, k)
    return res.shape = (b, nsamples, c) or (b, nsamples, k, c)
    """
    device = points.device
    b = points.shape[0]

    view_shape = list(indices.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    expand_shape = list(indices.shape)
    expand_shape[0] = -1
    batch_indices = torch.arange(b, device=device).view(view_shape).expand(expand_shape)
    res = points[batch_indices, indices, :]

    return res


def my_knn_query(k, query_pos, all_pos, all_x):
    """
    query_pos.shape = (b, sample, 3)
    all_pos.shape = (b, n, 3)
    all_x.shape = (b, n, c)
    return shape = (b, sample, k, 3), (b, sample, k, c), (b, sample, k)
    """
    b, m, _ = query_pos.shape
    device = query_pos.device
    k_indices = torch.zeros((b, m, k), dtype=torch.long, device=device)
    k_dis = torch.zeros((b, m, k), dtype=torch.float32, device=device)
    
    points_query.knn_query(k, all_pos, query_pos, k_indices, k_dis)
    return index_points(all_pos, k_indices), index_points(all_x, k_indices), k_dis


def my_ball_query(radius, k, query_pos, all_pos, all_x):
    """
    query_pos.shape = (b, sample, 3)
    all_pos.shape = (b, n, 3)
    all_x.shape = (b, n, c)
    return shape = (b, sample, k, 3), (b, sample, k, c), (b, sample, k)
    """
    b, m, _ = query_pos.shape
    device = query_pos.device
    k_indices = torch.zeros((b, m, k), dtype=torch.long, device=device)
    k_dis = torch.zeros((b, m, k), dtype=torch.float32, device=device)
    
    points_query.ball_query(k, radius, all_pos, query_pos, k_indices, k_dis)
    return index_points(all_pos, k_indices), index_points(all_x, k_indices), k_dis


def point_hist_feature(group_points, distance):
    """
    group_points.shape = (b, n, k, 3)   相对坐标
    distance,shape = (b, n, k)
    return res.shape = (b, n, 8)
    """
    b, n, k, _ = group_points.shape
    device = group_points.device
    masks = torch.zeros((b, n, k, 8), device=device, dtype=torch.float32)
    
    phf_cuda.phf(group_points, masks)
    dist = distance.unsqueeze(dim=-1)

    hist_features = (dist * masks).sum(dim=2)
    return hist_features


def get_normal(pos):
    b, n, _ = pos.shape
    pos = pos.to(device='cpu', dtype=torch.float64).numpy()
    res = np.zeros((b, n, 3), dtype=np.float64)
    
    for i in range(b):
        temp = pos[i]
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(temp)   # 必须是float64
        # pcd.estimate_normals(o3d.geometry.KDTreeSearchParamRadius(0.2))
        pcd.estimate_normals()
        res[i] = pcd.normals
    
    res = torch.as_tensor(res.astype(np.float32))
    return res

In [None]:
val_aug = Compose([PointCloudFloorCentering(),
                            ColorNormalize()])
val_dataset = S3dis('/home/lindi/chenhr/threed/data/processed_s3dis', split='val', loop=1, transforms=val_aug)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=8)

val_iter = iter(val_dataloader)
for i in range(6):
    pos, x, y = next(val_iter)

# 数据转移到gpu
device = 'cuda:2'
pos = pos.to(device=device)
x = x.to(device=device)
y = y.to(device=device)

neigh_pos, _, dis = my_knn_query(64, pos, pos, x)
neigh_pos = neigh_pos - pos.unsqueeze(dim=2)
hist_features = point_hist_feature(neigh_pos[:, :, 1:, :], dis[:, :, 1:])   # hist_features.shape = (b, n, 8)

# neigh_pos, _, _ = my_ball_query(0.1, 8, pos, pos, x)
# centers = neigh_pos.mean(dim=2, keepdim=True)
# neigh_pos = neigh_pos - centers
# dis = (neigh_pos ** 2).sum(dim=-1)
# hist_features = point_hist_feature(neigh_pos, dis)

hist_features = hist_features[0].to('cpu').numpy()
y = y[0].to('cpu').numpy()
# tsne = TSNE(init='random', learning_rate='auto')
# low_features = tsne.fit_transform(hist_features)
# print(low_features.shape)

In [None]:
temp = hist_features[y == 0].mean(axis=0)
print(temp)

In [None]:
temp = hist_features[y == 1].mean(axis=0)
print(temp)

In [None]:
val_aug = Compose([PointCloudFloorCentering(),
                            ColorNormalize()])
val_dataset = S3dis('/home/lindi/chenhr/threed/data/processed_s3dis', split='val', loop=1, transforms=val_aug)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=8)

val_iter = iter(val_dataloader)
for i in range(6):
    pos, x, y = next(val_iter)

normal = get_normal(pos)
print(normal.shape)
y = y[0]

In [None]:
tsne = TSNE(init='random', learning_rate='auto')
low_features = tsne.fit_transform(normal[0])
print(low_features.shape)

In [None]:
val_aug = Compose([PointCloudFloorCentering(),
                            ColorNormalize()])
val_dataset = S3dis('/home/lindi/chenhr/threed/data/processed_s3dis', split='val', loop=1, transforms=val_aug)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=8)

val_iter = iter(val_dataloader)
for i in range(5):
    pos, x, y = next(val_iter)

pos = pos[0].to(torch.float64).numpy()
x = x[0].to(torch.float64).numpy()
y = y[0].to(torch.float64).numpy()

tsne = TSNE(init='random', learning_rate='auto')
low_features = tsne.fit_transform(x)
print(low_features.shape)

In [None]:
mask = (y == 0)
plt.scatter(low_features[mask][:, 0], low_features[mask][:, 1])

In [None]:
mask = (y == 1)
plt.scatter(low_features[mask][:, 0], low_features[mask][:, 1])

In [None]:
mask = (y == 2)
plt.scatter(low_features[mask][:, 0], low_features[mask][:, 1])

In [None]:
mask = (y == 4)
plt.scatter(low_features[mask][:, 0], low_features[mask][:, 1])

In [None]:
mask = (y == 5)
plt.scatter(low_features[mask][:, 0], low_features[mask][:, 1])

In [None]:
mask = (y == 6)
plt.scatter(low_features[mask][:, 0], low_features[mask][:, 1])

In [None]:
mask = (y == 7)
plt.scatter(low_features[mask][:, 0], low_features[mask][:, 1])

In [None]:
mask = (y == 8)
plt.scatter(low_features[mask][:, 0], low_features[mask][:, 1])

In [None]:
mask = (y == 9)
plt.scatter(low_features[mask][:, 0], low_features[mask][:, 1])

In [None]:
mask = (y == 10)
plt.scatter(low_features[mask][:, 0], low_features[mask][:, 1])

In [None]:
mask = (y == 11)
plt.scatter(low_features[mask][:, 0], low_features[mask][:, 1])

In [None]:
mask = (y == 12)
plt.scatter(low_features[mask][:, 0], low_features[mask][:, 1])