# Extract tree trunks

The input for this notebook is the pre-processed cloud that results from notebook `1. Tree filter.ipynb`. From this data we try to extract individual tree trunks. The methods used here are based on the [internship project of Jorges Nofulla](https://github.com/Amsterdam-Internships/Tree-trunk-segmentation).

(**Note**: this notebook currently works on a single input file only, in contrast to notebooks 1-3)

In [None]:
import set_path

import numpy as np
import laspy as lp
import scipy
import matplotlib.pyplot as plt
from scipy.spatial import distance
from scipy.spatial import KDTree
import open3d as o3d
import pandas as pd
import geopandas as gpd
import shapefile
import pathlib

import gvl.trunk_utils as utils

## Settings

In [None]:
DATA_FOLDER = pathlib.Path('../data') 

my_tile = '2496_9727'

# Input
input_las = DATA_FOLDER / ("ahn4_trees/trees_" + my_tile + ".laz")
area_file = DATA_FOLDER / "ground_truth/gt_area.gpkg"

# Output
output_dir = DATA_FOLDER / "ahn4_trunks"
output_trunks = output_dir / ("trunks_" + my_tile + ".laz")
output_centroids = output_dir / ("trunk_centroids_" + my_tile + ".laz")

CRS = "epsg:28992"

In [None]:
# Parameters (in meters)
r = 3  # radius of the search sphere for the initial clustering
radius = 0.8  # the radius on which we count the point density in x and y for each point
# (the parameter used for local maxima calculation)
window_size = 4  # the size of the search window for local maxima in each cluster
max_distance = 0.8  # the delineated trunks radius
restrict_d = (
    3  # the minimum eucledian distance that 2 peaks of the same cluster can have
)
small_clusters = 100  # the size of the small custers we suspect as outliers
# (won't be deleted, they will just merge with a nearby big cluster if there is any,
# else they will be taken as individual clusters)
small_outliers = 30  # the minimal cluster size to be allowed as a tree.
# Deleting every cluster below this value (optional).
diff_height = (
    1.5  # the difference in height between 2 clusters very close to each other
)
# (this is the parameter to take care of branches that are classified as a separate cluster)
branch_dist = 0.8  # the max distance a branch cluster can be from the main tree
min_dist_tree = (
    1  # the max distance of 2 clusters to be checked if they are the same tree
)

In [None]:
# Create output folder
output_dir.mkdir(parents=True, exist_ok=True)

## 1. Load and Prepare the data

In [None]:
las_file = lp.read(input_las)

In [None]:
# Take only 'tree' points
las_file.points = las_file.points[las_file.label == 1]

In [None]:
# Concatenate the file coordinates
coord = np.c_[las_file.x, las_file.y, las_file.z]
coord.shape

### Reduce amount of points by downsampling

In [None]:
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(coord)
pcd_down = pcd.voxel_down_sample(voxel_size=0.5)
coord = np.asarray(pcd_down.points)
coord.shape

### Reduce amount of points by selecting an area

In [None]:
# Get area data
df_area = gpd.read_file(area_file)

# Put points in geodataframe
df_coord = pd.DataFrame(coord)
gdf_coord = gpd.GeoDataFrame(
    df_coord, geometry=gpd.points_from_xy(df_coord[0], df_coord[1]), crs=CRS
)

In [None]:
# Get only points within area
gdf_coord_sel = gdf_coord.sjoin(df_area, predicate="within").drop(
    ["index_right"], axis=1
)
coord = np.c_[gdf_coord_sel.geometry.x, gdf_coord_sel.geometry.y, gdf_coord_sel[2]]
coord.shape

### Reshape data

In [None]:
# Sort the coordinates by z value
position = coord[coord[:, 2].argsort()]

In [None]:
# Create a list of "point" class for each set of coordinates
points = []
for i in range(len(position)):
    i = utils.Point(i, position[i])
    points.append(i)

## 2. Find centroids of point clusters and tree peaks

1. A collection of points in 3D space is given, with a manually input radius value.
2. The code finds groups of points that are within the radius of each other, and it computes their group centroids.
3. For each group, it finds the point with the highest Z-value (i.e., the top of the tree), and links it to the centroid.
4. The code outputs the index of the closest point to the centroid for each group, and whether each point is the highest point of its group (i.e., at the top of the tree).

In [None]:
# Find all points within distance r of point(s) x
tree = scipy.spatial.cKDTree(position)
nn = tree.query_ball_point(position, r)

In [None]:
links = np.zeros(len(position), dtype=int)
centroids = np.zeros((len(position), 3))
has_parent = np.zeros(len(position), dtype=bool)

# Loop over all points
for i, this_nn in enumerate(nn):
    # If the point has no neighbors within radius r, it is a tree peak
    if len(this_nn) == 1:
        links[i] = i
        centroids[i] = position[i]
        has_parent[i] = True
    # If the point has at least one neighbor within radius r
    else:
        # Find all neighbors with a higher z value
        upper_nnbs = [j for j in this_nn if position[j, 2] > position[i, 2]]
        # If there are no such neighbors, the point is a tree peak
        if not upper_nnbs:
            links[i] = i
            centroids[i] = position[i]
            has_parent[i] = True
        # If there are any neighbors with a higher z value
        else:
            # Calculate the centroid of the group of neighbors
            centroids[i] = np.mean(position[upper_nnbs], axis=0)
            # Calculate the distances between each neighbor and the centroid
            dist = scipy.spatial.distance.cdist(
                position[upper_nnbs], [centroids[i]], metric="euclidean"
            )
            # Find the neighbor closest to the centroid and store its index as a link
            links[i] = upper_nnbs[np.argmin(dist)]

has_parent = has_parent.astype("int")

## 3. Label the points

1. For each point, the code checks if it has already been assigned to a path.
2. If not, it creates a new path and adds the current point to it.
3. It then follows the links created in Part 2 to add more points to the path, until it reaches a point with no parent (i.e., at the top of the tree), at which point it ends the path.
4. If the code encounters a point that is already in a path, it creates a new network that includes both the new path and the existing path.

In [None]:
networks = []
all_paths = []
for p in points:
    current_idx = p.index

    if len(points[current_idx].paths) == 0:
        end = False

        # initialize new path
        new_path = utils.Path(len(all_paths))  # len paths as index
        all_paths.append(new_path)

        # add first point to the path
        new_path.add_point(points[current_idx])
        points[current_idx].add_path(new_path)

        # append path
        while end is False:
            # point has a parent
            if has_parent[current_idx] != 1:
                # make link
                points[current_idx].linked_to = points[links[current_idx]]

                if len(points[current_idx].linked_to.paths) == 0:
                    # not in path
                    points[current_idx].linked_to.add_path(new_path)
                    new_path.add_point(points[current_idx].linked_to)
                    current_idx = links[current_idx]

                else:
                    # in path
                    points[current_idx].linked_to.network.add_path(new_path)
                    points[current_idx].add_path(new_path)
                    points[current_idx].linked_to.add_path(new_path)
                    end = True

            # point has no parent
            # make network, end path
            else:
                points[current_idx].linked_to = points[current_idx]
                # init new network
                new_network = utils.Network(len(networks))  # len networks as index
                new_network.add_path(
                    new_path
                )  # path and points are assigned to network
                new_network.top = current_idx
                new_network.points = new_path.points  # add points to the network
                networks.append(new_network)
                points[current_idx].network = new_network
                end = True

## 4. Remove all the outlier clusters

### Get the labels array

In [None]:
# Create array to extract and store all our individual tree labels from
labels = np.zeros(len(points))

# Extract the label value from class network to our new built array
for p in points:
    labels[p.index] = p.network.index
labels = labels.astype("int")

array_test = np.column_stack((position, labels))

In [None]:
# Get the count of each cluster label
labels_new = array_test[:, 3]
array = array_test[:, 0:3]

### Remove clusters

In [None]:
# Create a dictionary to store the count of each label
unique, counts = np.unique(labels_new, return_counts=True)
label_count = dict(zip(unique, counts))

# Initialize an empty list to store the indices of the large clusters
large_cluster_indices = []

# Iterate through the cluster labels
for i, label in enumerate(labels_new):
    # If the label corresponds to a large cluster, add the index to the list
    if label_count.get(label, 0) >= 10:
        large_cluster_indices.append(i)

# Use the indices of the large clusters to create a new array
array_test = array[large_cluster_indices, :]

# Add the labels as the last column of the new array
array_test = np.column_stack((array_test, labels_new[large_cluster_indices]))

## 5. Fix the small clusters

In [None]:
# Prepare the array for the "fix small clusters" code
labels_2 = array_test[:, 3].astype("int")
labels33, point_count33 = np.unique(labels_2, return_counts=True)

In [None]:
iterating_array = []
for i in range(len(labels33)):
    if point_count33[i] <= small_clusters:
        iterating_array.append(labels33[i])

In [None]:
# Get centroids of all clusters in the dataset
all_centroids = []
all_labs = []
for label in np.unique(array_test[:, 3]):
    centroid = array_test[array_test[:, 3] == label, :2].mean(axis=0)
    all_centroids.append(centroid)
    all_labs.append(label)

In [None]:
# Find the pairs of the closest clusters
tree1 = KDTree(all_centroids)

labels_nn = []
for i in range(len(all_labs)):
    point_cent = all_centroids[i]
    dist, idx = tree1.query(point_cent, k=2)
    closest_idx = idx[1] if idx[0] == i else idx[0]
    labels_nn.append([all_labs[i], all_labs[closest_idx]])

# Filter the list so it contains only the small clusters that we will fix
filtered_list = [x for x in labels_nn if int(x[0]) in iterating_array]
array_test2 = array_test.copy()

In [None]:
for i in filtered_list:
    coord_xy = array_test2[array_test2[:, 3] == i[0]]
    coord_xy2 = array_test2[array_test2[:, 3] == i[1]]
    wk = distance.cdist(coord_xy[:, :2], coord_xy2[:, :2], "euclidean")
    z = abs(coord_xy[:, 2:3].min() - coord_xy[:, 2:3].min())
    kk = array_test2[:, 2][array_test2[:, 3] == i[1]]
    z = abs(coord_xy[:, 2:3].min() - kk.min())
    if (
        len(array_test2[array_test2 == i[0]]) < (small_clusters / 2)
        and wk.min() < min_dist_tree
    ):
        array_test[:, 3][array_test[:, 3] == i[0]] = i[1]
    if wk.min() < branch_dist and z > diff_height:
        array_test[:, 3][array_test[:, 3] == i[0]] = i[1]
    if (
        len(array_test2[array_test2 == i[0]]) < small_clusters
        and wk.min() < min_dist_tree / 2
    ):
        array_test[:, 3][array_test[:, 3] == i[0]] = i[1]
    coord_xy = []
    coord_xy2 = []
    wk = []
    ind = []

### Delete small clusters (optional)

In [None]:
# Get the count of each cluster label
labels_new = array_test[:, 3]
array = array_test[:, 0:3]

# Create a dictionary to store the count of each label
unique, counts = np.unique(labels_new, return_counts=True)
label_count = dict(zip(unique, counts))

# Initialize an empty list to store the indices of the large clusters
large_cluster_indices = []

# Iterate through the cluster labels
for i, label in enumerate(labels_new):
    # If the label corresponds to a large cluster, add the index to the list
    if label_count.get(label, 0) >= small_outliers:
        large_cluster_indices.append(i)

# Use the indices of the large clusters to create a new array
array_test = array[large_cluster_indices, :]

# Add the labels as the last column of the new array
array_test = np.column_stack((array_test, labels_new[large_cluster_indices]))

## 6. Get the number of points in buffer per point (the local maxima column)

In [None]:
# Input data
points = array_test[:, :2]

# Create KDTree from points
kd_tree = KDTree(points)

# Array to store the number of points in the buffer for each point
count = np.zeros(len(points), dtype=int)

# Loop over each point and find points in the buffer
for i, p in enumerate(points):
    idx = kd_tree.query_ball_point(p, radius)
    count[i] = len(idx) - 1

## 7. Find the tree trunks

In [None]:
def cluster_local_maxima(full_array, window_size, max_distance, restrict_d):
    # get the unique label of tree clusters
    unique_clusters = np.unique(full_array[:, 3])
    current_label = 1
    labels = np.zeros(full_array.shape[0], dtype=np.int64)
    full_array = np.column_stack((full_array, labels))
    iteration = 0
    # Iterate through every single tree cluster separately
    for cluster_id in unique_clusters:
        peaks1 = []
        dist_peaks = 100
        # Form an array for the cluster of this iteration
        kot_arr = full_array[full_array[:, 3] == cluster_id]
        x1 = kot_arr[:, 0]
        y1 = kot_arr[:, 1]
        z1 = kot_arr[:, 2]
        p1 = kot_arr[:, 4]
        labels_k = kot_arr[:, 5]
        # Now we iterate through each point of the cluster of this iteration
        for i in range(len(kot_arr)):
            # We form a search window around each point of the cluster
            x_min = x1[i] - window_size
            x_max = x1[i] + window_size
            y_min = y1[i] - window_size
            y_max = y1[i] + window_size
            in_window = np.bitwise_and(x1 >= x_min, x1 <= x_max)
            in_window = np.bitwise_and(
                in_window, np.bitwise_and(y1 >= y_min, y1 <= y_max)
            )
            in_window = np.bitwise_and(in_window, kot_arr[:, 3] == cluster_id)

            # Calculate and save the distances between the local maximas we find.
            if len(peaks1) > 0:
                this_point = [x1[i], y1[i]]
                peak_array = np.array(peaks1)
                this_point = np.array(this_point)
                this_point = this_point.reshape(1, 2)
                dist_peaks = distance.cdist(peak_array, this_point, "euclidean")

            # We find the local maximas for each window
            # Then we restric every local maximas that are way too close with each other with
            # the parameter "restrict_d". Then the local maximas with an accepted distace between
            # each other are relabeld as a unique number for each unique tree.
            if np.max(p1[in_window]) == p1[i] and np.min(dist_peaks) > restrict_d:
                peaks1.append([x1[i], y1[i]])
                points_to_label = np.argwhere(
                    np.logical_and(
                        np.abs(x1 - x1[i]) <= max_distance,
                        np.abs(y1 - y1[i]) <= max_distance,
                    )
                )
                points_to_label = points_to_label.flatten()
                if labels_k[i] == 0:
                    labels_k[points_to_label] = current_label
                    current_label += 1
                else:
                    labels_k[points_to_label] = labels_k[i]

        # we create a new array with the new labels for trunks
        new_2 = np.c_[x1, y1, z1, labels_k]
        if iteration == 0:
            final_result = new_2
        else:
            final_result = np.vstack((final_result, new_2))
        iteration = 1

    return final_result

In [None]:
# Find trunks
full_array = np.column_stack((array_test, count))
Final_labels = cluster_local_maxima(full_array, window_size, max_distance, restrict_d)

In [None]:
# Get the number of trees in this las file
tree_count = np.unique(Final_labels[:, 3])
print("there are", len(tree_count), "trees in this area")

## 8. Store results

### Save the trunk Point Cloud as a new LAS file

In [None]:
vals = np.linspace(0, 1, 100)
np.random.shuffle(vals)
cmap = plt.cm.colors.ListedColormap(plt.cm.tab20(vals))
header = lp.LasHeader()
header.data_format_id = 2

new_las = lp.LasData(header)
new_las.header.scale = [0.01, 0.01, 0.01]
new_las.header.offset = [
    Final_labels[:, 0].min(),
    Final_labels[:, 1].min(),
    Final_labels[:, 2].min(),
]
new_las.x = Final_labels[:, 0]
new_las.y = Final_labels[:, 1]
new_las.z = Final_labels[:, 2]
new_las.pt_src_id = Final_labels[:, 3].astype("uint16")
new_las.write(output_trunks)

### Get the centroid in X and Y for each tree trunk

In [None]:
# Get the unique cluster labels excluding label zero
Centroid_tree = np.unique(Final_labels[:, 3])[1:]
# Initialize an empty list to store the centroids for each cluster
centroids_array = []

# Iterate through each cluster and find the centroid
for label in Centroid_tree:
    cluster_points = Final_labels[Final_labels[:, 3] == label][:, :2]
    centroid = list(np.mean(cluster_points, axis=0))
    centroids_array.append([centroid[0], centroid[1], label])

centroids_array = np.array(centroids_array)

### Save the tree centroids as 2D points (shapefile)

In [None]:
# Create a new shapefile
sf = shapefile.Writer(output_centroids, shapeType=shapefile.POINT)

# Define the fields for the shapefile
sf.field("label", "N")

# Iterate through each row of the array and add a point to the shapefile
for row in centroids_array:
    # Extract the x, y, and label values from the row
    x, y, label = row

    # Add a point to the shapefile with the x and y coordinates
    sf.point(x, y)

    # Set the attributes for the point
    sf.record(label)

# Save and close the shapefile
sf.close()