In [1]:
import os
import open3d as o3d
import torch
import laspy
import numpy as np
import matplotlib.pyplot as plt

from dotenv import load_dotenv
from laspy import LasData
from pathlib import Path
from tqdm import tqdm
from torch.cuda import amp
from torchsparse.utils.quantize import sparse_quantize
from torchsparse import SparseTensor

from EHydro_TreeUnet import TreeProjector

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [None]:
load_dotenv()

TREE_PROJECTOR_DIR = Path(os.environ.get('TREE_PROJECTOR_DIR', Path.home() / 'tree_projector'))
VERSION_NAME = 'tree_projector_VS-0.2_DA-48_E-3_V2'
SAVE_CHUNKS = True
USE_STORED = True
SAVE_SEGMENTED = True
SAVE_ALL_SEGMENTS = True
CHUNK_SIZE = 12.5
MIN_POINTS_PER_PC = 0

VOXEL_SIZE = 0.2
RESNET_BLOCKS = [
    (3, 16, 3, 1),
    (3, 32, 3, 2),
    (3, 64, 3, 2),
    (3, 128, 3, 2),
    (1, 128, (1, 1, 3), (1, 1, 2)),
]
LATENT_DIM = 512
INSTANCE_DENSITY = 0.01
SCORE_THRES = 0.1
CENTROID_THRES = 0.1
DESCRIPTOR_DIM = 64

# class_names = ['Terrain', 'Low Vegetation', 'Stem', 'Canopy']
class_names = ['Terrain', 'Stem', 'Canopy']
class_colormap = np.array([
    [128, 128, 128], # clase 0 - Terrain - gris
    # [147, 255, 138], # clase 1 - Low vegetation - verde claro
    [255, 165, 0],   # clase 2 - Stem - naranja
    [0, 128, 0],     # clase 3 - Canopy - verde oscuro
], dtype=np.uint8)

weights_folder = TREE_PROJECTOR_DIR / 'weights' / VERSION_NAME
process_folder = TREE_PROJECTOR_DIR / 'to_process'
process_folder.mkdir(parents=True, exist_ok=True)

input_folder = process_folder / 'input'
output_folder = process_folder / 'output'
segmented_folder = process_folder / 'segmented'
segmented_by_folder = process_folder / 'segmented/by_folders'

input_folder.mkdir(parents=True, exist_ok=True)
output_folder.mkdir(parents=True, exist_ok=True)
segmented_folder.mkdir(parents=True, exist_ok=True)
segmented_by_folder.mkdir(parents=True, exist_ok=True)

In [None]:
def chunkerize(file: LasData, ext: str):
    points = file.points
    coords = np.vstack((file.x, file.y)).transpose()

    idx  = np.floor_divide(coords, CHUNK_SIZE).astype(int)
    idx -= idx.min(axis=0)

    idx = np.ravel_multi_index(idx.T, idx.max(axis=0) + 1)
    chunks = []
    for i, unique_idx in enumerate(tqdm(np.unique(idx))):
        chunk_points = points[idx == unique_idx]
        if len(chunk_points) < MIN_POINTS_PER_PC:
            continue

        chunk_file = laspy.LasData(file.header)
        chunk_file.points = chunk_points
        chunks.append(chunk_file)
        
        if SAVE_CHUNKS:
            chunk_file.write(output_folder / f'plot_{i}{ext}')

    return chunks


In [None]:
extensions = ('.laz', '.las')
point_clouds = sorted(
            [f for f in input_folder.rglob("*") if f.is_file() and f.suffix.lower() in extensions],
            key=lambda f: f.name
        )

model = TreeProjector(
    in_channels=1,
    num_classes=len(class_names),
    resnet_blocks=RESNET_BLOCKS,
    instance_density=INSTANCE_DENSITY,
    score_thres=SCORE_THRES,
    centroid_thres=CENTROID_THRES,
    descriptor_dim=DESCRIPTOR_DIM
)
model.load_state_dict(torch.load(weights_folder / f'{VERSION_NAME}_weights.pth'))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

pcd = o3d.geometry.PointCloud()
for point_cloud in point_clouds:
    if USE_STORED:
        chunks = [laspy.read(f) for f in output_folder.rglob("*") if f.is_file() and f.suffix.lower() in extensions]
    else:
        chunks = chunkerize(laspy.read(point_cloud), point_cloud.suffix.lower())

    segmented_chunks = []

    instance_offset = 1
    for i, chunk in enumerate(tqdm(chunks)):
        coords = np.vstack((chunk.x, chunk.y, chunk.z)).transpose()
        min_coords = coords.min(axis=0)
        coords -= min_coords

        intensity = np.array(chunk.intensity)[:, None]
        min_intensity = np.min(intensity)
        max_intensity = np.max(intensity)
        i_norm = (intensity - min_intensity) / (max_intensity - min_intensity)

        voxels, indices, inverse_map = sparse_quantize(coords, VOXEL_SIZE, return_index=True, return_inverse=True)

        voxels = torch.tensor(voxels, dtype=torch.int).to(device)
        batch_index = torch.zeros((voxels.shape[0], 1), dtype=torch.int, device=voxels.device)
        voxels = torch.cat([batch_index, voxels], dim=1)
        feat = torch.tensor(i_norm.astype(np.float32), dtype=torch.float).to(device)

        inputs = SparseTensor(coords=voxels, feats=feat)
        with amp.autocast(enabled=True):
            semantic_output, instance_output = model(inputs)

        voxels = semantic_output.C.cpu().numpy()
        semantic_output = torch.argmax(semantic_output.F.cpu(), dim=1).numpy()
        instance_output = torch.argmax(instance_output.F.cpu(), dim=1).numpy()
        instance_output_full = torch.zeros_like(semantic_output)
        max_label = instance_output.max()
        instance_output_full[semantic_output != 0] = instance_output + instance_offset
        instance_offset += max_label + 1

        semantic_output = semantic_output[inverse_map]
        instance_output = instance_output_full[inverse_map]

        out_file = laspy.LasData(header=chunk.header, points=chunk.points.copy())
        out_file.add_extra_dims([laspy.ExtraBytesParams(name="semantic_pred", type=np.int16), laspy.ExtraBytesParams(name="instance_pred", type=np.int32)])
        out_file.semantic_pred = semantic_output
        out_file.instance_pred = instance_output

        segmented_chunks.append(out_file)
        if SAVE_ALL_SEGMENTS:
            out_file.write(segmented_by_folder / f'plot_{i}.las')

    with laspy.open(segmented_folder / 'combined.las', mode='w', header=chunks[0].header) as file:
        for chunk in segmented_chunks:
            file.write_points(chunk.points)


100%|██████████| 106/106 [00:29<00:00,  3.63it/s]
