In [1]:
import os
import open3d as o3d
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from pathlib import Path
from dotenv import load_dotenv

from EHydro_TreeUnet.trainers import TreeProjectorTrainer
from torchsparse.nn import functional as F

F.set_kmap_mode("hashmap_on_the_fly")

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
load_dotenv()

TREE_PROJECTOR_DIR = Path(os.environ.get('TREE_PROJECTOR_DIR', Path.home() / 'tree_projector'))
DATASET_FOLDER = 'MixedDataset'
VERSION_NAME = 'tree_projector_instance_VS-0.2_DA-48_E-3'

VOXEL_SIZE = 0.2
FEAT_KEYS = ['intensity']
CENTROID_SIGMA = 1.5
TRAIN_PCT = 0.9
DATA_AUGMENTATION_COEF = 48
YAW_RANGE = (0.0, 360.0)
TILT_RANGE = (-5.0, 5.0)
SCALE_RANGE = (0.9, 1.2)

TRAINING = True
EPOCHS = 3
START_ON_EPOCH = 0
BATCH_SIZE = 1
SEMANTIC_LOSS_COEF = 1.0
CENTROID_LOSS_COEF = 1.0
INSTANCE_LOSS_COEF = 1.0

RESNET_BLOCKS = [
    (3, 16, 3, 1),
    (3, 32, 3, 2),
    (3, 64, 3, 2),
    (3, 128, 3, 2),
    (1, 128, (1, 1, 3), (1, 1, 2)),
]
LATENT_DIM = 512
INSTANCE_DENSITY = 0.01
CENTROID_THRES = 0.1
DESCRIPTOR_DIM = 64

CHARTS_IGNORE_CLASS = []

In [3]:
def smooth(arr: np.ndarray, window: int) -> np.ndarray:
    if window <= 1:
        return arr

    kernel = np.ones(window, dtype=float)
    if arr.ndim == 1:
        denom = np.convolve(np.ones_like(arr), kernel, mode="same")
        return np.convolve(arr, kernel, mode="same") / denom

    # 2-D: suavizar cada columna por separado
    smoothed = np.empty_like(arr, dtype=float)
    denom = np.convolve(np.ones(arr.shape[0]), kernel, mode="same")
    for c in range(arr.shape[1]):
        smoothed[:, c] = np.convolve(arr[:, c], kernel, mode="same") / denom
    return smoothed
        
def gen_charts(trainer, losses, stats, training: bool, window: int = 1, ignore_class = []):
    keys = stats[0].keys()
    stats = {k: np.array([d[k].cpu() for d in stats]) for k in keys}

    loss = np.asarray(losses)
    loss_s = np.clip(smooth(loss[:, 0], window), 0.0, 10.0)
    loss_sem_s = np.clip(smooth(loss[:, 1], window), 0.0, 10.0)
    loss_centroid_s = np.clip(smooth(loss[:, 2], window), 0.0, 10.0)
    loss_inst_s = np.clip(smooth(loss[:, 3], window), 0.0, 10.0)

    iou_semantic = smooth(stats['iou_semantic'], window)
    mean_iou_semantic = smooth(stats['mean_iou_semantic'], window)
    precision_semantic = smooth(stats['precision_semantic'], window)
    mean_precision_semantic = smooth(stats['mean_precision_semantic'], window)
    recall_semantic = smooth(stats['recall_semantic'], window)
    mean_recall_semantic = smooth(stats['mean_recall_semantic'], window)
    f1_semantic = smooth(stats['f1_semantic'], window)
    mean_f1_semantic = smooth(stats['mean_f1_semantic'], window)

    mean_iou_instance = smooth(stats['mean_iou_instance'], window)

    # --- Global loss -----------------------------------------------------
    plt.figure(figsize=(10, 5))
    plt.plot(loss_s, label=f"{'Training' if training else 'Inference'} Loss (MA{window})")
    plt.xlabel("Step"); plt.ylabel("Loss")
    plt.title(f"Loss evolution during {'Training' if training else 'Inference'}")
    plt.legend(); plt.grid(True); plt.show()

    # --- Semantic loss ----------------------------------------------------
    plt.figure(figsize=(10, 5))
    plt.plot(loss_sem_s, label=f"{'Training' if training else 'Inference'} Loss (MA{window})")
    plt.xlabel("Step"); plt.ylabel("Loss")
    plt.title(f"Semantic loss evolution during {'Training' if training else 'Inference'}")
    plt.legend(); plt.grid(True); plt.show()

    # --- Centroid loss ----------------------------------------------------
    plt.figure(figsize=(10, 5))
    plt.plot(loss_centroid_s, label=f"{'Training' if training else 'Inference'} Loss (MA{window})")
    plt.xlabel("Step"); plt.ylabel("Loss")
    plt.title(f"Centroid loss evolution during {'Training' if training else 'Inference'}")
    plt.legend(); plt.grid(True); plt.show()

    # --- Instance loss ----------------------------------------------------
    plt.figure(figsize=(10, 5))
    plt.plot(loss_inst_s, label=f"{'Training' if training else 'Inference'} Loss (MA{window})")
    plt.xlabel("Step"); plt.ylabel("Loss")
    plt.title(f"Instance loss evolution during {'Training' if training else 'Inference'}")
    plt.legend(); plt.grid(True); plt.show()

    # --- Semantic IoU -----------------------------------------------------
    plt.figure(figsize=(10, 5))
    for c in range(trainer.dataset.num_classes):
        if trainer.dataset.class_names[c] in ignore_class:
            continue
        
        plt.plot(iou_semantic[:, c], label=trainer.dataset.class_names[c])
    plt.xlabel("Step"); plt.ylabel("IoU")
    plt.title(f"Semantic IoU evolution during {'Training' if training else 'Inference'} (MA{window})")
    plt.ylim(0, 1)
    plt.legend(); plt.grid(True); plt.show()

    # --- Semantic mIoU ----------------------------------------------------
    plt.figure(figsize=(10, 5))
    plt.plot(mean_iou_semantic, label=f"{'Training' if training else 'Inference'} mIoU (MA{window})")
    plt.xlabel("Step"); plt.ylabel("mIoU")
    plt.title(f"Semantic mIoU evolution during {'Training' if training else 'Inference'}")
    plt.ylim(0, 1)
    plt.legend(); plt.grid(True); plt.show()

    # --- Instance mIoU ----------------------------------------------------
    plt.figure(figsize=(10, 5))
    plt.plot(mean_iou_instance, label=f"{'Training' if training else 'Inference'} Instance mIoU (MA{window})")
    plt.xlabel("Step"); plt.ylabel("mIoU")
    plt.title(f"Instance mIoU evolution during {'Training' if training else 'Inference'}")
    plt.ylim(0, 1)
    plt.legend(); plt.grid(True); plt.show()

    column_names = ['IoU', 'Precision', 'Recall', 'F1']
    row_names = [trainer.dataset.class_names[c] for c in range(trainer.dataset.num_classes) if trainer.dataset.class_names[c] not in ignore_class]
    row_names.append('Mean')

    iou_semantic_arr = stats['iou_semantic']
    prec_arr = stats['precision_semantic']
    recall_arr = stats['recall_semantic']
    f1_arr = stats['f1_semantic']

    data = [
        [iou_semantic_arr[:, c].mean(), prec_arr[:, c].mean(), recall_arr[:, c].mean(), f1_arr[:, c].mean()]
    for c in range(trainer.dataset.num_classes) if trainer.dataset.class_names[c] not in ignore_class]

    means = np.array(data).mean(axis=0)
    data.append(list(means))

    df = pd.DataFrame(data, columns=column_names, index=row_names)
    display(df)

In [4]:
trainer = TreeProjectorTrainer(
    tree_projector_dir=TREE_PROJECTOR_DIR,
    dataset_folder=DATASET_FOLDER,
    version_name=VERSION_NAME,

    voxel_size=VOXEL_SIZE,
    feat_keys=FEAT_KEYS,
    centroid_sigma=CENTROID_SIGMA,
    train_pct=TRAIN_PCT,
    data_augmentation_coef=DATA_AUGMENTATION_COEF,
    yaw_range=YAW_RANGE,
    tilt_range=TILT_RANGE,
    scale_range=SCALE_RANGE,

    training=TRAINING,
    epochs=EPOCHS,
    start_on_epoch=START_ON_EPOCH,
    batch_size=BATCH_SIZE,
    semantic_loss_coef=SEMANTIC_LOSS_COEF,
    centroid_loss_coef=CENTROID_LOSS_COEF,
    instance_loss_coef=INSTANCE_LOSS_COEF,

    resnet_blocks=RESNET_BLOCKS,
    latent_dim=LATENT_DIM,
    instance_density=INSTANCE_DENSITY,
    centroid_thres=CENTROID_THRES,
    descriptor_dim=DESCRIPTOR_DIM
)

if TRAINING:
    trainer.train()
    stats = trainer.stats
    losses = trainer.losses
    gen_charts(trainer=trainer, losses=trainer.losses, stats=trainer.stats, training=True, window=1, ignore_class=CHARTS_IGNORE_CLASS)

pcd = o3d.geometry.PointCloud()
pcd2 = o3d.geometry.PointCloud()
for voxels, semantic_output, semantic_labels, centroid_score_output, centroid_score_labels, instance_output, instance_labels, centroid_voxels, centroid_confidence_output in trainer.eval():
    continue
    batch_idx = voxels[:, 0]
    centroid_batch_idx = centroid_voxels[:, 0]
    voxels = voxels[:, 1:]
    centroid_voxels = centroid_voxels[:, 1:]

    for idx in np.unique(batch_idx):
        mask = batch_idx == idx
        cloud_voxels = voxels[mask]
        cloud_semantic_output = semantic_output[mask]
        cloud_semantic_labels = semantic_labels[mask]
        cloud_centroid_score_output = centroid_score_output[mask]
        cloud_centroid_score_labels = centroid_score_labels[mask]
        cloud_instance_output = instance_output[mask]
        cloud_instance_labels = instance_labels[mask]

        mask = centroid_batch_idx == idx
        cloud_centroid_voxels = centroid_voxels[mask]
        cloud_centroid_confidence_output = centroid_confidence_output[mask]

        pcd.points = o3d.utility.Vector3dVector(cloud_voxels)

        colors = trainer.dataset.class_colormap[cloud_semantic_labels] / 255.0
        pcd.colors = o3d.utility.Vector3dVector(colors)
        o3d.visualization.draw_geometries([pcd])

        colors = trainer.dataset.class_colormap[cloud_semantic_output] / 255.0
        pcd.colors = o3d.utility.Vector3dVector(colors)
        o3d.visualization.draw_geometries([pcd])

        cmap = plt.get_cmap('viridis')

        colors = cmap(cloud_centroid_score_labels[:, 0])[:, :3]
        pcd.colors = o3d.utility.Vector3dVector(colors)
        o3d.visualization.draw_geometries([pcd])

        colors = cmap(cloud_centroid_score_output[:, 0])[:, :3]
        pcd.colors = o3d.utility.Vector3dVector(colors)
        o3d.visualization.draw_geometries([pcd])

        spheres = []
        colors = trainer.dataset.class_colormap[cloud_semantic_labels] / 255.0
        pcd.colors = o3d.utility.Vector3dVector(colors)
        for i in range(cloud_centroid_voxels.shape[0]):
            center = cloud_centroid_voxels[i]
            confidence = cloud_centroid_confidence_output[i][0]

            sphere = o3d.geometry.TriangleMesh.create_sphere(radius=1.5)
            sphere.translate(center)
            color = cmap(confidence)[:3]
            sphere.paint_uniform_color(color)
            spheres.append(sphere)

        o3d.visualization.draw_geometries([pcd] + spheres)

        unique_ids = np.unique(cloud_instance_labels)
        rng = np.random.default_rng(0)
        palette = rng.random((len(unique_ids), 3))

        id2color = {uid: palette[i] for i, uid in enumerate(unique_ids)}
        colors = np.array([id2color[i] for i in cloud_instance_labels], dtype=np.float64)

        pcd.colors = o3d.utility.Vector3dVector(colors)
        o3d.visualization.draw_geometries([pcd])

        unique_ids = np.unique(cloud_instance_output)
        rng = np.random.default_rng(0)
        palette = rng.random((len(unique_ids), 3))

        id2color = {uid: palette[i] for i, uid in enumerate(unique_ids)}
        colors = np.array([id2color[i] for i in cloud_instance_output], dtype=np.float64)

        pcd.colors = o3d.utility.Vector3dVector(colors)
        o3d.visualization.draw_geometries([pcd])

gen_charts(trainer=trainer, losses=trainer.losses, stats=trainer.stats, training=False, window=1, ignore_class=CHARTS_IGNORE_CLASS)


Parámetros totales: 34,512,431
Parámetros entrenables: 34,512,431
Resnet generates features at the following scales:
	* (0.2, 0.2, 0.2) meters -> 16 feats.
	* (0.4, 0.4, 0.4) meters -> 32 feats.
	* (0.8, 0.8, 0.8) meters -> 64 feats.
	* (1.6, 1.6, 1.6) meters -> 128 feats.
	* (1.6, 1.6, 3.2) meters -> 128 feats.

Minimum scene size: (4.8, 4.8, 9.6) meters
Total channels in backbone: 368 -> 512 in latent space.
Version name: tree_projector_instance_VS-0.2_DA-48_E-3

=== Starting epoch 1 ===
Training instance correlation with labels instead of predictions by now...

[Train]:   1%|          | 44/5520 [02:45<5:43:51,  3.77s/it, loss=5.3105, Sem mIoU=0.2243, centroid loss=3.2532, Inst mIoU=0.1162, centroids found=9 (6) / 9]    


KeyboardInterrupt: 