In [1]:
import os
import torch
import json
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from torch.utils.data import DataLoader
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
from scipy.optimize import linear_sum_assignment
from collections import defaultdict
from tqdm import tqdm

# Import classes from your main pipeline script
try:
    from trainModel import (
        Config, 
        PointUNet, 
        LiDARPointCloudDataset, 
        cluster_embeddings
    )
except ImportError:
    print("Error: Could not import from 'trainModel.py'.")
    exit()


Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
CUDA available: True
Device count: 1
Device name: Tesla P100-PCIE-16GB


In [None]:
# -------------------------------------------------------------------------
# METRIC FUNCTIONS
# -------------------------------------------------------------------------
def compute_instance_metrics(gt_labels, pred_labels, iou_threshold=0.5):
    """
    Computes mCov, mWCov, mRec, and mPrec.
    """
    gt_ids = np.unique(gt_labels)
    pred_ids = np.unique(pred_labels)
    
    # Remove noise label (-1) if present
    gt_ids = gt_ids[gt_ids != -1]
    pred_ids = pred_ids[pred_ids != -1]

    num_gt = len(gt_ids)
    num_pred = len(pred_ids)

    # Handle edge cases (empty scenes)
    if num_gt == 0: return 0.0, 0.0, 0.0, 0.0, 0.0
    if num_pred == 0: return 0.0, 0.0, 0.0, 0.0, 0.0

    # 1. Compute IoU Matrix [num_gt x num_pred]
    iou_matrix = np.zeros((num_gt, num_pred))
    gt_sizes = np.zeros(num_gt)
    pred_sizes = np.zeros(num_pred)

    for i, g_id in enumerate(gt_ids):
        gt_mask = (gt_labels == g_id)
        gt_sizes[i] = np.sum(gt_mask)  # For Weighted Coverage
        
        for j, p_id in enumerate(pred_ids):
            pred_mask = (pred_labels == p_id)
            if i == 0: pred_sizes[j] = np.sum(pred_mask) # Calculate pred size once

            intersection = np.logical_and(gt_mask, pred_mask).sum()
            union = np.logical_or(gt_mask, pred_mask).sum()
            
            if union > 0:
                iou_matrix[i, j] = intersection / union

    # 2. Calculate Metrics
    # Max IoU for each GT instance (Best match found for this roof face)
    max_iou_per_gt = np.max(iou_matrix, axis=1) 
    
    # Max IoU for each Pred instance (Is this prediction valid?)
    max_iou_per_pred = np.max(iou_matrix, axis=0)

    # mCov (Mean Coverage)
    mCov = np.mean(max_iou_per_gt)

    # mWCov (Weighted Mean Coverage)
    mWCov = np.sum(max_iou_per_gt * gt_sizes) / np.sum(gt_sizes)

    # mRec (Mean Recall @ IoU 0.5) - "Did we find the roof?"
    mRec = np.sum(max_iou_per_gt >= iou_threshold) / num_gt

    # mPrec (Mean Precision @ IoU 0.5) - "Is the prediction real?"
    mPrec = np.sum(max_iou_per_pred >= iou_threshold) / num_pred

    # mWPrec (Weighted Mean Precision) - Matches literature that ignores tiny noise
    # Weights the precision of each cluster by its size.
    mWPrec = np.sum((max_iou_per_pred >= iou_threshold) * pred_sizes) / np.sum(pred_sizes)

    return mCov, mWCov, mRec, mPrec, mWPrec

def instance_mean_iou(gt_labels, pred_labels):
    """Calculates traditional mIoU using linear assignment."""
    gt_ids = np.unique(gt_labels)
    pred_ids = np.unique(pred_labels)
    gt_ids = gt_ids[gt_ids != -1]
    pred_ids = pred_ids[pred_ids != -1]
    
    if len(gt_ids) == 0 or len(pred_ids) == 0: return 0.0
        
    iou_matrix = np.zeros((len(gt_ids), len(pred_ids)))
    for i, gt_id in enumerate(gt_ids):
        gt_mask = gt_labels == gt_id
        for j, pred_id in enumerate(pred_ids):
            pred_mask = pred_labels == pred_id
            intersection = np.logical_and(gt_mask, pred_mask).sum()
            union = np.logical_or(gt_mask, pred_mask).sum()
            if union > 0: iou_matrix[i, j] = intersection / union

    row_ind, col_ind = linear_sum_assignment(-iou_matrix)
    return iou_matrix[row_ind, col_ind].mean()

def save_plotly_comparison(points, pred_labels, gt_labels, scene_name, save_dir):
    """
    Saves an interactive HTML plot comparing Ground Truth and Prediction side-by-side.
    """
    # Create subplots: 1 row, 2 cols, both 3D
    fig = make_subplots(
        rows=1, cols=2,
        specs=[[{'type': 'scene'}, {'type': 'scene'}]],
        subplot_titles=("Ground Truth", "Prediction")
    )

    # Helper to create a trace
    def create_trace(pts, lbls, name):
        # We use 'Jet' colormap to distinguish integer instance IDs visually
        return go.Scatter3d(
            x=pts[:, 0], y=pts[:, 1], z=pts[:, 2],
            mode='markers',
            marker=dict(
                size=3,
                color=lbls,
                colorscale='Jet', 
                opacity=0.8
            ),
            name=name,
            text=[f"Instance: {l}" for l in lbls],
            hoverinfo='text+x+y+z'
        )

    # Ground Truth Trace (Left)
    fig.add_trace(create_trace(points, gt_labels, "Ground Truth"), row=1, col=1)

    # Prediction Trace (Right)
    fig.add_trace(create_trace(points, pred_labels, "Prediction"), row=1, col=2)

    # Layout updates for better viewing
    fig.update_layout(
        title=f"Scene: {scene_name}",
        width=1400, height=700,
        showlegend=False,
        scene=dict(aspectmode='data'),
        scene2=dict(aspectmode='data')
    )
    
    # Save as HTML
    file_path = os.path.join(save_dir, f"{scene_name}_comparison.html")
    fig.write_html(file_path)

# -------------------------------------------------------------------------
# MAIN EVALUATION FUNCTION
# -------------------------------------------------------------------------
def evaluate(conf: Config = None, model_path="roof_segmentation_dgcnn_best.pth", vis_output_dir="./evaluation_results", save_plotly=True):
    # -------------------------------------------------------------------------
    # 1. SETUP
    # -------------------------------------------------------------------------
    if conf is None:
        conf = Config()
        
    os.makedirs(vis_output_dir, exist_ok=True)
    
    # Use config-defined clustering if available, else default to hdbscan
    cluster_method = getattr(conf, 'clustering_method', 'hdbscan')

    print(f"--- Evaluation Settings ---")
    print(f"Model: {model_path}")
    print(f"Clustering: {cluster_method}")
    print(f"Data Split: TEST")
    
    # -------------------------------------------------------------------------
    # 2. LOAD DATA
    # -------------------------------------------------------------------------
    # Using root_dir="data/roofNTNU" maps to internal path "data/roofNTNU/train_test_split"
    # This matches the user's intended data location.
    test_dataset = LiDARPointCloudDataset(
        base_dir="data/roofNTNU/train_test_split",  
        split='test', 
        max_points=conf.max_points,
        sampling_method=conf.sampling_method
    )
    
    if len(test_dataset) == 0:
        print("No test data found. Please check ./data/roofNTNU/train_test_split/points_test_n")
        return

    # Batch size 1 for accurate per-instance metric calculation
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    
    # -------------------------------------------------------------------------
    # 3. LOAD MODEL
    # -------------------------------------------------------------------------
    model = PointUNet(conf).to(conf.device)
    if not os.path.exists(model_path):
        print(f"Checkpoint not found at {model_path}. Cannot evaluate.")
        return
        
    checkpoint = torch.load(model_path, map_location=conf.device)
    model.load_state_dict(checkpoint)
    model.eval()
    
    # -------------------------------------------------------------------------
    # 4. EVALUATION LOOP
    # -------------------------------------------------------------------------
    all_metrics = []
    complexity_buckets = defaultdict(list) # key = complexity, value = list of (ARI, mIoU)
    
    print("Starting Inference...")
    
    with torch.no_grad():
        for i, (points, labels, _) in enumerate(tqdm(test_loader)):
            scene_name = f"scene_{i:03d}"
            
            points = points.to(conf.device)
            gt_labels_full = labels.cpu().numpy()[0]
            
            # Forward Pass
            embeddings = model(points)
            emb_sample = embeddings[0].cpu().numpy()
            
            # Clustering
            pred_labels_full = cluster_embeddings(emb_sample, method=cluster_method)
            
            # -----------------------------------------------------------
            # 5. METRICS & COMPLEXITY
            # -----------------------------------------------------------
            # Filter padding (-1)
            valid_mask = gt_labels_full != -1
            if valid_mask.sum() == 0: continue
                
            gt_valid = gt_labels_full[valid_mask]
            pred_valid = pred_labels_full[valid_mask]
            
            # Ensure we only slice the first 3 columns (XYZ) for metrics/viz
            points_valid_xyz = points[0, :, :3].cpu().numpy()[valid_mask] 
            
            # Calculate Metrics
            ari = adjusted_rand_score(gt_valid, pred_valid)
            nmi = normalized_mutual_info_score(gt_valid, pred_valid)
            miou = instance_mean_iou(gt_valid, pred_valid)
            mCov, mWCov, mRec, mPrec, mWPrec = compute_instance_metrics(gt_valid, pred_valid, iou_threshold=0.5)
            
            # Determine Complexity
            num_gt_instances = len(np.unique(gt_valid))
            if num_gt_instances <= 2:
                complexity = "simple"
            elif num_gt_instances <= 5:
                complexity = "moderate"
            else:
                complexity = "complex"
                
            # Store results
            result_entry = {
                "scene": scene_name,
                "ARI": float(ari),
                "NMI": float(nmi),
                "mIoU": float(miou),
                "mCov": float(mCov),
                "mWCov": float(mWCov),
                "mRec": float(mRec),
                "mPrec": float(mPrec),
                "mWPrec": float(mWPrec),
                "gt_instances": int(num_gt_instances),
                "pred_instances": int(len(np.unique(pred_valid))),
                "complexity": complexity
            }
            all_metrics.append(result_entry)
            complexity_buckets[complexity].append((ari, miou))
            
            # -----------------------------------------------------------
            # 6. VISUALIZATION
            # -----------------------------------------------------------
            # Save ALL plots using Plotly (interactive HTML) if enabled
            if save_plotly:
                save_plotly_comparison(points_valid_xyz, pred_valid, gt_valid, scene_name, vis_output_dir)
            
            # Optional: Save .txt for CloudCompare
            txt_path = os.path.join(vis_output_dir, f"{scene_name}.txt")
            save_data = np.column_stack((points_valid_xyz, pred_valid, gt_valid))
            np.savetxt(txt_path, save_data, fmt="%.6f %.6f %.6f %d %d", header="x y z pred gt")

    # -------------------------------------------------------------------------
    # 7. SUMMARY REPORT
    # -------------------------------------------------------------------------
    # Save JSON
    with open(os.path.join(vis_output_dir, "evaluation_summary.json"), "w") as f:
        json.dump(all_metrics, f, indent=2)

    # Compute Averages
    keys = ["ARI", "mIoU", "mCov", "mWCov", "mRec", "mPrec", "mWPrec"]
    avgs = {k: np.mean([m[k] for m in all_metrics]) for k in keys}

    print("\n" + "="*50)
    print("           EVALUATION REPORT           ")
    print("="*50)
    print(f"Total Samples: {len(all_metrics)}")
    print(f"Mean ARI:              {avgs['ARI']:.4f}")
    print(f"Mean IoU (mIoU):       {avgs['mIoU']:.4f}")
    print("-" * 50)
    print("RoofNet Comparison Metrics:")
    print(f"Mean Coverage (mCov):  {avgs['mCov']:.4f}")
    print(f"Weighted Cov (mWCov):  {avgs['mWCov']:.4f}")
    print(f"Mean Recall (mRec):    {avgs['mRec']:.4f}")
    print(f"Mean Precision (mPrec):{avgs['mPrec']:.4f}")
    print(f"Mean Weighted Prec (mWPrec): {avgs['mWPrec']:.4f} <--- Likely matches paper")
    print("-" * 50)
    print("Performance by Roof Complexity:")
    
    for level in ["simple", "moderate", "complex"]:
        if level in complexity_buckets:
            scores = complexity_buckets[level]
            aris = [a for a, _ in scores]
            ious = [m for _, m in scores]
            print(f"  {level.capitalize().ljust(10)} ({len(scores)} scenes): ARI={np.mean(aris):.3f}, mIoU={np.mean(ious):.3f}")
        else:
            print(f"  {level.capitalize().ljust(10)} (0 scenes): N/A")
            
    print("="*50)
    print(f"Results saved to: {vis_output_dir}")

In [7]:
if __name__ == "__main__":
    evaluate(model_path="models/roof_segmentation_dgcnn_20251209.pth", save_plotly=False)

--- Evaluation Settings ---
Model: models/roof_segmentation_dgcnn_20251209.pth
Clustering: hdbscan
Data Split: TEST


  checkpoint = torch.load(model_path, map_location=conf.device)


Starting Inference...


100%|██████████| 99/99 [02:02<00:00,  1.24s/it]


           EVALUATION REPORT           
Total Samples: 99
Mean ARI:              0.7519
Mean IoU (mIoU):       0.9018
--------------------------------------------------
RoofNet Comparison Metrics:
Mean Coverage (mCov):  0.9010
Weighted Cov (mWCov):  0.8985
Mean Recall (mRec):    0.9688
Mean Precision (mPrec):0.6015
--------------------------------------------------
Performance by Roof Complexity:
  Simple     (51 scenes): ARI=0.704, mIoU=0.957
  Moderate   (41 scenes): ARI=0.820, mIoU=0.856
  Complex    (7 scenes): ARI=0.699, mIoU=0.771
Results saved to: ./evaluation_results



