In [None]:
import os
import torch
import tqdm
import numpy as np
from multiprocessing import Pool
from data_utils.prob_vol_data_utils import ProbVolDataset
from utils.localization_utils import finalize_localization

# Hardcoded Configuration
CONFIG = {
    "dataset_dir": "/datadrive2/CRM.AI.Research/TeamFolders/Email/repo_yuval/FloorPlan/Semantic_Floor_plan_localization/data/test_data_set_full/structured3d_perspective_full",
    "prob_vol_save_dir": "/datadrive2/CRM.AI.Research/TeamFolders/Email/repo_yuval/FloorPlan/Semantic_Floor_plan_localization/data/test_data_set_full",
    "prob_vol_path": "/datadrive2/CRM.AI.Research/TeamFolders/Email/repo_yuval/FloorPlan/Semantic_Floor_plan_localization/data/test_data_set_full/prob_vols",
    "L": 0,
    "start_scene": 0,
    "end_scene": 3000,
    "weight_combinations": [
        [1.0, 0.0],  # Only depth
        [0.9, 0.1],  # Weighted combination
        [0.5, 0.5],  # Equal combination
    ],
}

# Global variables for multiprocessing
dataset = None


def calculate_sample_accs(prob_vol_pred_depth, prob_vol_pred_semantic, weight_combinations, ref_pose, device):
    sample_accs = []

    for depth_weight, semantic_weight in weight_combinations:
        # Combine volumes
        min_shape = [min(d, s) for d, s in zip(prob_vol_pred_depth.shape, prob_vol_pred_semantic.shape)]
        prob_vol_pred_depth_sliced = prob_vol_pred_depth[tuple(slice(0, min_shape[i]) for i in range(len(min_shape)))]
        prob_vol_pred_semantic_sliced = prob_vol_pred_semantic[tuple(slice(0, min_shape[i]) for i in range(len(min_shape)))]
        combined_pred = depth_weight * prob_vol_pred_depth_sliced + semantic_weight * prob_vol_pred_semantic_sliced

        # Localize and calculate accuracy
        _, _, _, pose_pred = finalize_localization(combined_pred)
        pose_pred = torch.tensor(pose_pred, device=device, dtype=torch.float32)
        pose_pred[:2] = pose_pred[:2] / 10  # Scale poses
        acc_pred = torch.norm(pose_pred[:2] - ref_pose[:2], p=2).item()

        sample_accs.append(acc_pred)

    return sample_accs


def process_data_idx(data_idx):
    global dataset
    device = torch.device("cpu")

    if dataset is None:
        dataset = ProbVolDataset(
            dataset_dir=CONFIG["dataset_dir"],
            scene_names=[f"scene_{str(i).zfill(5)}" for i in range(CONFIG["start_scene"], CONFIG["end_scene"] + 1)],
            L=CONFIG["L"],
            prob_vol_path=CONFIG["prob_vol_path"],
            acc_only=False,
        )

    data = dataset[data_idx]
    ref_pose = torch.tensor(data["ref_pose"], device=device, dtype=torch.float32)
    prob_vol_pred_depth = data['prob_vol_depth'].to(device)
    prob_vol_pred_semantic = data['prob_vol_semantic'].to(device)

    return calculate_sample_accs(
        prob_vol_pred_depth,
        prob_vol_pred_semantic,
        CONFIG["weight_combinations"],
        ref_pose,
        device,
    )


def generate_summary():
    weight_combinations = CONFIG["weight_combinations"]
    dataset = ProbVolDataset(
        dataset_dir=CONFIG["dataset_dir"],
        scene_names=[f"scene_{str(i).zfill(5)}" for i in range(CONFIG["start_scene"], CONFIG["end_scene"] + 1)],
        L=CONFIG["L"],
        prob_vol_path=CONFIG["prob_vol_path"],
        acc_only=False,
    )
    total_scenes = CONFIG["end_scene"] - CONFIG["start_scene"] + 1
    total_images = len(dataset)

    num_processes = os.cpu_count() -5

    with Pool(processes=num_processes) as pool:
        results = list(tqdm.tqdm(pool.imap(process_data_idx, range(total_images)), total=total_images, desc="Processing Data"))

    # Consolidate results for each weight combination
    accs_by_weights = list(zip(*results))

    less_than_counts = {threshold: [0] * len(weight_combinations) for threshold in ["0.1m", "0.5m", "1m"]}

    # Default combination indices
    default_idx = weight_combinations.index([0.5, 0.5])
    default_accs = accs_by_weights[default_idx]

    ref_idx_1_0 = weight_combinations.index([1.0, 0.0])
    ref_accs_1_0 = accs_by_weights[ref_idx_1_0]

    below_1m_better_0_5 = {threshold: [0] * len(weight_combinations) for threshold in ["0.1m", "0.5m", "1m"]}
    below_1m_better_1_0 = {threshold: [0] * len(weight_combinations) for threshold in ["0.1m", "0.5m", "1m"]}

    for idx, accs in enumerate(accs_by_weights):
        for acc, default_acc, ref_acc in zip(accs, default_accs, ref_accs_1_0):
            # General thresholds
            if acc < 0.1:
                less_than_counts["0.1m"][idx] += 1
            if acc < 0.5:
                less_than_counts["0.5m"][idx] += 1
            if acc < 1.0:
                less_than_counts["1m"][idx] += 1

            # Accuracy lower than 1m
            if acc < 1.0:
                # Compare with (0.5, 0.5)
                if acc < default_acc - 0.1:
                    below_1m_better_0_5["0.1m"][idx] += 1
                if acc < default_acc - 0.5:
                    below_1m_better_0_5["0.5m"][idx] += 1
                if acc < default_acc - 1.0:
                    below_1m_better_0_5["1m"][idx] += 1

                # Compare with (1.0, 0.0)
                if acc < ref_acc - 0.1:
                    below_1m_better_1_0["0.1m"][idx] += 1
                if acc < ref_acc - 0.5:
                    below_1m_better_1_0["0.5m"][idx] += 1
                if acc < ref_acc - 1.0:
                    below_1m_better_1_0["1m"][idx] += 1

    # Print total counts
    print(f"Total number of scenes: {total_scenes}")
    print(f"Total number of images: {total_images}\n")

    # Original Table: General Counts
    print("Summary of Results:\n")
    header = (
        f"{'Weights':<30} {'Count < 0.1m':<25} {'Count < 0.5m':<25} {'Count < 1m':<25} "
    )
    divider = "-" * len(header)
    print(header)
    print(divider)

    for idx, (depth_weight, semantic_weight) in enumerate(weight_combinations):
        weights_str = f"Depth={depth_weight}, Semantic={semantic_weight}"
        count_lt_01 = less_than_counts["0.1m"][idx]
        count_lt_05 = less_than_counts["0.5m"][idx]
        count_lt_1 = less_than_counts["1m"][idx]

        print(
            f"{weights_str:<30} {count_lt_01} ({count_lt_01 / total_images * 100:.2f}%)                 {count_lt_05} ({count_lt_05 / total_images * 100:.2f}%)                 {count_lt_1} ({count_lt_1 / total_images * 100:.2f}%)"
        )

    # Table 1: Below 1m and better than (0.5, 0.5)
    print("\nResults: Below 1m and Better than (0.5, 0.5):\n")
    header = (
        f"{'Weights':<30} {'Better by 0.1m':<30} {'Better by 0.5m':<30} {'Better by 1m':<30}"
    )
    divider = "-" * len(header)
    print(header)
    print(divider)

    for idx, (depth_weight, semantic_weight) in enumerate(weight_combinations):
        weights_str = f"Depth={depth_weight}, Semantic={semantic_weight}"
        better_by_01 = below_1m_better_0_5["0.1m"][idx]
        better_by_05 = below_1m_better_0_5["0.5m"][idx]
        better_by_1 = below_1m_better_0_5["1m"][idx]

        print(
            f"{weights_str:<30} {better_by_01} ({better_by_01 / total_images * 100:.2f}%)                    {better_by_05} ({better_by_05 / total_images * 100:.2f}%)                    {better_by_1} ({better_by_1 / total_images * 100:.2f}%)"
        )

    # Table 2: Below 1m and better than (1.0, 0.0)
    print("\nResults: Below 1m and Better than (1.0, 0.0):\n")
    print(header)
    print(divider)

    for idx, (depth_weight, semantic_weight) in enumerate(weight_combinations):
        weights_str = f"Depth={depth_weight}, Semantic={semantic_weight}"
        better_by_01 = below_1m_better_1_0["0.1m"][idx]
        better_by_05 = below_1m_better_1_0["0.5m"][idx]
        better_by_1 = below_1m_better_1_0["1m"][idx]

        print(
            f"{weights_str:<30} {better_by_01} ({better_by_01 / total_images * 100:.2f}%)                    {better_by_05} ({better_by_05 / total_images * 100:.2f}%)                    {better_by_1} ({better_by_1 / total_images * 100:.2f}%)"
        )


if __name__ == "__main__":
    generate_summary()







  7%|▋         | 197/3001 [00:00<00:02, 1021.90it/s]

Scene scene_46 has floorplan_semantic.png with dimensions 3444x1454, skipping this scene.
Missing prediction files for scene scene_103, image 0. Skipping this scene.
Missing files for scene scene_144: ['/datadrive2/CRM.AI.Research/TeamFolders/Email/repo_yuval/FloorPlan/Semantic_Floor_plan_localization/data/test_data_set_full/structured3d_perspective_full/scene_144/poses.txt'], skipping this scene.


 17%|█▋        | 509/3001 [00:00<00:02, 856.33it/s] 

Scene scene_360 has floorplan_semantic.png with dimensions 1492x3715, skipping this scene.
Missing prediction files for scene scene_500, image 0. Skipping this scene.


 32%|███▏      | 950/3001 [00:00<00:01, 1052.41it/s]

Scene scene_765 has floorplan_semantic.png with dimensions 3240x2440, skipping this scene.
Scene scene_978 has floorplan_semantic.png with dimensions 3325x999, skipping this scene.


 38%|███▊      | 1152/3001 [00:01<00:02, 871.41it/s]

Scene scene_1045 has floorplan_semantic.png with dimensions 3070x1428, skipping this scene.
Scene scene_1078 has floorplan_semantic.png with dimensions 3357x2248, skipping this scene.
Missing files for scene scene_1155: ['/datadrive2/CRM.AI.Research/TeamFolders/Email/repo_yuval/FloorPlan/Semantic_Floor_plan_localization/data/test_data_set_full/structured3d_perspective_full/scene_1155/poses.txt'], skipping this scene.


 45%|████▌     | 1352/3001 [00:01<00:01, 892.74it/s]

Missing files for scene scene_1192: ['/datadrive2/CRM.AI.Research/TeamFolders/Email/repo_yuval/FloorPlan/Semantic_Floor_plan_localization/data/test_data_set_full/structured3d_perspective_full/scene_1192/poses.txt'], skipping this scene.


 61%|██████▏   | 1841/3001 [00:01<00:00, 1399.56it/s]

Missing files for scene scene_1600: ['/datadrive2/CRM.AI.Research/TeamFolders/Email/repo_yuval/FloorPlan/Semantic_Floor_plan_localization/data/test_data_set_full/structured3d_perspective_full/scene_1600/poses.txt'], skipping this scene.
Missing files for scene scene_1601: ['/datadrive2/CRM.AI.Research/TeamFolders/Email/repo_yuval/FloorPlan/Semantic_Floor_plan_localization/data/test_data_set_full/structured3d_perspective_full/scene_1601/poses.txt'], skipping this scene.
Missing files for scene scene_1602: ['/datadrive2/CRM.AI.Research/TeamFolders/Email/repo_yuval/FloorPlan/Semantic_Floor_plan_localization/data/test_data_set_full/structured3d_perspective_full/scene_1602/poses.txt'], skipping this scene.
Missing files for scene scene_1603: ['/datadrive2/CRM.AI.Research/TeamFolders/Email/repo_yuval/FloorPlan/Semantic_Floor_plan_localization/data/test_data_set_full/structured3d_perspective_full/scene_1603/poses.txt'], skipping this scene.
Missing files for scene scene_1604: ['/datadrive2/CR

 74%|███████▍  | 2228/3001 [00:02<00:00, 1016.50it/s]

Missing prediction files for scene scene_2026, image 0. Skipping this scene.
Scene scene_2051 has floorplan_semantic.png with dimensions 4016x1516, skipping this scene.
Missing prediction files for scene scene_2098, image 0. Skipping this scene.
Missing files for scene scene_2119: ['/datadrive2/CRM.AI.Research/TeamFolders/Email/repo_yuval/FloorPlan/Semantic_Floor_plan_localization/data/test_data_set_full/structured3d_perspective_full/scene_2119/poses.txt'], skipping this scene.
Scene scene_2208 has floorplan_semantic.png with dimensions 2028x3035, skipping this scene.


 85%|████████▌ | 2558/3001 [00:02<00:00, 978.22it/s] 

Missing files for scene scene_2401: ['/datadrive2/CRM.AI.Research/TeamFolders/Email/repo_yuval/FloorPlan/Semantic_Floor_plan_localization/data/test_data_set_full/structured3d_perspective_full/scene_2401/poses.txt'], skipping this scene.


100%|██████████| 3001/3001 [00:03<00:00, 976.95it/s] 

Missing prediction files for scene scene_2800, image 0. Skipping this scene.
Number of scenes after filtering: 2781





In [None]:
import argparse
import os
import numpy as np
import torch
import tqdm
import yaml
from attrdict import AttrDict
import cv2
import torch.nn as nn
from PIL import Image

# Import your models
from modules.mono.depth_net_pl import depth_net_pl
from modules.semantic.semantic_net_pl import semantic_net_pl

# Import your dataset
from data_utils.data_utils import GridSeqDataset
import torch.nn.functional as F

# Import helper functions from utils
from utils.localization_utils import (
    get_ray_from_depth,
    get_ray_from_semantics,
    localize,
    finalize_localization,
)
from utils.data_loader_helper import load_scene_data

# Torch DataLoader
from torch.utils.data import DataLoader


def evaluate_combined_model(
    depth_net,
    semantic_net,
    desdfs,
    semantics,
    test_set,
    gt_poses,
    maps,
    device,
    valid_scene_names,
    config,  
):
    improved_cases = {"rays_before": [], "rays_after": [], "accuracies": []}
    degraded_cases = {"rays_before": [], "rays_after": [], "accuracies": []}

    for data_idx in tqdm.tqdm(range(len(test_set))):
        data = test_set[data_idx]
        scene_idx = np.sum(data_idx >= np.array(test_set.scene_start_idx)) - 1
        scene = test_set.scene_names[scene_idx]
        if 'floor' not in scene:
            scene_number = int(scene.split("_")[1])
            scene = f"scene_{scene_number}"
        
        if not scene in valid_scene_names:
            continue
        idx_within_scene = data_idx - test_set.scene_start_idx[scene_idx]

        desdf = desdfs[scene]
        semantic = semantics[scene]
        ref_pose_map = gt_poses[scene][idx_within_scene * (config["L"] + 1) + config["L"], :]

        # Predict depth
        ref_img_torch = torch.tensor(data["ref_img"], device=device).unsqueeze(0)
        pred_depths, _, _ = depth_net.encoder(ref_img_torch, None)
        pred_depths = pred_depths.squeeze(0).detach().cpu().numpy()
        pred_rays_depth = get_ray_from_depth(pred_depths, V=config["V"], F_W=config["F_W"])

        # Predict semantics
        _, _, prob = semantic_net.encoder(ref_img_torch, None)
        prob_squeezed = prob.squeeze(dim=0)
        sampled_indices = torch.multinomial(
            prob_squeezed, num_samples=1, replacement=True
        )
        sampled_indices = sampled_indices.squeeze(dim=1)
        sampled_indices_np = sampled_indices.cpu().numpy()

        # Rays before and after processing
        rays_before_get_ray = sampled_indices_np.copy()
        pred_rays_semantic = get_ray_from_semantics(sampled_indices_np, V=config["V"], F_W=config["F_W"])
        rays_after_get_ray = pred_rays_semantic.copy()

        # Depth localization
        prob_vol_pred_depth, _, _, _ = localize(
            torch.tensor(desdf["desdf"]),
            torch.tensor(pred_rays_depth, device="cpu"),
            return_np=False,
        )
        _, _, _, pose_pred_depth = finalize_localization(prob_vol_pred_depth)
        pose_pred_depth = torch.tensor(pose_pred_depth, device=device, dtype=torch.float32)
        pose_pred_depth[:2] = pose_pred_depth[:2] / 10
        acc_depth = torch.norm(pose_pred_depth[:2] - torch.tensor(ref_pose_map[:2], device=device), p=2).item()

        # Combined localization
        prob_vol_pred_semantic, _, _, _ = localize(
            torch.tensor(semantic["desdf"]),
            torch.tensor(pred_rays_semantic, device="cpu"),
            return_np=False,
        )
        combined_prob_vol_pred = 0.5 * prob_vol_pred_depth + 0.5 * prob_vol_pred_semantic
        _, _, _, pose_pred_combined = finalize_localization(combined_prob_vol_pred)
        pose_pred_combined = torch.tensor(pose_pred_combined, device=device, dtype=torch.float32)
        pose_pred_combined[:2] = pose_pred_combined[:2] / 10
        acc_combined = torch.norm(pose_pred_combined[:2] - torch.tensor(ref_pose_map[:2], device=device), p=2).item()

        # Check for improvement or degradation
        if acc_combined <= acc_depth:
            improved_cases["rays_before"].append(rays_before_get_ray)
            improved_cases["rays_after"].append(rays_after_get_ray)
            improved_cases["accuracies"].append((data_idx, acc_depth, acc_combined))
        elif acc_combined > acc_depth:
            degraded_cases["rays_before"].append(rays_before_get_ray)
            degraded_cases["rays_after"].append(rays_after_get_ray)
            degraded_cases["accuracies"].append((data_idx, acc_depth, acc_combined))

    # Print all results
    print("\nEvaluation Complete.")

    print("\nImproved Accuracy Cases:")
    print("All Rays Before Get Ray:")
    for rays in improved_cases["rays_before"]:
        print(rays)

    print("\nAll Rays After Get Ray:")
    for rays in improved_cases["rays_after"]:
        print(rays)

    print("\nImproved Accuracy Details:")
    for idx, acc_depth, acc_combined in improved_cases["accuracies"]:
        print(f"Image {idx}: Acc Depth = {acc_depth}, Acc Combined = {acc_combined}")

    print("\nDegraded Accuracy Cases:")
    print("All Rays Before Get Ray:")
    for rays in degraded_cases["rays_before"]:
        print(rays)

    print("\nAll Rays After Get Ray:")
    for rays in degraded_cases["rays_after"]:
        print(rays)

    print("\nDegraded Accuracy Details:")
    for idx, acc_depth, acc_combined in degraded_cases["accuracies"]:
        print(f"Image {idx}: Acc Depth = {acc_depth}, Acc Combined = {acc_combined}")


def evaluate_observation(config, device):
    dataset_dir = config["dataset_dir"]
    desdf_path = config["desdf_path"]
    log_dir_depth = config["log_dir_depth"]
    log_dir_semantic = config["log_dir_semantic"]
    split_file = config["split_file"]

    # Load dataset
    with open(split_file, "r") as f:
        split = AttrDict(yaml.safe_load(f))

    test_set = GridSeqDataset(
        dataset_dir,
        split.train[:config["num_of_scenes"]],
        L=config["L"],
    )

    depth_net = depth_net_pl.load_from_checkpoint(
        checkpoint_path=log_dir_depth,
        d_min=config["d_min"],
        d_max=config["d_max"],
        d_hyp=config["d_hyp"],
        D=config["D"],
    ).to(device)

    semantic_net = semantic_net_pl.load_from_checkpoint(
        checkpoint_path=log_dir_semantic,
        num_classes=config["num_classes"],
    ).to(device)

    desdfs, semantics, maps, gt_poses, valid_scene_names = load_scene_data(
        test_set, dataset_dir, desdf_path
    )

    evaluate_combined_model(
        depth_net,
        semantic_net,
        desdfs,
        semantics,
        test_set,
        gt_poses,
        maps,
        device,
        valid_scene_names,
        config=config,
    )


def main():
    # Hardcoded configuration
    config = {
        "dataset_dir": "/datadrive2/CRM.AI.Research/TeamFolders/Email/repo_yuval/FloorPlan/Semantic_Floor_plan_localization/data/test_data_set_full/structured3d_perspective_full",
        "desdf_path": "/datadrive2/CRM.AI.Research/TeamFolders/Email/repo_yuval/FloorPlan/Semantic_Floor_plan_localization/data/test_data_set_full/desdf",
        "log_dir_depth": "/datadrive2/CRM.AI.Research/TeamFolders/Email/repo_yuval/FloorPlan/Semantic_Floor_plan_localization/modules/Final_wights/depth/final_depth_model_checkpoint.ckpt",
        "log_dir_semantic": "/datadrive2/CRM.AI.Research/TeamFolders/Email/repo_yuval/FloorPlan/Semantic_Floor_plan_localization/modules/Final_wights/semantic/final_semantic_model_checkpoint.ckpt",
        "split_file": "/datadrive2/CRM.AI.Research/TeamFolders/Email/repo_yuval/FloorPlan/Semantic_Floor_plan_localization/data/test_data_set_full/structured3d_perspective_full/split.yaml",
        "L": 0,
        "D": 128,
        "d_min": 0.1,
        "d_max": 15.0,
        "d_hyp": -0.2,
        "F_W": 0.59587643422,
        "V": 7,
        "num_classes": 4,
        "num_of_scenes": 1,
    }

    # Set device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # Run evaluation
    evaluate_observation(config, device)


if __name__ == "__main__":
    main()
