In [1]:
from typing import Any, Dict, List

import open3d as o3d
import laspy
import numpy as np
import torch
from torch import nn
from torch.cuda import amp
from pathlib import Path

import torchsparse
from EHydro_TreeUnet.tree_unet import UNet
from torchsparse import SparseTensor
from torchsparse.nn import functional as F
from torchsparse.utils.collate import sparse_collate_fn
from torchsparse.utils.quantize import sparse_quantize
from torchsparse.utils.collate import sparse_collate_fn as _orig_collate

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
class FORinstanceDataset:
    def __init__(self, voxel_size: float) -> None:
        self.voxel_size = voxel_size
        self.folder = Path('./datasets/FORinstance')
        self.extensions = ('.laz', '.las')
        self.files = sorted(
            [f for f in self.folder.iterdir() if f.is_file() and f.suffix.lower() in self.extensions],
            key=lambda f: f.name
        )

    def __getitem__(self, idx):
        if isinstance(idx, slice):
            return [self._load_file(self.files[i]) for i in range(*idx.indices(len(self)))]
        elif isinstance(idx, int):
            if idx < 0:
                idx += len(self)
            if idx < 0 or idx >= len(self):
                raise IndexError("Index out of range")
            return self._load_file(self.files[idx])
        else:
            raise TypeError("Index must be a slice or an integer")
    
    def _load_file(self, path):
        las = laspy.read(path)

        coords = np.vstack((las.x, las.y, las.z)).transpose()
        coords -= np.min(coords, axis=0, keepdims=True)
        feats = np.vstack((las.intensity, las.return_number, las.number_of_returns)).transpose()
        class_labels = np.array(las.classification)
        instance_ids = np.array(las.treeID)

        mask = class_labels != 3

        coords = coords[mask]
        feats = feats[mask]
        class_labels = class_labels[mask]
        class_labels = np.where(class_labels > 3, class_labels - 1, class_labels)
        instance_ids = instance_ids[mask]

        #N = coords.shape[0]
        #labels = np.zeros((N, 4), dtype=np.float32)
        #labels[:, 0] = class_labels

        #unique_ids = np.unique(instance_ids)
        #for inst_id in unique_ids:
        #    mask = instance_ids == inst_id
        #    if not np.any(mask):
        #        continue

        #    center = coords[mask].mean(axis=0)
        #    labels[mask, 1:] = center - coords[mask]

        min_coords = coords.min(axis=0)
        max_coords = coords.max(axis=0)
        print(f'tamaño: {max_coords - min_coords}')

        coords, indices = sparse_quantize(coords, self.voxel_size, return_index=True)

        coords = torch.tensor(coords, dtype=torch.int)
        feats = torch.tensor(feats[indices].astype(np.float32), dtype=torch.float)
        labels = torch.tensor(class_labels[indices], dtype=torch.long)

        input = SparseTensor(coords=coords, feats=feats)
        label = SparseTensor(coords=coords, feats=labels)
        return {"input": input, "label": label}

    def __len__(self):
        return len(self.files)

In [3]:
colormap = np.array([
    [255, 0, 0],    # clase 0 - Unclassified - rojo
    [0, 255, 0],    # clase 1 - Low-vegetation - verde
    [128, 128, 128],# clase 2 - Terrain - gris
    [255, 165, 0],  # clase 3 - Stem - naranja
    [0, 128, 0],  # clase 4 - Live-branches - verde oscuro
    [0, 0, 255]     # clase 5 - Woody-branches - azul
], dtype=np.uint8)

def draw_pc(coords, labels):
    pcd = o3d.geometry.PointCloud()

    colors = colormap[labels]
    colors = colors / 255.0
    
    pcd.points = o3d.utility.Vector3dVector(coords)
    pcd.colors = o3d.utility.Vector3dVector(colors)

    o3d.visualization.draw_geometries([pcd])

In [4]:
conv_config = F.conv_config.get_default_conv_config(conv_mode=F.get_conv_mode())
conv_config.kmap_mode = 'hashmap'
F.conv_config.set_global_conv_config(conv_config)

model = UNet(3, 6).to(device='cuda')

dataset = FORinstanceDataset(voxel_size=0.2)
dataflow = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    collate_fn=sparse_collate_fn,
)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scaler = amp.GradScaler(enabled=True)

for k, feed_dict in enumerate(dataflow):
    inputs = feed_dict["input"].to(device='cuda')
    labels = feed_dict["label"].to(device='cuda')

    with amp.autocast(enabled=True):
        outputs = model(inputs)
        loss = criterion(outputs.feats, labels.feats)

    print(f"[step {k + 1}] loss = {loss.item()}")

    optimizer.zero_grad()
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

tamaño: [27.83 27.83 28.25]
[step 1] loss = 2.2216506004333496
tamaño: [27.83 27.83 36.33]
[step 2] loss = 1.8627995252609253
tamaño: [27.83 27.83 33.01]
[step 3] loss = 1.839050531387329
tamaño: [27.83 27.83 35.2 ]
[step 4] loss = 1.480320692062378
tamaño: [27.83 27.83 35.83]


KeyboardInterrupt: 

In [None]:
# enable torchsparse 2.0 inference
model.eval()
# enable fused and locality-aware memory access optimization
torchsparse.backends.benchmark = True  # type: ignore

with torch.no_grad():
    for k, feed_dict in enumerate(dataflow):
        inputs = feed_dict["input"].to(device='cuda').half()
        labels = feed_dict["label"].to(device='cuda')

        draw_pc(labels.coords[:, 1:].cpu().numpy(), labels.feats.cpu().numpy())
        with amp.autocast(enabled=True):
            outputs = model(inputs)
            loss = criterion(outputs.feats, labels.feats)

            coords = outputs.coords[:, 1:].cpu().numpy()
            features = outputs.feats.cpu()
            labels = torch.argmax(features, dim=1).numpy()

            draw_pc(coords, labels)

        print(f"[inference step {k + 1}] loss = {loss.item()}")

tamaño: [27.83 27.83 28.25]


KeyboardInterrupt: 