# Pointcloud 3D Tree Modelling

This notebook tries to represent a tree in 3D using the deflt repo.

---------------

In [None]:
# Uncomment to load the local package rather than the pip-installed version.
# Add project src to path.
import set_path

In [None]:
# Import modules.
import os
import time
import trimesh
import pymeshfix
import subprocess
import alphashape
import numpy as np
import open3d as o3d
import logging as log
import networkx as nx
import matplotlib.pyplot as plt
from scipy.spatial import KDTree
from shapely.geometry import Polygon
from scipy.spatial import ConvexHull
from plyfile import PlyData, PlyElement

import utils.math_utils as math_utils
import utils.plot_utils as plot_utils
import utils.las_utils as las_utils

## Datasets

-------

#### AHN Reconstruction

![AHN Reconstruction](../imgs/ahn_reconstruction.png)

In [None]:
ahn_tree_path = '../datasets/single_selection/single_121913_487434_AHN.las'
las_utils.read_las(ahn_tree_path)

# TODO: python reconstruction implementation

#### Cyclo Reconstruction

![Cyclo Media Reconstruction](../imgs/cyclo_reconstruction.png)

In [None]:
cyclo_tree_path = '../datasets/single_selection/single_121913_487434_Cyclo.las'
cyclo_pcd = las_utils.to_o3d(las_utils.read_las(cyclo_tree_path))
cyclo_pcd = cyclo_pcd.voxel_down_sample(voxel_size=0.02)

#### Sonarski Reconstruction

![Sonarski Reconstruction](../imgs/sonarski_reconstruction.png)

In [None]:
sonarski_tree_path = '../datasets/single_selection/single_121913_487434_Sonarski.las'
las_utils.read_las(sonarski_tree_path)

# TODO: python reconstruction implementation

## Modules

--------

#### 1. Leaf-Wood Filtering

There are 2 filters.
- `surface_variation_filter()`, filter based on the **surface variation** of a point. 
- `curvature_filter()`, filter based on the **curvature** of a point

In [None]:
def surface_variation_filter(pcd, radius=0.05, threshold=.15):
    pcd.estimate_covariances(
        search_param=o3d.geometry.KDTreeSearchParamRadius(radius=radius))
    eig_val, _ = np.linalg.eig(np.asarray(pcd.covariances))
    eig_val = np.sort(eig_val, axis=1)

    filter_mask = eig_val[:,0] / eig_val.sum(axis=1) < threshold

    return filter_mask

def curvature_filter(pcd, radius, min1=0, max1=100, min2=0, max2=100, min3=0, max3=100):

    # estimate eigenvalues
    pcd.estimate_covariances(
        search_param=o3d.geometry.KDTreeSearchParamRadius(radius=radius))
    eig_val, _ = np.linalg.eig(np.asarray(pcd.covariances))
    eig_val = np.sort(eig_val, axis=1)
    eig_val[eig_val[:,2]==1] = np.zeros(3)
    L1, L2, L3 = eig_val[:,2], eig_val[:,1], eig_val[:,0]

    # filter L1
    filter_L1 = (L1 > L1.min() + (L1.max()-L1.min()) / 100 * min1) & \
        (L1 < L1.min() + (L1.max()-L1.min()) / 100 * max1)

    # filter L2
    filter_L2 = (L2 > L2.min() + (L2.max()-L2.min()) / 100 * min2) & \
        (L2 < L2.min() + (L2.max()-L2.min()) / 100 * max2)

    # filter L3
    filter_L3 = (L3 > L3.min() + (L3.max()-L3.min()) / 100 * min3) & \
        (L3 < L3.min() + (L3.max()-L3.min()) / 100 * max3)

    L1 = (L1 - L1.min()) / ((L1.max()-L1.min()) / 100)
    L2 = (L2 - L2.min()) / ((L2.max()-L2.min()) / 100)
    L3 = (L3 - L3.min()) / ((L3.max()-L3.min()) / 100)


    filter_mask = filter_L1 & filter_L2 & filter_L3
    return filter_mask, (L1,L2,L3)

def filter_leaves(pcd, surface=False):

    pcd_, _ = pcd.remove_radius_outlier(nb_points=4, radius=.05)

    if surface:
        mask = surface_variation_filter(pcd_, .1, .15)
    else:
        mask, l123 = curvature_filter(pcd_, .05, min1=20, min2=35)

    # Visualize result
    wood_pcd = pcd_.select_by_index(np.where(mask)[0])
    wood_pcd = wood_pcd.paint_uniform_color([.5,.3,0])
    leaf_pcd = pcd_.select_by_index(np.where(mask)[0], invert=True)
    leaf_pcd = leaf_pcd.paint_uniform_color([.8,1,.8])
    o3d.visualization.draw_geometries([wood_pcd, leaf_pcd])

    return wood_pcd, l123, pcd_


In [None]:
pcd = cyclo_pcd
pcd_filtered, l123, pcd_ = filter_leaves(pcd, surface=False)
print(len(pcd.points),'-->', len(pcd_filtered.points))

In [None]:
import laspy

points = np.asarray(pcd_.points)

las = laspy.create(file_version="1.2", point_format=3)
las.header.offsets = np.min(points, axis=0)
las.x = points[:, 0]
las.y = points[:, 1]
las.z = points[:, 2]

las.add_extra_dim(laspy.ExtraBytesParams(name="L1", type="uint8",description="L1"))
las.add_extra_dim(laspy.ExtraBytesParams(name="L2", type="uint8",description="L2"))
las.add_extra_dim(laspy.ExtraBytesParams(name="L3", type="uint8",description="L3"))
las.L1 = np.asarray(l123[0])
las.L2 = np.asarray(l123[1])
las.L3 = np.asarray(l123[2])
las.write('test.las')

#### 2. 3D-Reconstruction (AdTree)

In [None]:
# adtree executable
adtree_exe = '../../AdTree-single/build/bin/AdTree.app/Contents/MacOS/AdTree'

In [None]:
log.basicConfig(
    level=log.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[
        log.FileHandler("debug.log"),
        log.StreamHandler()
    ]
)

def create_graph(vertices, edges):

    graph = nx.DiGraph()

    for i, vertex in enumerate(vertices):
        graph.add_node(i, x=vertex[0],y=vertex[1],z=vertex[2])

    for i,j in edges:
        graph.add_edge(j, i)

    return graph

def _run_command(command):
    log.info("Command: {}".format(command))
    result = subprocess.run(command, capture_output=True)
    if result.stderr:
        raise subprocess.CalledProcessError(
                returncode = result.returncode,
                cmd = result.args,
                stderr = result.stderr
                )
    return result

def reconstruct_tree(pcd, adtree_exe):
    '''
    Reconstructs the branches of a point coud tree.

    pcd : open3d.geometry.PointCloud
        input tree point cloud

    adtree : str
        path to adtree executable.

    Returns:
    ----------
    PlyData file of the reconstructed input point cloud tree.
    '''

    skeleton_graph = nx.DiGraph()
    vertices, edges = np.array([]), np.array([])

    # create input file system
    tmp_folder = './tmp'
    in_file = os.path.join(tmp_folder, 'tree.xyz')
    out_file = os.path.join(tmp_folder, 'tree_skeleton.ply')
    if not os.path.exists(tmp_folder):
        os.mkdir(tmp_folder)

    try:    
        log.info("Reconstructing...") 
        o3d.io.write_point_cloud(in_file, pcd) # write input file
        result = _run_command([adtree_exe, in_file, out_file]) # run reconstruction
        plydata = PlyData.read(out_file) # read output

        # Convert skeleton to graph
        log.info("Done. Constructing graph...")
        vertices = np.array([[c for c in p] for p in plydata['vertex'].data])
        vertices, reverse_ = np.unique(vertices, axis=0, return_inverse=True)
        edges = np.array([reverse_[edge[0]] for edge in plydata['edge'].data])
        skeleton_graph = create_graph(vertices, edges)
        log.info("Done. Succesful.")
        
    except subprocess.CalledProcessError as e:
        log.error("Failed reconstructing tree:\n{}".format(e.stderr.decode('utf-8')))
    except Exception as e:
        log.error("Failed:\n{}".format(e))

    # clean filesystem
    if os.path.exists(in_file):
        os.remove(in_file)
    if os.path.exists(out_file):
        os.remove(out_file)
    if os.path.isdir(tmp_folder):
        os.rmdir(tmp_folder)

    return skeleton_graph, vertices, edges
    

In [None]:
skeleton, vertices, edges = reconstruct_tree(pcd_filtered, adtree_exe)
# skeleton, vertices, edges = reconstruct_tree(pcd, adtree_exe)

#### 3. Stem Split

In [None]:
def plot_skeleton(vertices, edges):
    # visulalize
    colors = [[1, 0, 0] for i in range(len(edges))]
    line_set = o3d.geometry.LineSet()
    line_set.points = o3d.utility.Vector3dVector(vertices)
    line_set.lines = o3d.utility.Vector2iVector(edges)
    line_set.colors = o3d.utility.Vector3dVector(colors)

    # colors = [[0, 0, 0] for i in range(len(edges_sel))]
    # line_set_sel = o3d.geometry.LineSet()
    # line_set_sel.points = o3d.utility.Vector3dVector(vertices)
    # line_set_sel.lines = o3d.utility.Vector2iVector(edges_sel)
    # line_set_sel.colors = o3d.utility.Vector3dVector(colors)

    v_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(vertices))
    v_pcd = v_pcd.paint_uniform_color([0,0.5,0])
    o3d.visualization.draw_geometries([line_set, v_pcd])

def branch_path(graph, start):

    path = [start]
    while graph.out_degree(path[-1]) == 1:
        for node in graph.successors(path[-1]):
            path.append(node)
            break

    print(path[0], '-->', path[-1], '(nodes:',len(path),')')
    return path

def get_stem(skeleton, pcd):
    '''
    Returns the stem of a point cloud given the skeleton. 
    The Stem is the part of the tree till the first branch split
    
    Params
    -----------
    skeleton : NetworkX Graph
        The tree skeleton
    pcd : o3d.geometry.PointCloud
        The point cloud of the tree

    Returns
    -----------
    A open3d point cloud of the stem of the given tree
    '''

    # get start node
    zs = nx.get_node_attributes(skeleton, 'z')
    start_node = min(zs, key=zs.get)

    # get path till first split
    stem_path = branch_path(skeleton, start_node)
    stem_points = np.array([list(skeleton.nodes[node].values()) for node in stem_path])

    # Create tree points
    tree_points = np.array(pcd.points)
    mask_idx = np.where(tree_points[:,2] < stem_points[:,2].max())[0]
    tree = KDTree(tree_points[mask_idx])

    # Filter tree points
    selection = set()
    num_ = int(np.linalg.norm(stem_points[1]-stem_points[0]) / 0.05)
    stem_points = np.linspace(start=stem_points[0], stop=stem_points[1], num=num_)
    for result in tree.query_ball_point(stem_points, .75):
        selection.update(result) 
    selection = mask_idx[list(selection)]

    # Visualize
    stem_pcd = pcd.select_by_index(selection)
    stem_pcd = stem_pcd.paint_uniform_color([.4,.7,0])
    crown_pcd = pcd.select_by_index(selection, invert=True)
    o3d.visualization.draw_geometries([stem_pcd, crown_pcd])

    return stem_pcd, crown_pcd


In [None]:
stem_pcd, crown_pcd = get_stem(skeleton, pcd)

In [None]:
plot_skeleton(vertices, edges)

#### 4. Stem Analysis

In [None]:
# Fittings

#### 5. Crown Analysis

In [None]:
def pcd_convex_hull(pcd, plot=True):
    log.info(f'Convex Hull for pcd with {len(pcd.points)} points.')
    o3d_mesh, _ = pcd.compute_convex_hull()
    log.info(f'Done. Crown volume: {o3d_mesh.get_volume():.2f}m3')
    
    if plot: # Visualize
        fig = plt.figure()
        ax = plt.axes(projection='3d')
        ax.plot_trisurf(*zip(*o3d_mesh.vertices), triangles=o3d_mesh.triangles)
        plt.show()

    return o3d_mesh

def alpha_shape(pcd, alpha=.8, plot=True):
    log.info(f'Alpha Shapes for pcd with {len(pcd.points)} points.')
    start = time.time()
    pcd_pts = np.asarray(pcd.points)
    mesh = alphashape.alphashape(pcd_pts, alpha)
    log.info(f'Done. {time.time()-start:.2f}s.')
    
    # Repair
    log.info(f'Repair broken faces...')
    clean_points, clean_faces = pymeshfix.clean_from_arrays(mesh.vertices,  mesh.faces)
    mesh = trimesh.base.Trimesh(clean_points, clean_faces)
    mesh.fix_normals()
    log.info(f'Done. Crown volume: {mesh.volume:.2f}m3')
    
    if plot: # Visualize
        fig = plt.figure()
        ax = plt.axes(projection='3d')
        ax.plot_trisurf(*zip(*mesh.vertices), triangles=mesh.faces)
        plt.show()

    # convert
    o3d_mesh = mesh.as_open3d
    o3d_mesh.compute_vertex_normals()

    return o3d_mesh

def plot_meshlines(pcds, mesh):
    mesh_lines = o3d.geometry.LineSet.create_from_triangle_mesh(mesh)
    mesh_lines.paint_uniform_color((1, 0, 0))
    geometries = pcds + [mesh_lines]
    o3d.visualization.draw_geometries(geometries)


In [None]:
voxel_size = 0.3

# Down sample crown
crown_sampled = crown_pcd.voxel_down_sample(voxel_size)

In [None]:
mesh = pcd_convex_hull(crown_sampled)

In [None]:
mesh = alpha_shape(crown_sampled)

In [None]:
plot_meshlines([stem_pcd, crown_pcd], mesh)

## Overig

---------

In [None]:
# Stam doet gek...

In [None]:
skim = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(vertices[leaf_nodes]))
o3d.visualization.draw_geometries([skim])

In [None]:
skeleton_sel = skeleton.subgraph([node for node, n in skeleton.out_degree() if n != 0])
edges_sel = np.array([(i,j) for i,j in skeleton_sel.edges])
plot_skeleton(vertices, edges, edges_sel)