In [None]:
import os
import torch
import json
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
from scipy.optimize import linear_sum_assignment
from collections import defaultdict
from tqdm import tqdm

# Import classes from your main pipeline script
try:
    from trainModel import (
        Config, 
        PointUNet, 
        LiDARPointCloudDataset, 
        cluster_embeddings
    )
except ImportError:
    print("Error: Could not import from 'roof_segmentation_pipeline.py'.")
    exit()

In [None]:
# -------------------------------------------------------------------------
# HELPER FUNCTIONS
# -------------------------------------------------------------------------
def instance_mean_iou(gt_labels, pred_labels):
    """
    Calculates mean IoU by finding the optimal matching between 
    ground truth and predicted instances using the Hungarian algorithm.
    """
    gt_ids = np.unique(gt_labels)
    pred_ids = np.unique(pred_labels)
    
    if len(gt_ids) == 0 or len(pred_ids) == 0:
        return 0.0
        
    iou_matrix = np.zeros((len(gt_ids), len(pred_ids)))

    for i, gt_id in enumerate(gt_ids):
        gt_mask = gt_labels == gt_id
        for j, pred_id in enumerate(pred_ids):
            pred_mask = pred_labels == pred_id
            intersection = np.logical_and(gt_mask, pred_mask).sum()
            union = np.logical_or(gt_mask, pred_mask).sum()
            if union > 0:
                iou_matrix[i, j] = intersection / union

    # Maximize total IoU using linear sum assignment
    row_ind, col_ind = linear_sum_assignment(-iou_matrix)
    matched_ious = iou_matrix[row_ind, col_ind]
    return matched_ious.mean()

def save_cluster_plot(points, labels, path):
    """Saves a 3D scatter plot of the clusters."""
    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111, projection='3d')
    # Using tab20 for distinct colors
    ax.scatter(points[:, 0], points[:, 1], points[:, 2], c=labels, cmap='tab20', s=2)
    ax.set_axis_off()
    plt.tight_layout()
    plt.savefig(path)
    plt.close()

In [None]:
def evaluate(conf: Config = None, model_path="roof_segmentation_dgcnn_best.pth", vis_output_dir="./evaluation_results"):
    # -------------------------------------------------------------------------
    # 1. SETUP
    # -------------------------------------------------------------------------
    if conf is None:
        conf = Config()
    # MODEL_PATH = "roof_segmentation_dgcnn_best.pth"
    # VIS_OUTPUT_DIR = "./evaluation_results"
    os.makedirs(vis_output_dir, exist_ok=True)
    
    # Use HDBSCAN if available (from pipeline config), else MeanShift
    clustering_method = conf.clustering_method

    print(f"--- Evaluation Settings ---")
    print(f"Model: {model_path}")
    print(f"Clustering: {clustering_method}")
    print(f"Data Split: TEST")
    
    # -------------------------------------------------------------------------
    # 2. LOAD DATA
    # -------------------------------------------------------------------------
    test_dataset = LiDARPointCloudDataset(
        base_dir="data/roofNTNU/train_test_split", 
        split='test', 
        max_points=conf.max_points,
        sampling_method=conf.sampling_method
    )
    
    if len(test_dataset) == 0:
        print("No test data found. Please check data/train_test_split/points_test_n")
        return

    # Batch size 1 for accurate per-instance metric calculation
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    
    # -------------------------------------------------------------------------
    # 3. LOAD MODEL
    # -------------------------------------------------------------------------
    model = PointUNet(conf).to(conf.device)
    if not os.path.exists(model_path):
        print(f"Checkpoint not found at {model_path}. Cannot evaluate.")
        return
        
    checkpoint = torch.load(model_path, map_location=conf.device)
    model.load_state_dict(checkpoint)
    model.eval()
    
    # -------------------------------------------------------------------------
    # 4. EVALUATION LOOP
    # -------------------------------------------------------------------------
    all_metrics = []
    complexity_buckets = defaultdict(list) # key = complexity, value = list of (ARI, mIoU)
    
    print("Starting Inference...")
    
    with torch.no_grad():
        for i, (points, labels, _) in enumerate(tqdm(test_loader)):
            scene_name = f"scene_{i:03d}"
            
            points = points.to(conf.device)
            gt_labels_full = labels.cpu().numpy()[0]
            
            # Forward Pass
            embeddings = model(points)
            emb_sample = embeddings[0].cpu().numpy()
            
            # Clustering
            pred_labels_full = cluster_embeddings(emb_sample, method=clustering_method)
            
            # -----------------------------------------------------------
            # 5. METRICS & COMPLEXITY
            # -----------------------------------------------------------
            # Filter padding (-1)
            valid_mask = gt_labels_full != -1
            if valid_mask.sum() == 0: continue
                
            gt_valid = gt_labels_full[valid_mask]
            pred_valid = pred_labels_full[valid_mask]
            points_valid = points[0].cpu().numpy()[valid_mask] # For visualization
            
            # Calculate Metrics
            ari = adjusted_rand_score(gt_valid, pred_valid)
            nmi = normalized_mutual_info_score(gt_valid, pred_valid)
            miou = instance_mean_iou(gt_valid, pred_valid)
            
            # Determine Complexity
            num_gt_instances = len(np.unique(gt_valid))
            if num_gt_instances <= 2:
                complexity = "simple"
            elif num_gt_instances <= 5:
                complexity = "moderate"
            else:
                complexity = "complex"
                
            # Store results
            result_entry = {
                "scene": scene_name,
                "ARI": float(ari),
                "NMI": float(nmi),
                "mIoU": float(miou),
                "gt_instances": int(num_gt_instances),
                "pred_instances": int(len(np.unique(pred_valid))),
                "complexity": complexity
            }
            all_metrics.append(result_entry)
            complexity_buckets[complexity].append((ari, miou))
            
            # -----------------------------------------------------------
            # 6. VISUALIZATION
            # -----------------------------------------------------------
            if i < 10: # Save first 10
                vis_path = os.path.join(vis_output_dir, f"{scene_name}_pred.png")
                save_cluster_plot(points_valid, pred_valid, vis_path)
                
                # Optional: Save .txt for CloudCompare
                txt_path = os.path.join(vis_output_dir, f"{scene_name}.txt")
                save_data = np.column_stack((points_valid, pred_valid, gt_valid))
                np.savetxt(txt_path, save_data, fmt="%.6f %.6f %.6f %d %d", header="x y z pred gt")

    # -------------------------------------------------------------------------
    # 7. SUMMARY REPORT
    # -------------------------------------------------------------------------
    # Save JSON
    with open(os.path.join(vis_output_dir, "evaluation_summary.json"), "w") as f:
        json.dump(all_metrics, f, indent=2)

    avg_ari = np.mean([m['ARI'] for m in all_metrics]) if all_metrics else 0.0
    avg_miou = np.mean([m['mIoU'] for m in all_metrics]) if all_metrics else 0.0

    print("\n" + "="*50)
    print("           EVALUATION REPORT           ")
    print("="*50)
    print(f"Total Samples: {len(all_metrics)}")
    print(f"Overall Mean ARI:  {avg_ari:.4f}")
    print(f"Overall Mean mIoU: {avg_miou:.4f}")
    print("-" * 50)
    print("Performance by Roof Complexity:")
    
    for level in ["simple", "moderate", "complex"]:
        if level in complexity_buckets:
            scores = complexity_buckets[level]
            aris = [a for a, _ in scores]
            ious = [m for _, m in scores]
            print(f"  {level.capitalize().ljust(10)} ({len(scores)} scenes): ARI={np.mean(aris):.3f}, mIoU={np.mean(ious):.3f}")
        else:
            print(f"  {level.capitalize().ljust(10)} (0 scenes): N/A")
            
    print("="*50)
    print(f"Results saved to: {vis_output_dir}")

In [None]:
if __name__ == "__main__":
    evaluate()