In [None]:
import sys
sys.path.append('..')
import torch
from src.datasets.synthetic_config import CLASS_NAMES, CLASS_COLORS, STUFF_CLASSES
from src.datasets.synthetic import read_synthetic


# filepath = '../data/synthetic/raw/40_202506241038_frames_1_to_1056_noise_parts_processed.las'
# filepath = '../data/ontras/ontras_1_leveled.las'
filepath = '../data/ontras/ontras_3.las'
data = read_synthetic(filepath)

data.show(class_names=CLASS_NAMES, class_colors=CLASS_COLORS)

In [None]:
from src.transforms import SampleXYTiling
from src.data import Batch

xy_tiling = (5, 5)

chunks = []
for x in range(xy_tiling[0]):
    for y in range(xy_tiling[1]):
        chunk = SampleXYTiling(x=x, y=y, tiling=xy_tiling)(data)
        chunk.tile = torch.full((chunk.num_points,), x * xy_tiling[1] + y)
        chunks.append(chunk)

data_tiled = Batch.from_data_list(chunks)
print(data_tiled)
data_tiled.show(keys='tile')

In [None]:
from src.utils import init_config
from src.transforms import instantiate_datamodule_transforms
from src.transforms import NAGRemoveKeys

# cfg = init_config(overrides=[f"experiment=semantic/synthetic_nano"])
cfg = init_config(overrides=[
    f"experiment=semantic/synthetic_11g",
    f"datamodule.load_full_res_idx=True"
])
cfg.keys()

transforms_dict = instantiate_datamodule_transforms(cfg.datamodule)
nag = transforms_dict['pre_transform'](data)

nag = NAGRemoveKeys(level=0, keys=[k for k in nag[0].keys if k not in cfg.datamodule.point_load_keys])(nag)
nag = NAGRemoveKeys(level='1+', keys=[k for k in nag[1].keys if k not in cfg.datamodule.segment_load_keys])(nag)
nag = nag.cuda()
nag = transforms_dict['on_device_test_transform'](nag)

In [None]:
import hydra
from src.utils import init_config

ckpt_path = "../checkpoints/0701_2_spt-2_synthetic_11g_epoch_099.ckpt"
cfg = init_config(overrides=[f"experiment=semantic/synthetic_11g"])
model = hydra.utils.instantiate(cfg.model)
model = model._load_from_checkpoint(ckpt_path)

model = model.eval().to(nag.device)
print(nag)
with torch.no_grad():
    output = model(nag)
output.semantic_pred().shape, nag.num_points

## Save full resolution prediction in .las file

In [None]:
import numpy as np
import laspy

# Compute full-resolution semantic predictions
raw_semseg_y = output.full_res_semantic_pred(
    super_index_level0_to_level1=nag[0].super_index,
    sub_level0_to_raw=nag[0].sub
)

print(f"Full resolution predictions shape: {raw_semseg_y.shape}")
print(f"Original data points: {data.num_points}")

original_las = laspy.read(filepath)
assert len(raw_semseg_y) == len(original_las.points), f"Mismatch: {len(raw_semseg_y)} predictions vs {len(original_las.points)} points"

# Neue LAS-Datei erstellen mit Predictions
# Kopiere die urspr端ngliche Struktur
output_las = laspy.LasData(original_las.header)
output_las.points = original_las.points

# F端ge die semantischen Predictions als neues Feld hinzu
# Konvertiere zu numpy array falls es ein torch tensor ist
if hasattr(raw_semseg_y, "cpu"):
    predictions = raw_semseg_y.cpu().numpy().astype(np.uint8)
else:
    predictions = np.array(raw_semseg_y, dtype=np.uint8)

# Klassifizierungsfeld setzen
output_las.classification = predictions

# Optional: Auch die Klassennamen f端r bessere Interpretierbarkeit hinzuf端gen
print("Predicted classes:")
unique_classes = np.unique(predictions)
for cls in unique_classes:
    count = np.sum(predictions == cls)
    class_name = CLASS_NAMES[cls] if cls < len(CLASS_NAMES) else f"Unknown_{cls}"
    print(f"  Class {cls} ({class_name}): {count} points")

# Ausgabedatei speichern
output_filename = filepath.replace('.las', '_predicted.las')
output_las.write(output_filename)
print(f"Saved predictions to: {output_filename}")

## Show voxel prediction

In [None]:
nag[0].semantic_pred = output.voxel_semantic_pred(super_index=nag[0].super_index)
nag.show(
    stuff_classes=STUFF_CLASSES, 
    class_names=CLASS_NAMES, 
    class_colors=CLASS_COLORS, 
    max_points=100000, 
    # title="My Interactive Visualization Partition", 
    # path="my_interactive_visualization.html", 
    centroids=True,
    v_edge=True, 
    v_edge_width=2, 
    gap=[0, 0, 4])

In [None]:
nag.level_ratios