In [1]:
import open3d as o3d
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchsparse
import hdbscan
import laspy
import time

from EHydro_TreeUnet.tree_unet import UNet
from pathlib import Path

from torch import nn
from torch.cuda import amp
from torchsparse import SparseTensor
from torchsparse.nn import functional as F
from torch.utils.data import random_split, DataLoader
from torchsparse.utils.collate import sparse_collate_fn
from torchsparse.utils.quantize import sparse_quantize
from torchsparse.utils.collate import sparse_collate_fn as _orig_collate

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
class FORinstanceDataset:
    def __init__(self, voxel_size: float) -> None:
        self.voxel_size = voxel_size
        self.folder = Path('./datasets/FORinstance')
        self.extensions = ('.laz', '.las')
        self.files = sorted(
            [f for f in self.folder.rglob("*") if f.is_file() and f.suffix.lower() in self.extensions],
            key=lambda f: f.name
        )

    def __getitem__(self, idx):
        if isinstance(idx, slice):
            return [self._load_file(self.files[i]) for i in range(*idx.indices(len(self)))]
        elif isinstance(idx, int):
            if idx < 0:
                idx += len(self)
            if idx < 0 or idx >= len(self):
                raise IndexError("Index out of range")
            return self._load_file(self.files[idx])
        else:
            raise TypeError("Index must be a slice or an integer")
    
    def _load_file(self, path):
        las = laspy.read(path)

        coords = np.vstack((las.x, las.y, las.z)).transpose()
        coords -= np.min(coords, axis=0, keepdims=True)
        feat = np.vstack((las.intensity, las.return_number, las.number_of_returns)).transpose()
        label = np.array(las.classification)
        instance_ids = np.array(las.treeID)

        mask = label != 3

        coords = coords[mask]
        feat = feat[mask]
        label = label[mask]
        label = np.where(label > 3, label - 1, label)
        instance_ids = instance_ids[mask]
            
        coords, indices = sparse_quantize(coords, self.voxel_size, return_index=True)
        feat = feat[indices]
        label = label[indices]
        instance_ids = instance_ids[indices]

        offset = np.zeros((coords.shape[0], 3), dtype=np.float32)
        unique_ids = np.unique(instance_ids)
        for inst_id in unique_ids:
            mask = instance_ids == inst_id
            if not np.any(mask):
                continue

            tree_points = coords[mask]
            offset[mask, :] = tree_points.mean(axis=0) - tree_points

        #min_coords = np.min(coords, axis=0)
        #max_coords = np.max(coords, axis=0)
        #side_lengths = max_coords - min_coords
        #volume = np.prod(side_lengths)
        #offset /= volume

        coords = torch.tensor(coords, dtype=torch.int)
        feat = torch.tensor(feat.astype(np.float32), dtype=torch.float)
        label = torch.tensor(label, dtype=torch.long)
        offset = torch.tensor(offset, dtype=torch.float)

        input = SparseTensor(coords=coords, feats=feat)
        label = SparseTensor(coords=coords, feats=label)
        offset = SparseTensor(coords=coords, feats=offset)
        return {"input": input, "label": label, "offset": offset}

    def __len__(self):
        return len(self.files)

In [3]:
colormap = np.array([
    [255, 0, 0],    # clase 0 - Unclassified - rojo
    [0, 255, 0],    # clase 1 - Low-vegetation - verde
    [128, 128, 128],# clase 2 - Terrain - gris
    [255, 165, 0],  # clase 3 - Stem - naranja
    [0, 128, 0],  # clase 4 - Live-branches - verde oscuro
    [0, 0, 255]     # clase 5 - Woody-branches - azul
], dtype=np.uint8)

def draw_pc(coords, labels, ids):
    pcd = o3d.geometry.PointCloud()

    colors = colormap[labels]
    colors = colors / 255.0
    
    pcd.points = o3d.utility.Vector3dVector(coords)
    pcd.colors = o3d.utility.Vector3dVector(colors)

    o3d.visualization.draw_geometries([pcd])

    unique_ids = np.unique(ids)
    n_instances = len(unique_ids)

    cmap = plt.get_cmap("tab20")  # Puedes cambiar a 'tab10', 'gist_ncar', etc.
    colors = np.array([cmap(i % 20)[:3] for i in range(n_instances)])

    id2color = {id_: colors[i] for i, id_ in enumerate(unique_ids)}
    point_colors = np.array([id2color[id_] for id_ in ids])

    pcd.colors = o3d.utility.Vector3dVector(point_colors)
    o3d.visualization.draw_geometries([pcd])


In [4]:
conv_config = F.conv_config.get_default_conv_config(conv_mode=F.get_conv_mode())
conv_config.kmap_mode = 'hashmap'
F.conv_config.set_global_conv_config(conv_config)

model = UNet(3, 6).to(device='cuda')

dataset = FORinstanceDataset(voxel_size=0.2)

train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=1, collate_fn=sparse_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=1, collate_fn=sparse_collate_fn)

criterion_semantic = nn.CrossEntropyLoss()
criterion_offset = nn.SmoothL1Loss()

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scaler = amp.GradScaler(enabled=True)

for k, feed_dict in enumerate(train_loader):
    inputs = feed_dict["input"].to(device='cuda')
    label = feed_dict["label"].to(device='cuda')
    offset = feed_dict["offset"].to(device='cuda')

    with amp.autocast(enabled=True):
        semantic_output, offset_output = model(inputs)
        
        loss_semantic = criterion_semantic(semantic_output.feats, label.feats)
        loss_offset = criterion_offset(offset_output.feats, offset.feats)
        loss = 0.1 * loss_semantic + 0.9 * loss_offset

    print(f"[step {k + 1}] loss = {loss.item()}")

    optimizer.zero_grad()
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

[step 1] loss = 10.772480964660645


RuntimeError: shape '[27, 64, 192]' is invalid for input of size 276480

In [None]:
# enable torchsparse 2.0 inference
model.eval()
# enable fused and locality-aware memory access optimization
torchsparse.backends.benchmark = True  # type: ignore

clusterer = hdbscan.HDBSCAN(min_cluster_size=10)
with torch.no_grad():
    for k, feed_dict in enumerate(val_loader):
        inputs = feed_dict["input"].to(device='cuda')
        label = feed_dict["label"].to(device='cuda')
        offset = feed_dict["offset"].to(device='cuda')

        with amp.autocast(enabled=True):
            now = time.time()
            semantic_output, offset_output = model(inputs)
            print(f'duracion: {(time.time() - now):.2f} s')

            loss_semantic = criterion_semantic(semantic_output.feats, label.feats)
            loss_offset = criterion_offset(offset_output.feats, offset.feats)
            loss = 0.1 * loss_semantic + 0.9 * loss_offset

            coords = semantic_output.coords[:, 1:].cpu().numpy()
            semantic = semantic_output.feats.cpu()
            offset = offset_output.feats.cpu().numpy()
            label = torch.argmax(semantic, dim=1).numpy()

            #min_coords = np.min(coords, axis=0)
            #max_coords = np.max(coords, axis=0)
            #side_lengths = max_coords - min_coords
            #volume = np.prod(side_lengths)

            print(offset)

            collapsed_coords = coords + offset
            id = clusterer.fit_predict(collapsed_coords)

            draw_pc(coords, label, id)

        print(f"[inference step {k + 1}] loss = {loss.item()}")

[[ 0.7437  0.1825  4.305 ]
 [ 1.154   0.4395  4.074 ]
 [ 0.706  -0.6387  4.75  ]
 ...
 [-0.4534 -0.1288  2.275 ]
 [-0.857  -0.5356  2.559 ]
 [-1.559  -0.4438  3.717 ]]
[inference step 1] loss = 10.033637046813965
[[ 0.2091 -0.1625  0.4011]
 [ 0.2119 -0.1364  0.2485]
 [ 0.405  -0.1638  0.3062]
 ...
 [-0.288  -0.3105  0.578 ]
 [-0.4766 -0.3914  0.9814]
 [-0.7026 -0.37    1.198 ]]
[inference step 2] loss = 9.115703582763672
[[ 0.371  -0.2546  0.6567]
 [ 0.7065 -0.1864  0.9097]
 [ 0.2449 -0.2338  0.3481]
 ...
 [-0.873  -0.4504  1.921 ]
 [-0.893  -0.6753  2.    ]
 [-0.5645 -0.7246  1.573 ]]
[inference step 3] loss = 8.601606369018555
