# AdTree

This notebook tries to represent a tree in 3D using the deflt AdTree repo.

---------------

##### Imports

In [None]:
# Uncomment to load the local package rather than the pip-installed version.
# Add project src to path.
import set_path

In [None]:
# Import modules.
import os
import math
import time
import trimesh
import shapely
import pymeshfix
import subprocess
import alphashape
import numpy as np
import open3d as o3d
import logging as log
import networkx as nx
import matplotlib.pyplot as plt
from tqdm import tqdm
from scipy.spatial import KDTree
from descartes import PolygonPatch
from shapely.geometry import Polygon
from scipy.spatial import ConvexHull
from plyfile import PlyData, PlyElement
from skimage.measure import CircleModel, ransac
from sklearn.neighbors import NearestNeighbors
import utils.math_utils as math_utils
import utils.plot_utils as plot_utils
import utils.o3d_utils as o3d_utils
from misc.quaternion import Quaternion

In [None]:
import networkx as nx

LEAF = 0
WOOD = 1
STEM = 2

class Tree:
    """ This class implements the AdTree delft repository.
    Attributes:
        ---
    """

    def __init__(self, adtree_exe):
        self.exe_path = adtree_exe
        self.leafs_color = [0,.8,0]
        self.wood_color = [.34,.12,0]
        self.stem_color = [1,.2,.2]

    def _lewo(self, pcd, method='curvature'):
        """Leaf-wood classification."""

        labels = np.full(len(pcd.points), LEAF, dtype=int)

        # outlier removal
        pcd_, matrix, trace = pcd.voxel_down_sample_and_trace(0.02, pcd.get_min_bound(), pcd.get_max_bound())
        pcd_, ind_ = pcd_.remove_statistical_outlier(nb_neighbors=16, std_ratio=2.0)
        ind_ = np.asarray(ind_)

        if method == 'curvature':
            mask = o3d_utils.curvature_filter(pcd_, .075, min1=20, min2=35)
            ind = np.hstack([trace[i] for i in ind_[mask]])
        else:
            mask = o3d_utils.surface_variation_filter(pcd_, .1, .15)    
            ind = np.hstack([trace[i] for i in ind_[mask]])

        labels[ind] = WOOD

        return labels

    def _construct_skeleton(self, cloud):
        """Reconstruct tree skeleton using adTree."""

        graph = nx.DiGraph()

        # create input file system
        tmp_folder = './tmp'
        in_file = os.path.join(tmp_folder, 'tree.xyz')
        out_file = os.path.join(tmp_folder, 'tree_skeleton.ply')
        if not os.path.exists(tmp_folder):
            os.mkdir(tmp_folder)

        try:    
            cloud_sampled = cloud.voxel_down_sample(0.04)
            o3d.io.write_point_cloud(in_file, cloud_sampled) # write input file
            result = subprocess.run(
                    [self.exe_path, in_file, out_file],
                    capture_output=True
                    )

            if result.stderr:
                raise subprocess.CalledProcessError(
                        returncode = result.returncode,
                        cmd = result.args,
                        stderr = result.stderr
                        )

            # read output graph
            graph, vertices, edges = read_ply_graph(out_file)
            
        except subprocess.CalledProcessError as e:
            print("Failed reconstructing tree:\n{}".format(e.stderr.decode('utf-8')))
        except Exception as e:
            print("Failed:\n{}".format(e))

        # clean filesystem
        if os.path.exists(in_file):
            os.remove(in_file)
        if os.path.exists(out_file):
            os.remove(out_file)
        if os.path.isdir(tmp_folder):
            os.rmdir(tmp_folder)

        skeleton = {
            'graph': graph,
            'vertices': vertices,
            'edges': edges
        }

        return skeleton

    def _stem_crown_split(self, graph, cloud):
        '''
        Returns the stem of a point cloud given the skeleton. 
        The Stem is the part of the tree till the first branch split
        
        Params
        -----------
        skeleton : NetworkX Graph
            The tree skeleton
        cloud : o3d.geometry.PointCloud
            The point cloud of the tree

        Returns
        -----------
        A open3d point cloud of the stem of the given tree
        '''

        # get start node
        zs = nx.get_node_attributes(graph, 'z')
        start_node = min(zs, key=zs.get)

        # get path till first split
        stem_path = self.branch_path(graph, start_node)
        stem_points = np.array([list(graph.nodes[node].values()) for node in stem_path])

        # Create tree points
        tree_points = np.array(cloud.points)
        labels = np.zeros(len(tree_points), dtype=bool)
        mask_idx = np.where(tree_points[:,2] < stem_points[:,2].max())[0]
        tree = KDTree(tree_points[mask_idx])

        # Filter tree points
        selection = set()
        num_ = int(np.linalg.norm(stem_points[1]-stem_points[0]) / 0.05)
        stem_points = np.linspace(start=stem_points[0], stop=stem_points[-1], num=num_)
        for result in tree.query_ball_point(stem_points, .75):
            selection.update(result) 
        selection = mask_idx[list(selection)]
        labels[selection] = True

        return labels

    def branch_path(self, graph, start):

        path = [start]
        while graph.out_degree(path[-1]) == 1:
            for node in graph.successors(path[-1]):
                path.append(node)
                break

        return path

    def show_tree(self, cloud, labels, skeleton=None):
        """Show point cloud with coloured inliers and outliers."""

        # Leafs
        leafs_cloud = cloud.select_by_index(np.where(labels==LEAF)[0])
        leafs_cloud.paint_uniform_color(self.leafs_color)

        # Wood
        wood_cloud = cloud.select_by_index(np.where(labels==WOOD)[0])
        wood_cloud.paint_uniform_color(self.wood_color)

        # Stem
        stem_cloud = cloud.select_by_index(np.where(labels==STEM)[0])
        stem_cloud.paint_uniform_color(self.stem_color)

        o3d_geometries = [leafs_cloud, wood_cloud, stem_cloud]

        # Skeleton
        if skeleton:
            colors = [[0.3, 0.3, 0.3] for i in range(len(skeleton['edges']))]
            line_set = o3d.geometry.LineSet()
            line_set.points = o3d.utility.Vector3dVector(skeleton['vertices'])
            line_set.lines = o3d.utility.Vector2iVector(skeleton['edges'])
            line_set.colors = o3d.utility.Vector3dVector(colors)

            skeleton_cloud = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(skeleton['vertices']))
            skeleton_cloud = skeleton_cloud.paint_uniform_color([0.1,0.1,0.1])

            o3d_geometries.extend([line_set, skeleton_cloud])

        o3d.visualization.draw_geometries(o3d_geometries)
    
    def model(self, pcd, lw_classification=None):
        """ Create the tree model
        Args:
            pcd (``o3d.geometry.PointCloud``): Tree point cloud.
        Returns:
            The following values are returned.
            ?
        """

        # Leaf-Wood classification
        if lw_classification:
            print(f"Leaf-wood classification using `{lw_classification}` method...")
            labels = self._lewo(pcd, method=lw_classification)
            w_per = 100*np.sum(labels==1)/len(labels)
            print(f"Done. {np.sum(labels==WOOD)}/{len(labels)} points wood.")
        else:
            labels = np.ones(len(pcd.points), dtype=int)

        # 3D reconstruct
        print("Reconstructing tree skeleton...")
        wood_cloud = pcd.select_by_index(np.where(labels==WOOD)[0])
        skeleton = self._construct_skeleton(wood_cloud)
        print(f"Done. Skeleton constructed containing {len(skeleton['vertices'])} nodes")

        # Stem-Crown split
        print("Splitting stem form crown...")
        stem_mask = self._stem_crown_split(skeleton['graph'], pcd)
        labels[stem_mask] = STEM
        print(f"Done. {np.sum(stem_mask)}/{len(labels)} points labeled as stem.")

        # Show tree
        self.show_tree(pcd, labels, skeleton)

        return skeleton, labels

def read_ply_graph(ply_file):
    
    plydata = PlyData.read(ply_file) 

    # vertices
    vertices = np.array([[c for c in p] for p in plydata['vertex'].data])
    vertices, reverse_ = np.unique(vertices, axis=0, return_inverse=True)

    # edges
    edges = np.array([reverse_[edge[0]] for edge in plydata['edge'].data])

    # construct graph
    graph = nx.DiGraph()
    for i, vertex in enumerate(vertices):
        graph.add_node(i, x=vertex[0],y=vertex[1],z=vertex[2])
    for i,j in edges:
        graph.add_edge(j, i)

    return graph, vertices, edges


In [None]:
# Load point cloud data
source = 'cyclo'

if source == 'ahn':
    pcd = o3d_utils.read_las('../datasets/single_selection/single_121913_487434_AHN.las')
elif source == 'cyclo':
    pcd = o3d_utils.read_las('../datasets/single_selection/single_121913_487434_Cyclo.las')
else:
    pcd = o3d_utils.read_las('../datasets/single_selection/single_121913_487434_Sonarski.las')

In [None]:
adtree_exe = '../../AdTree-single/build/bin/AdTree.app/Contents/MacOS/AdTree'
adTree = Tree(adtree_exe)

In [None]:
skeleton, labels = adTree.model(pcd, lw_classification='curvature')