In [None]:
# Import Necessary Libraries
import open3d as o3d
import numpy as np
from pc_skeletor.laplacian import SLBC
from scipy.spatial import cKDTree
import time
import laspy
import os
import copy
from math import ceil
import matplotlib.pyplot as plt
from multiprocessing import Pool, cpu_count
from math import ceil
from sklearn.cluster import KMeans
import vtk

In [None]:
# compute FPFH (Fast Point Feature Histograms)
def compute_fpfh(pcd):
    radius_normal = 0.4
    # print(f":: Estimating normals with search radius {radius_normal:.3f}.")
    pcd.estimate_normals(o3d.geometry.KDTreeSearchParamHybrid(radius=radius_normal, max_nn=30))

    radius_feature = 0.4
    # print(f":: Computing FPFH feature with search radius {radius_feature:.3f}.")
    pcd_fpfh = o3d.pipelines.registration.compute_fpfh_feature(
        pcd,
        o3d.geometry.KDTreeSearchParamHybrid(radius=radius_feature, max_nn=30)
    )
    return pcd_fpfh

# perform Fast Global Registration (FGR) between source and target point clouds
def perform_fast_global_registration(source, target, source_fpfh, target_fpfh):
    distance_threshold = 0.4
    # print(f":: Applying fast global registration with distance threshold {distance_threshold:.3f}")
    result = o3d.pipelines.registration.registration_fgr_based_on_feature_matching(
        source, target, source_fpfh, target_fpfh,
        o3d.pipelines.registration.FastGlobalRegistrationOption(
            maximum_correspondence_distance=distance_threshold)
    )
    return result

# save the registration result by transforming
def save_registration(source, target, transformation):
    # Compute FPFH features for source and target point clouds
    source_temp = copy.deepcopy(source)
    target_temp = copy.deepcopy(target)
    source_temp.transform(transformation)
    return source_temp

# perform multiple iterations of registration and choose the best result
def iterative_registration(source_ply, target_ply, iterations):
    source_fpfh = compute_fpfh(source_ply)
    target_fpfh = compute_fpfh(target_ply)
    best_rmse = float('inf')
    best_result = None

    # Iterate the registration process
    for _ in range(iterations):
        result = perform_fast_global_registration(source_ply, target_ply, source_fpfh, target_fpfh)
        if result.inlier_rmse < best_rmse:
            best_rmse = result.inlier_rmse
            best_result = result
            if best_rmse < 0.01:
                break
    return best_result


# Visualize Registration Results
def draw_registration_result(point_cloud, window_name):
    visualizer = o3d.visualization.Visualizer()
    visualizer.create_window(window_name=window_name)

    render_option = visualizer.get_render_option()
    render_option.background_color = [0, 0, 0]
    render_option.light_on = False
    visualizer.add_geometry(point_cloud)

    visualizer.run()
    visualizer.destroy_window()

In [None]:
def create_folder_if_not_exists(folder_path):
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)


# Nearest neighbor search removes the stem point cloud
def remove_nearest_points(source_point_cloud, target_point_cloud, distance_threshold):
    # Create a KDTree
    kdtree_target = o3d.geometry.KDTreeFlann(target_point_cloud)

    correspondences = []
    corresponding_colors = []
    # # Iterate over each point in the source point cloud
    for idx in range(len(source_point_cloud.points)):
        point_source = np.asarray(source_point_cloud.points[idx])
        # Perform radius search in the KDTree to find points within the distance threshold
        [_, idx_b, _] = kdtree_target.search_radius_vector_3d(point_source, distance_threshold)
        correspondences.append(idx_b)
        corresponding_colors.append(np.asarray(target_point_cloud.colors)[np.asarray(idx_b).flatten()])

    correspondences = np.hstack(correspondences)
    correspondences = np.array(correspondences).flatten()

    # Find the complement of the indices in the target point cloud
    all_indices = set(np.arange(len(target_point_cloud.points)))
    subset_indices = set(correspondences)
    complement_indices = all_indices - subset_indices
    complement_list = list(complement_indices)

    # Select the complement indices to create the filtered target point cloud
    filtered_target_point_cloud = target_point_cloud.select_by_index(complement_list)

    return filtered_target_point_cloud


# Color threshold removes stem point clouds
def remove_by_color(point_cloud, color_threshold):
    colors = np.asarray(point_cloud.colors)
    # Convert RGB colors to grayscale
    gray_colors = []
    for color in colors:
        gray = color[0] * 0.299 + color[1] * 0.587 + color[2] * 0.114
        gray_colors.append([gray, gray, gray])

    # Create a new point cloud with grayscale colors
    graypcd = o3d.geometry.PointCloud()
    graypcd.points = o3d.utility.Vector3dVector(point_cloud.points)
    graypcd.colors = o3d.utility.Vector3dVector(gray_colors)

    color_array = np.asarray(graypcd.colors)
    selected_indices = np.where(color_array > color_threshold)[0]
    selected_points = graypcd.select_by_index(selected_indices)
    return selected_points

In [None]:
# Function to perform DBSCAN clustering on a point cloud
def dbscan_cluster(point_cloud,trunk_radius,output_folder):
    # Perform DBSCAN clustering on the point cloud
    with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Debug) as cm:
        labels = np.array(point_cloud.cluster_dbscan(eps=trunk_radius/2, min_points=20, print_progress=False))
    max_label = labels.max()

    j = 1
    for i in range(max_label + 1):
        indices = np.where(labels == i)[0]
        # Save the cluster to a file if it contains at least 50 points
        if len(indices) >= 50:
            cluster_point_cloud = point_cloud.select_by_index(indices)
            file_name = f"Dbscan_cluster{j}.pcd"
            file_path = os.path.join(output_folder, file_name)
            o3d.io.write_point_cloud(file_path, cluster_point_cloud)
            j += 1

    colors = plt.get_cmap("tab20")(labels / (max_label if max_label > 0 else 1))
    colors[labels < 0] = 0
    point_cloud.colors = o3d.utility.Vector3dVector(colors[:, :3])
    return point_cloud


# convert a point cloud to a mesh
def point_cloud_to_mesh(point_cloud_path, alpha=1):
    point_cloud = o3d.io.read_point_cloud(point_cloud_path)
    # Create a mesh from the point cloud using the Alpha Shape algorithm
    mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_alpha_shape(point_cloud, alpha)
    mesh.compute_vertex_normals()
    # mesh_area = mesh.get_surface_area()
    return mesh


# compute the surface area and volume of a mesh
def compute_area_volume(mesh_path):
    # Create a VTK reader for PLY files
    vtkReader = vtk.vtkPLYReader()
    vtkReader.SetFileName(mesh_path)
    vtkReader.Update()
    # Get the output polydata from the reader
    polydata = vtkReader.GetOutput()
    mass_properties = vtk.vtkMassProperties()
    mass_properties.SetInputData(polydata)

    return mass_properties


# classify data using K-means clustering
def kmeans_classify(trait_data, n_clusters=2):
    kmeans = KMeans(n_clusters=n_clusters, n_init=10, random_state=42)
    kmeans.fit(trait_data)

    kmeans_labels = kmeans.labels_
    # cluster_centers = kmeans.cluster_centers_
    return kmeans_labels


# color a point cloud based on the given label
def color_points_by_label(point_cloud, label):
    if label == 1:
        color = [1.0, 0.0, 0.0]  # Red
    elif label == 0:
        color = [0.0, 1.0, 0.0]  # Green
    else:
        color = [0.5, 0.5, 0.5]  # Gray (for unknown labels)

    point_cloud.paint_uniform_color(color)
    return point_cloud


# fix labels
def fix_label(labels):
    num_zeros = np.sum(labels == 0)
    num_ones = len(labels) - num_zeros

    # label #0 alaways greater than label #1
    if num_zeros > num_ones:
        new_labels = labels
    else:
        new_labels = 1 - labels

    return new_labels


# compute the number of flowers given points for single and multi-flower clusters
def compute_cluster_num(single_flower_points, multi_flower_points):
    # Calculate the average number of points using the top 10 largest single flower clusters
    average_point_num = sum(sorted(single_flower_points, reverse=True)[:10]) / 10
    flower_num = [ceil(x / average_point_num) for x in multi_flower_points]
    # Calculate the total number of flowers by summing the estimated counts of flowers in multi-flower clusters
    total_flower_num = sum(flower_num) + len(single_flower_points)

    return total_flower_num, average_point_num

# Extract the traits of each tree
def process_tree_traits(tree_result_folder):

    dbscan_folder = os.path.join(tree_result_folder, "dbscan_result")
    mesh_folder = os.path.join(tree_result_folder, "dbscan_mesh")
    tree_trait_list = []

    # Iterate over each tree in the DBSCAN results folder
    for tree_num in os.listdir(dbscan_folder):
        tree_path = os.path.join(dbscan_folder, tree_num)
        single_tree_trait = []
        for name in os.listdir(tree_path):
            cluster_path = os.path.join(tree_path, name)
            pcd = o3d.io.read_point_cloud(cluster_path)
            tree_mesh_folder = os.path.join(mesh_folder, tree_num)
            create_folder_if_not_exists(tree_mesh_folder)
            mesh_path = os.path.join(tree_mesh_folder, name[:-4] + ".ply")
            # Convert the point cloud to a mesh
            mesh = point_cloud_to_mesh(cluster_path)
            o3d.io.write_triangle_mesh(mesh_path, mesh)

            # Compute the mass properties (surface area and volume) of the mesh
            mass_properties = compute_area_volume(mesh_path)
            single_tree_trait.append([len(pcd.points), mass_properties.GetSurfaceArea(), mass_properties.GetVolume()])
        tree_trait_list.append(single_tree_trait)

    return tree_trait_list


# classify and count flowers
def classify_and_count_flowers(tree_result_folder):
    # Process tree traits to get a list of traits for each tree
    tree_trait_list = process_tree_traits(tree_result_folder)

    dbscan_folder = os.path.join(tree_result_folder, "dbscan_result")
    tree_num_folders = [os.path.join(dbscan_folder, num) for num in os.listdir(dbscan_folder)]
    all_tree_pcd_paths = []
    for i in range(len(tree_num_folders)):
        pcd_path_list = []
        for name in os.listdir(tree_num_folders[i]):
            cluster_path = os.path.join(tree_num_folders[i], name)
            pcd_path_list.append(cluster_path)
        all_tree_pcd_paths.append(pcd_path_list)

    tree_names = os.listdir(dbscan_folder)
    colored_pcd_output_folder = os.path.join(tree_result_folder, "pcd_cluster_with_color")
    create_folder_if_not_exists(colored_pcd_output_folder)
    flower_num_dict = {}

    # Iterate over the point cloud paths and tree traits for each tree
    for i in range(len(all_tree_pcd_paths)):
        tree_trait = np.array(tree_trait_list[i])
        pcd_path_list = all_tree_pcd_paths[i]

        # Classify the tree traits using K-Means
        labels = kmeans_classify(np.asarray(tree_trait_list[i]))
        fixed_labels = fix_label(labels)

        single_flower_points = []
        multi_flower_points = []
        combined_pcd = o3d.geometry.PointCloud()

        for pcd_path, label in zip(pcd_path_list, fixed_labels):
            pcd = o3d.io.read_point_cloud(pcd_path)
            # Color the points in the point cloud based on their label
            colored_pcd = color_points_by_label(pcd, label)
            combined_pcd += colored_pcd
            if label == 0:
                single_flower_points.append(len(pcd.points))
            elif label == 1:
                multi_flower_points.append(len(pcd.points))

        # Compute the number of flower clusters
        cluster_num, average_point_num = compute_cluster_num(single_flower_points, multi_flower_points)
        flower_num_dict[tree_names[i]] = cluster_num

        output_path = os.path.join(colored_pcd_output_folder, tree_names[i] + ".pcd")
        o3d.io.write_point_cloud(output_path, combined_pcd)