In [1]:
import open3d as o3d
import numpy as np
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt


# Load your point cloud
pcd = o3d.io.read_point_cloud("pbr28.pcd")


In [26]:
class KMeansSegmentation:
    def __init__(self, pcd, n_clusters=2):
        self.pcd = pcd
        self.n_clusters = n_clusters
        self.labels = np.array([])
        self.pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.1, max_nn=100))


    def calculate_slope(self, normals):
        # Assuming that the normals are already normalized
        # The slope can be calculated as the arccosine of the dot product
        # of the normal with the Z-axis (0, 0, 1)
        z_axis = np.array([0, 0, 1])
        slopes = np.arccos(np.clip(np.dot(normals, z_axis), -1.0, 1.0))
        return slopes

    def calculate_curvature(self):
        # Create a KDTree for the point cloud
        pcd_tree = o3d.geometry.KDTreeFlann(self.pcd)

        # Placeholder for curvature values
        curvature = np.zeros(len(np.asarray(self.pcd.points)))

        for i, point in enumerate(self.pcd.points):
            # Search for nearest neighbors of the given point
            [k, idx, _] = pcd_tree.search_radius_vector_3d(point, 0.2)  # Radius 
            
            # Calculate the covariance matrix of the neighborhood
            if k < 3:
                continue
            
            neighbors = np.asarray(self.pcd.points)[idx, :]
            mean = np.mean(neighbors, axis=0)
            covariance_matrix = np.cov((neighbors - mean).T)

            # Eigen decomposition
            eigen_values, _ = np.linalg.eigh(covariance_matrix)
            eigen_values.sort()

            # The curvature can be approximated as the ratio of the smallest
            # to the sum of eigenvalues (Gaussian curvature approximation)
            curvature[i] = eigen_values[0] / np.sum(eigen_values)
        
        return curvature

    def segment(self):
        pcd_points = np.asarray(self.pcd.points)

        # Get normals and normalize them
        normals = np.asarray(self.pcd.normals)
        normals = (normals - np.mean(normals)) / np.std(normals)

        # Compute the relative height as Z minus the minimum Z
        min_z = np.min(pcd_points[:, 2])
        relative_height = pcd_points[:, 2] - min_z

        # Normalize the relative height
        relative_height = (relative_height - np.mean(relative_height)) / np.std(relative_height)

        # Calculate slope feature and Normalize it
        slopes = self.calculate_slope(normals)
        slopes = (slopes - np.mean(slopes)) / np.std(slopes)

        # Calculate curvature feature and Normalize it
        curvature = self.calculate_curvature()
        curvature = (curvature - np.mean(curvature)) / np.std(curvature)

        # Create a feature array by stacking coordinates, normals, and curvature
        features = np.hstack((pcd_points[:, :2], 
                              relative_height[:, np.newaxis], 
                              slopes[:, np.newaxis], 
                              curvature[:, np.newaxis],
                              normals))

        # Normalize the features all at once
        # features = (features - np.mean(features, axis=0)) / np.std(features, axis=0)

        # Apply K-Means
        kmeans = KMeans(n_clusters=self.n_clusters)
        kmeans.fit(features)

        # Assign clusters
        self.labels = kmeans.labels_
        
        return self.labels

    def visualize_segmentation(self):
        max_label = self.labels.max()
        colors = plt.get_cmap('viridis')(self.labels / (max_label if max_label > 0 else 1))
        colors = colors[:, :3]  # remove the alpha channel
        self.pcd.colors = o3d.utility.Vector3dVector(colors)
        o3d.visualization.draw_geometries([self.pcd])


In [27]:
# Create an instance of the KMeansSegmentation class
kmeans_segmentation = KMeansSegmentation(pcd, n_clusters=2)

# Perform K-Means segmentation
segment_labels = kmeans_segmentation.segment()


In [None]:
# Visualize the segmented point cloud
kmeans_segmentation.visualize_segmentation()