In [7]:
import open3d as o3d
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchsparse
import laspy
import pandas as pd

from tqdm import tqdm
from plyfile import PlyData

from EHydro_TreeUnet.tree_projector import TreeProjector
from EHydro_TreeUnet.tree_unet import UNet
from pathlib import Path

from torch import nn
from torch.nn import functional as tF
from torch.cuda import amp
from torchsparse import SparseTensor
from torchsparse.nn import functional as F
from torch.utils.data import random_split, DataLoader
from torchsparse.utils.collate import sparse_collate_fn
from torchsparse.utils.quantize import sparse_quantize
from scipy.optimize import linear_sum_assignment

In [17]:
TRAINING = True

CHANNELS = [16, 32, 64, 128]
LATENT_DIM = 512
MAX_INSTANCES = 64
TRAIN_PCT = 0.8
VOXEL_SIZE = 0.1
DATA_AUGMENTATION_COEF = 2.0
SEMANTIC_LOSS_COEF = 1.0
INSTANCE_LOSS_COEF = 0.0
BATCH_SIZE = 1

In [9]:
class Dataset:
    def __init__(self, files, voxel_size: float, data_augmentation: float = 1.0, yaw_range = (0, 360), tilt_range = (-5, 5), scale = (0.9, 1.1)) -> None:
        self._rng = np.random.default_rng()
        self._files = files

        self._voxel_size = voxel_size
        self._len = int(len(self._files) * data_augmentation)
        
        self._yaw_range = yaw_range
        self._tilt_range = tilt_range
        self._scale = scale
        
    def __getitem__(self, idx):
        if isinstance(idx, slice):
            return [self._preprocess(i) for i in range(*idx.indices(len(self)))]
        elif isinstance(idx, int):
            if idx < 0:
                idx += len(self)
            if idx < 0 or idx >= len(self):
                raise IndexError("Index out of range")
            return self._preprocess(idx)
        else:
            raise TypeError("Index must be a slice or an integer")
        
    def __len__(self):
        return self._len
    
    def _load_file(self, path):
        ext = path.suffix.lower()

        coords = ...
        feats = ...
        semantic_labels = ...
        
        if ext in ('.las, .laz'):
            file = laspy.read(path)

            z = np.asarray(file.z)
            coords = np.vstack((file.x, file.y, z)).transpose()
            # feats = np.hstack((np.array(file.intensity)[:, None], coords))
            min_z = np.min(z)
            feats = np.array(file.intensity) / 65535
            feats = np.column_stack((feats, (z - min_z) / (np.max(z) - min_z)))
            semantic_labels = np.array(file.classification)
            instance_labels = np.array(file.treeID)
        else:
            raise ValueError(f'Unsopported file extension: {ext}!')

        return coords, feats, semantic_labels, instance_labels
    
    def _agument_data(self, coords):
        yaw = np.deg2rad(self._rng.uniform(*self._yaw_range))
        pitch = np.deg2rad(self._rng.uniform(*self._tilt_range))
        roll = np.deg2rad(self._rng.uniform(*self._tilt_range))
        scale = self._rng.uniform(*self._scale)

        cy, sy = np.cos(yaw), np.sin(yaw)
        cp, sp = np.cos(pitch), np.sin(pitch)
        cr, sr = np.cos(roll), np.sin(roll)

        rotation_mtx = np.array([[cy*cp,  cy*sp*sr - sy*cr,  cy*sp*cr + sy*sr],
                                 [sy*cp,  sy*sp*sr + cy*cr,  sy*sp*cr - cy*sr],
                                 [ -sp ,            cp*sr ,            cp*cr ]],
                                dtype=coords.dtype)

        return (coords @ rotation_mtx.T) * scale
        
    def _preprocess(self, idx: int):
        coords, feat, semantic_labels, instance_labels = self._load_file(self._files[idx % len(self._files)])
        if idx >= len(self._files):
            coords = self._agument_data(coords)

        coords -= np.min(coords, axis=0, keepdims=True)

        voxels, indices, inverse_map = sparse_quantize(coords, self._voxel_size, return_index=True, return_inverse=True)
        feat = feat[indices]
        semantic_labels = semantic_labels[indices]
        instance_labels = instance_labels[indices]

        voxels = torch.tensor(voxels, dtype=torch.int)
        feat = torch.tensor(feat.astype(np.float32), dtype=torch.float)
        semantic_labels = torch.tensor(semantic_labels, dtype=torch.long)
        instance_labels = torch.tensor(instance_labels, dtype=torch.long)

        inputs = SparseTensor(coords=voxels, feats=feat)
        semantic_labels = SparseTensor(coords=voxels, feats=semantic_labels)
        instance_labels = SparseTensor(coords=voxels, feats=instance_labels)

        return {"inputs": inputs, "semantic_labels": semantic_labels, "instance_labels": instance_labels, "coords": coords, "inverse_map": inverse_map}


class MixedDataset:
    def __init__(self, voxel_size: float, data_augmentation: float = 1.0, yaw_range = (0, 360), tilt_range = (-5, 5), scale = (0.9, 1.1)) -> None:
        self._folder = Path('./datasets/MixedDataset')
        self._extensions = ('.laz', '.las')

        self._feat_channels = 2
        self._num_classes = 4
        self._class_names = ['Terrain', 'Low Vegetation', 'Stem', 'Canopy']
        self._class_colormap = np.array([
            [128, 128, 128], # clase 0 - Terrain - gris
            [147, 255, 138], # clase 1 - Low vegetation - verde claro
            [255, 165, 0],   # clase 2 - Stem - naranja
            [0, 128, 0],     # clase 3 - Canopy - verde oscuro
        ], dtype=np.uint8)

        files = sorted(
            [f for f in self._folder.rglob("*") if f.is_file() and f.suffix.lower() in self._extensions],
            key=lambda f: f.name
        )

        train_idx = int(TRAIN_PCT * len(files))
        self._train_dataset = Dataset(files[:train_idx], voxel_size, data_augmentation, yaw_range, tilt_range, scale)
        self._val_dataset = Dataset(files[train_idx:], voxel_size)

    @property
    def feat_channels(self):
        return self._feat_channels
    
    @property
    def num_classes(self):
        return self._num_classes

    @property
    def class_names(self):
        return self._class_names
    
    @property
    def class_colormap(self):
        return self._class_colormap
    
    @property
    def train_dataset(self):
        return self._train_dataset
    
    @property
    def val_dataset(self):
        return self._val_dataset
    

In [10]:
class ScanNet:
    def __init__(self, voxel_size: float, data_augmentation: float = 1.0, yaw_range = (0, 360), tilt_range = (-5, 5), scale = (0.9, 1.1)) -> None:
        self._rng = np.random.default_rng()
        self._folder = Path('./datasets/ScanNet/')
        self._extensions = ('.ply')
        self._feat_channels = 3
        self._num_classes = 41
        self._class_names = [
            "wall", "floor", "cabinet", "bed", "chair", "sofa", "table", "door", "window", "bookshelf",
            "picture", "counter", "desk", "curtain", "refrigerator", "shower curtain", "toilet", "sink",
            "bathtub", "otherfurniture"
        ]
        self._class_labels = np.array([1, 4, 5])
        cmap = plt.get_cmap('tab10', self._num_classes)
        colors = cmap(np.arange(self._num_classes))[:, :3]
        self._class_colormap = (colors * 255).astype(np.uint8)
        
        self._files = sorted(
            [f for f in self._folder.rglob("*") if f.is_file() and f.suffix.lower() in self._extensions and f.stem.endswith('.labels')],
            key=lambda f: f.name
        )

        self._voxel_size = voxel_size
        self._len = int(len(self._files) * data_augmentation)
        
        self._yaw_range = yaw_range
        self._tilt_range = tilt_range
        self._scale = scale
        
    def __getitem__(self, idx):
        if isinstance(idx, slice):
            return [self._preprocess(self.files[i]) for i in range(*idx.indices(len(self)))]
        elif isinstance(idx, int):
            if idx < 0:
                idx += len(self)
            if idx < 0 or idx >= len(self):
                raise IndexError("Index out of range")
            return self._preprocess(idx)
        else:
            raise TypeError("Index must be a slice or an integer")
        
    def __len__(self):
        return self._len
    
    @property
    def feat_channels(self):
        return self._feat_channels
    
    @property
    def num_classes(self):
        return self._num_classes

    @property
    def class_names(self):
        return self._class_names
    
    @property
    def class_colormap(self):
        return self._class_colormap
    
    def _load_file(self, path):
        ext = path.suffix.lower()

        coords = ...
        feats = ...
        semantic_labels = ...
        
        if ext in ('.las, .laz'):
            file = laspy.read(path)

            coords = np.vstack((file.x, file.y, file.z)).transpose()
            coords -= np.min(coords, axis=0, keepdims=True)
            # feats = np.hstack((np.array(file.intensity)[:, None], coords))
            I = np.array(file.intensity)
            p1, p99 = np.percentile(I, [1, 99])
            I_norm = np.clip((I - p1) / (p99 - p1), 0, 1)
            I_norm = I_norm - np.median(I_norm)
            feats = I_norm[:, None]
            
            semantic_labels = np.array(file.classification)
            instance_labels = np.array(file.treeID)
        elif ext == '.ply':
            ply = PlyData.read(str(path))
            v   = ply['vertex']

            coords = np.stack([v['x'], v['y'], v['z']], axis=-1).astype(np.float32)
            coords -= np.min(coords, axis=0, keepdims=True)
            feats = np.stack([v['red'], v['green'], v['blue']], axis=-1).astype(np.float32) / 256

            semantic_labels = v['label'].astype(np.int64)
            instance_labels = np.zeros(semantic_labels.shape)
        else:
            raise ValueError(f'Unsopported file extension: {ext}!')

        return coords, feats, semantic_labels, instance_labels
    
    def _agument_data(self, coords):
        yaw = np.deg2rad(self._rng.uniform(*self._yaw_range))
        pitch = np.deg2rad(self._rng.uniform(*self._tilt_range))
        roll = np.deg2rad(self._rng.uniform(*self._tilt_range))
        scale = self._rng.uniform(*self._scale)

        cy, sy = np.cos(yaw), np.sin(yaw)
        cp, sp = np.cos(pitch), np.sin(pitch)
        cr, sr = np.cos(roll), np.sin(roll)

        rotation_mtx = np.array([[cy*cp,  cy*sp*sr - sy*cr,  cy*sp*cr + sy*sr],
                                 [sy*cp,  sy*sp*sr + cy*cr,  sy*sp*cr - cy*sr],
                                 [ -sp ,            cp*sr ,            cp*cr ]],
                                dtype=coords.dtype)

        return (coords @ rotation_mtx.T) * scale
        
    def _preprocess(self, idx: int):
        coords, feat, semantic_labels, instance_labels = self._load_file(self._files[idx % len(self._files)])
        if idx >= len(self._files):
            coords = self._agument_data(coords)

        voxels, indices, inverse_map = sparse_quantize(coords, self._voxel_size, return_index=True, return_inverse=True)
        feat = feat[indices]
        semantic_labels = semantic_labels[indices]
        instance_labels = instance_labels[indices]

        voxels = torch.tensor(voxels, dtype=torch.int)
        feat = torch.tensor(feat.astype(np.float32), dtype=torch.float)
        semantic_labels = torch.tensor(semantic_labels, dtype=torch.long)
        instance_labels = torch.tensor(instance_labels, dtype=torch.long)

        inputs = SparseTensor(coords=voxels, feats=feat)
        semantic_labels = SparseTensor(coords=voxels, feats=semantic_labels)
        instance_labels = SparseTensor(coords=voxels, feats=instance_labels)

        return {"inputs": inputs, "semantic_labels": semantic_labels, "instance_labels": instance_labels, "coords": coords, "inverse_map": inverse_map}

In [18]:
dataset = MixedDataset(voxel_size=VOXEL_SIZE, data_augmentation=DATA_AUGMENTATION_COEF)
input = dataset.train_dataset[2]
inputs = input['inputs']
semantic_labels = input['semantic_labels'].F

pcd = o3d.geometry.PointCloud()
coords = inputs.C
# colors = inputs.F

colors = dataset.class_colormap[semantic_labels] / 255.0
pcd.points = o3d.utility.Vector3dVector(coords)
pcd.colors = o3d.utility.Vector3dVector(colors)
o3d.visualization.draw_geometries([pcd])

In [None]:
class TreeProjectorTrainer:
    def __init__(self):
        F.set_kmap_mode("hashmap")

        self._dataset = MixedDataset(voxel_size=VOXEL_SIZE, data_augmentation=DATA_AUGMENTATION_COEF)

        self._model = TreeProjector(self._dataset.feat_channels, self._dataset.num_classes, MAX_INSTANCES, channels = CHANNELS, latent_dim = LATENT_DIM)
        # self._model = UNet(self._dataset.feat_channels, self._dataset.num_classes, base_channels=64, depth=3)
        self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        
        total_params = sum(p.numel() for p in self._model.parameters())
        trainable_params = sum(p.numel() for p in self._model.parameters() if p.requires_grad)

        print(f"Parámetros totales: {total_params:,}")
        print(f"Parámetros entrenables: {trainable_params:,}")

        self._train_loader = DataLoader(self._dataset.train_dataset, batch_size=BATCH_SIZE, collate_fn=sparse_collate_fn, shuffle=True)
        self._val_loader = DataLoader(self._dataset.val_dataset, batch_size=BATCH_SIZE, collate_fn=sparse_collate_fn, shuffle=True)

        self._criterion_semantic = nn.CrossEntropyLoss()
        self._criterion_instance = nn.CrossEntropyLoss()

        if not TRAINING:
            self._load_weights()

        self._model.to(self._device)

    @property
    def dataset(self):
        return self._dataset

    def _load_weights(self):
        self._model.load_state_dict(torch.load('./weights/tree_unet_weights.pth'))

    @torch.no_grad()
    def _apply_hungarian(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        N, K = logits.shape
        device = logits.device

        log_p = tF.log_softmax(logits, dim=-1)

        uniq = torch.unique(labels)
        M    = len(uniq)

        cost = torch.empty((M, K), device=device)
        for m, g in enumerate(uniq):
            mask = (labels == g)
            cost[m] = -(log_p[mask].mean(0))

        row, col = linear_sum_assignment(cost.detach().cpu())

        remapped = torch.full_like(labels, fill_value=-1)
        for r, c in zip(row, col):
            g = uniq[r]
            remapped[labels == g] = c

        print(f'Originales: {torch.unique(labels)}\nRemapeadas: {torch.unique(remapped)}')
        return remapped

    '''
    @torch.no_grad()
    def _apply_hungarian(conf: torch.Tensor, gt_labels: torch.Tensor, ignore_val: int = -1) -> torch.Tensor:
        M, K = conf.shape
        row, col = linear_sum_assignment((-conf).cpu().numpy())

        remapped = torch.full_like(gt_labels, fill_value=ignore_val)
        for r, c in zip(row, col):
            remapped[gt_labels == r] = c

        return remapped
    '''

    def _compute_loss(self, semantic_output, semantic_labels, instance_output = 0, instance_labels = 0):
        loss_sem = self._criterion_semantic(semantic_output, semantic_labels)
        #loss_inst = self._criterion_instance(instance_output, self._apply_hungarian(instance_output, instance_labels))
        loss_inst = 0

        return SEMANTIC_LOSS_COEF * loss_sem + INSTANCE_LOSS_COEF * loss_inst
    
    @torch.no_grad()    
    def _compute_metrics(self, pred_labels, gt_labels, num_classes, ignore_index = None):
        if ignore_index is not None:
            mask = gt_labels != ignore_index
            pred_labels, gt_labels = pred_labels[mask], gt_labels[mask]

        pred_labels = torch.argmax(pred_labels, dim=1)

        C = num_classes
        conf = torch.zeros((C, C), dtype=torch.long, device=pred_labels.device)
        idx = C * gt_labels + pred_labels
        conf += torch.bincount(idx, minlength=C**2).reshape(C, C)

        TP = conf.diag()
        FP = conf.sum(0) - TP
        FN = conf.sum(1) - TP

        precision = TP.float() / (TP + FP).clamp(min=1)
        recall    = TP.float() / (TP + FN).clamp(min=1)
        f1        = 2 * precision * recall / (precision + recall).clamp(min=1e-6)

        iou = TP.float() / (TP + FP + FN).clamp(min=1)
        miou = iou.mean()

        macroP, macroR, macroF = precision.mean(), recall.mean(), f1.mean()

        microTP = TP.sum()
        microP = microTP.float() / (microTP + FP.sum()).clamp(min=1)
        microR = microTP.float() / (microTP + FN.sum()).clamp(min=1)
        microF = 2 * microP * microR / (microP + microR).clamp(min=1e-6)

        return {
            "confusion":             conf.cpu().numpy(),
            "iou_per_class":         iou.cpu().numpy(),
            "miou":                  miou.cpu().numpy(),
            "precision_per_class":   precision.cpu().numpy(),
            "recall_per_class":      recall.cpu().numpy(),
            "f1_per_class":          f1.cpu().numpy(),
            "precision_macro":       macroP.cpu().numpy(),
            "recall_macro":          macroR.cpu().numpy(),
            "f1_macro":              macroF.cpu().numpy(),
            "precision_micro":       microP.cpu().numpy(),
            "recall_micro":          microR.cpu().numpy(),
            "f1_micro":              microF.cpu().numpy(),
        }
    
    '''
    def _compute_iou(self, semantic_output, semantic_labels):
        if semantic_output.C.shape != semantic_labels.C.shape or not torch.all(semantic_output.C == semantic_labels.C):
            raise ValueError("Dimensions doesn't match between semantic labels and output.")

        semantic_output = semantic_output.F.argmax(dim=1)
        semantic_labels = semantic_labels.F.view(-1).long()
        iou_list = torch.full((self._dataset.num_classes,), float('nan'), device=semantic_output.device)

        for cls in range(self._dataset.num_classes):
            label_mask = semantic_labels == cls
            out_mask = semantic_output == cls

            union = (out_mask | label_mask).sum()
            if union == 0:
                continue

            inter = (out_mask & label_mask).sum()
            iou_list[cls] = inter.float() / union.float()

        valid = ~torch.isnan(iou_list)
        miou  = iou_list[valid].mean().item() if valid.any() else float("nan")
        return iou_list, miou
    '''
            
    def _gen_charts(self, losses, stats, training):
        keys = stats[0].keys()
        stats = {k: np.array([d[k] for d in stats]) for k in keys}

        plt.figure(figsize=(10, 5))
        plt.plot(losses, label=f"{'Training' if training else 'Inference'} Loss")
        plt.xlabel("Step")
        plt.ylabel("Loss")
        plt.title(f"Loss evolution during {'Training' if training else 'Inference'}")
        plt.legend()
        plt.grid(True)
        plt.show()

        plt.figure(figsize=(10, 5))
        plt.plot(stats['miou'], label=f"{'Training' if training else 'Inference'} mIoU")
        plt.xlabel("Step")
        plt.ylabel("mIoU")
        plt.title(f"mIoU evolution during {'Training' if training else 'Inference'}")
        plt.legend()
        plt.grid(True)
        plt.show()

        iou_arr = np.asarray(stats['iou_per_class'])
        plt.figure(figsize=(10, 5))
        for c in range(self._dataset.num_classes):
            plt.plot(iou_arr[:, c], label=self._dataset.class_names[c])
        
        plt.xlabel("Step")
        plt.ylabel("IoU")
        plt.title(f"IoU evolution during {'Training' if training else 'Inference'}")
        plt.legend()
        plt.grid(True)
        plt.show()

        plt.figure(figsize=(10, 5))
        plt.plot(stats['precision_macro'], label=f"{'Training' if training else 'Inference'} precision")
        plt.xlabel("Step")
        plt.ylabel("Precision")
        plt.title(f"Precision evolution during {'Training' if training else 'Inference'}")
        plt.legend()
        plt.grid(True)
        plt.show()

        prec_arr = np.asarray(stats['precision_per_class'])
        plt.figure(figsize=(10, 5))
        for c in range(self._dataset.num_classes):
            plt.plot(prec_arr[:, c], label=self._dataset.class_names[c])
        
        plt.xlabel("Step")
        plt.ylabel("Precision")
        plt.title(f"Precision evolution during {'Training' if training else 'Inference'}")
        plt.legend()
        plt.grid(True)
        plt.show()

        plt.figure(figsize=(10, 5))
        plt.plot(stats['recall_macro'], label=f"{'Training' if training else 'Inference'} recall")
        plt.xlabel("Step")
        plt.ylabel("Recall")
        plt.title(f"Recall evolution during {'Training' if training else 'Inference'}")
        plt.legend()
        plt.grid(True)
        plt.show()

        recall_arr = np.asarray(stats['recall_per_class'])
        plt.figure(figsize=(10, 5))
        for c in range(self._dataset.num_classes):
            plt.plot(recall_arr[:, c], label=self._dataset.class_names[c])
        
        plt.xlabel("Step")
        plt.ylabel("Recall")
        plt.title(f"Recall evolution during {'Training' if training else 'Inference'}")
        plt.legend()
        plt.grid(True)
        plt.show()

        plt.figure(figsize=(10, 5))
        plt.plot(stats['f1_macro'], label=f"{'Training' if training else 'Inference'} F1")
        plt.xlabel("Step")
        plt.ylabel("F1")
        plt.title(f"F1 evolution during {'Training' if training else 'Inference'}")
        plt.legend()
        plt.grid(True)
        plt.show()

        f1_arr = np.asarray(stats['f1_per_class'])
        plt.figure(figsize=(10, 5))
        for c in range(self._dataset.num_classes):
            plt.plot(f1_arr[:, c], label=self._dataset.class_names[c])
        
        plt.xlabel("Step")
        plt.ylabel("F1")
        plt.title(f"F1 evolution during {'Training' if training else 'Inference'}")
        plt.legend()
        plt.grid(True)
        plt.show()

        column_names = ['IoU', 'Precision', 'Recall', 'F1']
        row_names = [self._dataset.class_names[c] for c in range(self._dataset.num_classes)]
        row_names.append('Mean')

        data = [
            [iou_arr[:, c].mean(), prec_arr[:, c].mean(), recall_arr[:, c].mean(), f1_arr[:, c].mean()]
        for c in range(self._dataset.num_classes)]

        data.append([stats['miou'].mean(), stats['precision_macro'].mean(), stats['recall_macro'].mean(), stats['f1_macro'].mean()])

        df = pd.DataFrame(data, columns=column_names, index=row_names)
        display(df)
    
    def train(self):
        optimizer = torch.optim.Adam(self._model.parameters(), lr=1e-3)
        scaler = amp.GradScaler(enabled=True)
        losses = []
        stats = []

        pbar = tqdm(self._train_loader, desc='[Train]')
        for feed_dict in pbar:
            inputs = feed_dict["inputs"].to(self._device)
            semantic_labels = feed_dict["semantic_labels"].to(self._device)
            # instance_labels = feed_dict["instance_labels"].to(self._device)

            with amp.autocast(enabled=True):
                # semantic_output, instance_output = self._model(inputs)
                semantic_output = self._model(inputs)
                loss = self._compute_loss(semantic_output.F, semantic_labels.F)
                # loss = self._compute_loss(semantic_output.F, semantic_labels.F, instance_output.F, instance_labels.F)
                stat = self._compute_metrics(semantic_output.F, semantic_labels.F, num_classes=self._dataset.num_classes)


            stats.append(stat)
            losses.append(loss.item())
            # print(f"[Train step {k + 1}] loss = {loss.item()}; mIoU = {stat['miou']}")
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'mIoU': f'{stat["miou"]:.4f}'
            })

            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            del inputs, semantic_output, semantic_labels

        torch.save(self._model.state_dict(), './weights/tree_unet_weights.pth')
        self._gen_charts(losses, stats, True)

    def eval(self):
        self._model.eval()
        losses = []
        stats = []

        # enable torchsparse 2.0 inference
        # enable fused and locality-aware memory access optimization
        torchsparse.backends.benchmark = True  # type: ignore

        with torch.no_grad():
            pbar = tqdm(self._val_loader, desc='[Inference]')
            for feed_dict in pbar:
                semantic_labels_cpu = feed_dict["semantic_labels"].F.numpy()
                instance_labels_cpu = feed_dict["instance_labels"].F.numpy()
                coords = feed_dict["coords"].numpy()
                inverse_map = feed_dict["inverse_map"].numpy()

                inputs = feed_dict["inputs"].to(self._device)
                semantic_labels = feed_dict["semantic_labels"].to(self._device)
                # instance_labels = feed_dict["instance_labels"].to(self._device)

                with amp.autocast(enabled=True):
                    # semantic_output, instance_output = self._model(inputs)
                    semantic_output = self._model(inputs)
                    loss = self._compute_loss(semantic_output.F, semantic_labels.F)
                    # loss = self._compute_loss(semantic_output.F, semantic_labels.F, instance_output.F, instance_labels.F)
                    stat = self._compute_metrics(semantic_output.F, semantic_labels.F, num_classes=self._dataset.num_classes)

                losses.append(loss.item())
                stats.append(stat)

                voxels = semantic_output.C[:, 1:].cpu().numpy()
                semantic_output = torch.argmax(semantic_output.F.cpu(), dim=1).numpy()
                # instance_output = torch.argmax(instance_output.F.cpu(), dim=1).numpy()
                instance_output = np.zeros(semantic_output.shape)

                pbar.set_postfix({
                    'loss': f'{loss.item():.4f}',
                    'mIoU': f'{stat["miou"]:.4f}'
                })

                yield voxels, semantic_output, instance_output, semantic_labels_cpu, instance_labels_cpu, coords, inverse_map

        self._gen_charts(losses, stats, False)

In [12]:
tester = TreeProjectorTrainer()

if TRAINING:
    tester.train()

pcd = o3d.geometry.PointCloud()
for voxels, semantic_output, instance_output, semantic_labels, instance_labels, coords, inverse_map in tester.eval():
    continue
    coords = coords[0]
    inverse_map = inverse_map[0]

    colors = tester.dataset.class_colormap[semantic_labels[inverse_map]] / 255.0

    pcd.points = o3d.utility.Vector3dVector(coords)
    pcd.colors = o3d.utility.Vector3dVector(colors)
    o3d.visualization.draw_geometries([pcd])

    colors = tester.dataset.class_colormap[semantic_output[inverse_map]] / 255.0
    pcd.colors = o3d.utility.Vector3dVector(colors)
    o3d.visualization.draw_geometries([pcd])

    unique_ids = np.unique(instance_labels)
    rng = np.random.default_rng(0)
    palette = rng.random((len(unique_ids), 3))

    id2color = {uid: palette[i] for i, uid in enumerate(unique_ids)}
    colors = np.array([id2color[i] for i in instance_labels], dtype=np.float64)

    pcd.points = o3d.utility.Vector3dVector(coords)
    pcd.colors = o3d.utility.Vector3dVector(colors[inverse_map])
    o3d.visualization.draw_geometries([pcd])

    unique_ids = np.unique(instance_output)
    rng = np.random.default_rng(0)
    palette = rng.random((len(unique_ids), 3))

    id2color = {uid: palette[i] for i, uid in enumerate(unique_ids)}
    colors = np.array([id2color[i] for i in instance_output], dtype=np.float64)

    pcd.points = o3d.utility.Vector3dVector(coords)
    pcd.colors = o3d.utility.Vector3dVector(colors[inverse_map])
    o3d.visualization.draw_geometries([pcd])


Parámetros totales: 16,119,472
Parámetros entrenables: 16,119,472


AttributeError: 'ScanNet' object has no attribute 'train_dataset'