In [32]:
import math
from typing import Any, Dict, List
import copy
import numpy as np
from scipy.spatial.transform import Rotation
from scipy.stats import special_ortho_group
import torch
import torch.utils.data
import torchvision
import open3d as o3d

def se3_init(rot, trans):
    pose = np.concatenate([rot, trans], axis=-1)
    return pose
def se3_inv(pose):
    """Inverts the SE3 transform"""
    rot, trans = pose[..., :3, :3], pose[..., :3, 3:4]
    irot = rot.transpose(-1, -2)
    itrans = -irot @ trans
    return se3_init(irot, itrans)


def se3_transform(pose, xyz):
    """Apply rigid transformation to points

    Args:
        pose: ([B,] 3, 4)
        xyz: ([B,] N, 3)

    Returns:

    """

    assert xyz.shape[-1] == 3 and pose.shape[:-2] == xyz.shape[:-2]

    rot, trans = pose[..., :3, :3], pose[..., :3, 3:4]
    transformed = np.einsum('...ij,...bj->...bi', rot, xyz) + trans.transpose(-1, -2)  # Rx + t

    return transformed

def so3_transform(rot, xyz):
    """

    Args:
        rot: ([B,] 3, 3)
        xyz: ([B,] N, 3)

    Returns:

    """
    assert xyz.shape[-1] == 3 and rot.shape[:-2] == xyz.shape[:-2]
    transformed = np.einsum('...ij,...bj->...bi', rot, xyz)
    return transformed

class RandomTransformSE3:
    def __init__(self, rot_mag: float = 180.0, trans_mag: float = 1.0, random_mag: bool = False):
        """Applies a random rigid transformation to the source point cloud

        Args:
            rot_mag (float): Maximum rotation in degrees
            trans_mag (float): Maximum translation T. Random translation will
              be in the range [-X,X] in each axis
            random_mag (bool): If true, will randomize the maximum rotation, i.e. will bias towards small
                               perturbations
        """
        self._rot_mag = rot_mag
        self._trans_mag = trans_mag
        self._random_mag = random_mag

    def generate_transform(self):
        """Generate a random SE3 transformation (3, 4) """

        if self._random_mag:
            attentuation = np.random.random()
            rot_mag, trans_mag = attentuation * self._rot_mag, attentuation * self._trans_mag
        else:
            rot_mag, trans_mag = self._rot_mag, self._trans_mag

        # Generate rotation
        rand_rot = special_ortho_group.rvs(3)
        axis_angle = Rotation.as_rotvec(Rotation.from_dcm(rand_rot))
        axis_angle *= rot_mag / 180.0
        rand_rot = Rotation.from_rotvec(axis_angle).as_dcm()

        # Generate translation
        rand_trans = np.random.uniform(-trans_mag, trans_mag, 3)
        rand_SE3 = np.concatenate((rand_rot, rand_trans[:, None]), axis=1).astype(np.float32)

        return rand_SE3

    def apply_transform(self, p0, transform_mat):
        p1 = se3_transform(transform_mat, p0[:, :3])
        if p0.shape[1] == 6:  # Need to rotate normals also
            n1 = so3_transform(transform_mat[:3, :3], p0[:, 3:6])
            p1 = np.concatenate((p1, n1), axis=-1)

        igt = transform_mat
        gt = se3_inv(igt)

        return p1, gt, igt

    def transform(self, tensor):
        transform_mat = self.generate_transform()
        return self.apply_transform(tensor, transform_mat)

    def __call__(self, sample):

        if 'deterministic' in sample and sample['deterministic']:
            np.random.seed(sample['idx'])

        if 'points' in sample:
            sample['points'], _, _ = self.transform(sample['points'])
        else:
            src_transformed, transform_r_s, transform_s_r = self.transform(sample['points_src'])
            sample['transform_gt'] = transform_r_s  # Apply to source to get reference
            sample['points_src'] = src_transformed

        return sample


# noinspection PyPep8Naming
class RandomTransformSE3_euler(RandomTransformSE3):
    """Same as RandomTransformSE3, but rotates using euler angle rotations

    This transformation is consistent to Deep Closest Point but does not
    generate uniform rotations

    """
    def generate_transform(self):

        if self._random_mag:
            attentuation = np.random.random()
            rot_mag, trans_mag = attentuation * self._rot_mag, attentuation * self._trans_mag
        else:
            rot_mag, trans_mag = self._rot_mag, self._trans_mag

        # Generate rotation
        anglex = np.random.uniform() * np.pi * rot_mag / 180.0
        angley = np.random.uniform() * np.pi * rot_mag / 180.0
        anglez = np.random.uniform() * np.pi * rot_mag / 180.0

        cosx = np.cos(anglex)
        cosy = np.cos(angley)
        cosz = np.cos(anglez)
        sinx = np.sin(anglex)
        siny = np.sin(angley)
        sinz = np.sin(anglez)
        Rx = np.array([[1, 0, 0],
                       [0, cosx, -sinx],
                       [0, sinx, cosx]])
        Ry = np.array([[cosy, 0, siny],
                       [0, 1, 0],
                       [-siny, 0, cosy]])
        Rz = np.array([[cosz, -sinz, 0],
                       [sinz, cosz, 0],
                       [0, 0, 1]])
        R_ab = Rx @ Ry @ Rz
        t_ab = np.random.uniform(-trans_mag, trans_mag, 3)

        rand_SE3 = np.concatenate((R_ab, t_ab[:, None]), axis=1).astype(np.float32)
        return rand_SE3
class ReadPcd:
    """read pcd from .pcd"""
    def __call__(self, sample: Dict):
        sample['src_pcd'] = o3d.io.read_point_cloud(sample['src_pcd'])
        sample['src_pcd'] = np.asarray(sample['src_pcd'].points).astype(np.float32)
        sample['tar_pcd'] = o3d.io.read_point_cloud(sample['tar_pcd'])
        sample['tar_pcd'] = np.asarray(sample['tar_pcd'].points).astype(np.float32)

        n_points = sample['src_pcd'].shape[0]
        sample['correspondences'] = np.tile(np.arange(n_points), (2, 1))

        return sample
class RandomTransform(RandomTransformSE3_euler):

    def __call__(self, sample:Dict):
        src_transformed, transform_r_s, _ = self.transform(sample['src_pcd'])
        sample['transform_gt'] = transform_r_s  # Apply to source to get reference
        sample['src_raw'] = sample.pop('src_pcd')
        sample['src_pcd'] = src_transformed
        return sample

In [48]:
sample = {
            'src_pcd':'/home/yangqi/Documents/DeepLearning/dataset/CustomData/train_data/src/src_41_left_0.ply',
            'tar_pcd':'/home/yangqi/Documents/DeepLearning/dataset/CustomData/train_data/tar/tar_41_left_0.ply',
            'idx': 0
        }
transforms = [ReadPcd(), 
                RandomTransform(rot_mag=45.0, trans_mag=0.5)]
train_transforms = torchvision.transforms.Compose(transforms)
sample = train_transforms(sample)

def show_np(pcd):
    cld = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(pcd))
    o3d.visualization.draw_geometries([cld])
def get_correspondences(src_ply, tgt_ply, transf, search_radius, K=1):
    transf = np.vstack((transf, np.array([0,0,0,1]))).astype(np.float32)
    src_ply = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(src_ply))
    tgt_ply = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(tgt_ply))
    src_ply.transform(transf)
    pcd_tree = o3d.geometry.KDTreeFlann(tgt_ply)
    src_npy = np.array(src_ply.points)
    corrs = []
    for i in range(src_npy.shape[0]):
        point = src_npy[i]
        [k, idx, _] = pcd_tree.search_radius_vector_3d(point, search_radius)
        if K is not None:
            idx = idx[:K]
        for j in idx:
            corrs.append([i, j])
    return np.array(corrs)
coor = get_correspondences(sample['src_pcd'], sample['tar_pcd'], sample['transform_gt'], 5)

# temp_src = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(sample['src_raw']))
# temp_src.paint_uniform_color([0,0.5,0])
# color = np.array(temp_src.colors)
# color[coor[:,0]] = [1,0,0]
# temp_src.colors = o3d.utility.Vector3dVector(color)
# temp_tar = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(sample['tar_pcd']))
# temp_tar.paint_uniform_color([0,0.5,0])
# color = np.array(temp_tar.colors)
# color[coor[:,1]] = [0,0,1]
# temp_tar.colors = o3d.utility.Vector3dVector(color)

# o3d.visualization.draw_geometries([temp_src,temp_tar])
a = coor.T
print(len(a[0]), len(set(a[0])), a[0])


1102 1102 [   0    1    2 ... 2044 2046 2047]


1607 1102 [   0    1    2 ... 2046 2047 2047]
