In [None]:
import torch
import gzip
import os
import numpy as np
import matplotlib.pyplot as plt
import re
import cv2
import random


# Method to extract camera number from file name
def extract_camera_number(file_name):
    match = re.search(r'\d+', file_name)  # Find the first number in the file name
    if match:
        return int(match.group())  # Convert the matched number to an integer
    return float('inf')  # If no number is found, assign a very high number (or handle differently)

# Function to load a compressed tensor
def load_compressed_tensor(file_path):
    with gzip.open(file_path, 'rb') as f:
        tensor = torch.load(f)
    return tensor

# Function to read poses from a file
def read_poses(file_path):
    poses = []
    with open(file_path, 'r') as f:
        for line in f:
            x, y, _ = map(float, line.strip().split())
            poses.append((x, y))
    return poses

# Function to plot a single map using the provided method
def plot_single_map(prob_map, occ, true_pose, H, W, resolution, title, ax):
    prob_map = prob_map.astype(np.float32)
    prob_map = np.flipud(cv2.resize(prob_map, (prob_map.shape[1] * 10, prob_map.shape[0] * 10), interpolation=cv2.INTER_LINEAR))
    H = prob_map.shape[0]
    W = prob_map.shape[1]
    ax.imshow(prob_map, extent=[0, W, 0, H], cmap='viridis', alpha=0.6, origin='lower')

    # Find all positions with the maximum value
    max_value = np.max(prob_map)
    max_positions = list(zip(*np.where(prob_map == max_value)))  # List of (y, x) positions

    # Randomly select one position among the maxima
    random_max_y, random_max_x = random.choice(max_positions)

    # Add a red dot at the randomly selected maximum of the map
    ax.plot(random_max_x, random_max_y, 'o', markeredgecolor='red', markerfacecolor='none', markersize=10, label='Max Value', alpha=0.8)


    # Add a green dot for the true pose
    if true_pose:
        true_pose_x = true_pose[0] * (1 / resolution) * 10
        true_pose_y = H - true_pose[1] * (1 / resolution) * 10
        ax.plot(true_pose_x, true_pose_y, 'o', markeredgecolor='green', markerfacecolor='none', markersize=10, label='True Pose', alpha=0.8)

    ax.set_title(title, fontsize=10)
    ax.axis('off')
    ax.legend(fontsize=8)
    
# Modified Function to Plot Camera Heatmaps with Combinations
def plot_camera_heatmaps_with_combination(scene_dir, poses, num_cameras=5):
    # Collect all heatmaps
    heatmaps = []

    for file_name in os.listdir(scene_dir):
        if file_name.endswith("pred_depth_prob_vol.pt.gz"):  # Filter for depth_prob_vol files
            file_path = os.path.join(scene_dir, file_name)
            # Load the tensor
            depth_prob_vol = load_compressed_tensor(file_path)
            # Take the max over the O dimension (dim=2)
            depth_heatmap = depth_prob_vol.max(dim=2).values.numpy()
            # Compute the 95th percentile threshold
            threshold_0_05 = np.percentile(depth_heatmap, 99.9)
            threshold_0_1 = np.percentile(depth_heatmap, 99)
            
            # Top maps for visualization
            top_0_1_map = np.where(depth_heatmap >= threshold_0_1, depth_heatmap, depth_heatmap * 0)
            top_0_05_map = np.where(depth_heatmap >= threshold_0_05, depth_heatmap, depth_heatmap * 0)

            depth_unique_values = np.unique(depth_heatmap)
            depth_unique_values = np.sort(depth_unique_values)[::-1]  # Descending order

            # Select top 4 unique values
            depth_top_values = depth_unique_values[:4]  # Take top 4 values
            # print(f"top values semantic is: {top_values}")
            # Apply filters to include cumulative top values
            depth_top_1_value = np.where(depth_heatmap == depth_top_values[0], depth_heatmap, 0) if len(depth_top_values) > 0 else np.zeros_like(depth_heatmap)
            depth_top_2_value = np.where(np.isin(depth_heatmap, depth_top_values[:2]), depth_heatmap, 0) if len(depth_top_values) > 1 else np.zeros_like(depth_heatmap)
            
            # Append all maps and sampled indices
            heatmaps.append((file_name,depth_prob_vol, depth_heatmap, top_0_1_map, top_0_05_map))

    # Sort heatmaps by camera number
    heatmaps.sort(key=lambda x: extract_camera_number(x[0]))

    # Limit to the specified number of cameras
    heatmaps = heatmaps[:num_cameras]

    # Adjust the number of rows in the figure
    fig, axes = plt.subplots(len(heatmaps) * 2, 6, figsize=(24, len(heatmaps) * 12))

    for i, (file_name, depth_prob_vol, depth_heatmap, top_0_1_map, top_0_05_map) in enumerate(heatmaps):
        true_pose = poses[i] if i < len(poses) else None  # Match pose to the heatmap if available

        # Semantic file corresponding to the depth file
        semantic_file_name = file_name.replace("pred_depth_prob_vol.pt.gz", "pred_semantic_prob_vol.pt.gz")
        semantic_file_path = os.path.join(scene_dir, semantic_file_name)

        if os.path.exists(semantic_file_path):
            # Load semantic tensor
            semantic_prob_vol = load_compressed_tensor(semantic_file_path)
            semantic_heatmap = semantic_prob_vol.max(dim=2).values.numpy()

            unique_values = np.unique(semantic_heatmap)
            unique_values = np.sort(unique_values)[::-1]  # Descending order

            # Select top 4 unique values
            top_values = unique_values  # Take top 4 values
            # print(f"top values semantic is: {top_values}")
            # Apply filters to include cumulative top values
            semantic_top_1_value = np.where(semantic_heatmap == top_values[0], semantic_heatmap, 0) if len(top_values) > 0 else np.zeros_like(semantic_heatmap)
            semantic_top_2_value = np.where(np.isin(semantic_heatmap, top_values[:100]), semantic_heatmap, 0) if len(top_values) > 1 else np.zeros_like(semantic_heatmap)

            # Combine depth and semantic heatmaps
            combined_heatmap_after_max = 0.5 * semantic_heatmap + 0.5 * depth_heatmap
            combined_heatmap_before_max = 0.5 * depth_prob_vol + 0.5 * semantic_prob_vol
            combined_heatmap_before_max = combined_heatmap_before_max.max(dim=2).values.numpy()
            combined_threshold_0_1 = np.percentile(combined_heatmap_before_max, 99.9)
            combined_threshold_0_05 = np.percentile(combined_heatmap_before_max, 99.95)
            combined_after_threshold_0_1 = np.percentile(combined_heatmap_after_max, 99.9)
            combined_after_threshold_0_05 = np.percentile(combined_heatmap_after_max, 99.95)

            combined_top_0_1_map = np.where(combined_heatmap_before_max >= combined_threshold_0_1, combined_heatmap_before_max, 0)
            combined_top_0_05_map = np.where(combined_heatmap_before_max >= combined_threshold_0_05, combined_heatmap_before_max, 0)
            combined_after_top_0_1_map = np.where(combined_heatmap_after_max >= combined_after_threshold_0_1, combined_heatmap_after_max, 0)
            combined_after_top_0_05_map = np.where(combined_heatmap_after_max >= combined_after_threshold_0_05, combined_heatmap_after_max, 0)
            
            unique_values_combined = np.unique(combined_heatmap_before_max)
            unique_values_combined = np.sort(unique_values_combined)[::-1]  # Descending order

            # Select top 4 unique values
            top_values_combined = unique_values_combined[:4]  # Take top 4 values

            # Apply filters to include cumulative top values
            combined_top_1_value = np.where(combined_heatmap_before_max == top_values_combined[0], combined_heatmap_before_max, 0) if len(top_values) > 0 else np.zeros_like(combined_heatmap_before_max)
            combined_top_2_value = np.where(np.isin(combined_heatmap_before_max, top_values_combined[:2]), combined_heatmap_before_max, 0) if len(top_values) > 1 else np.zeros_like(combined_heatmap_before_max)
            
        else:
            semantic_heatmap = None
            semantic_top_0_1_map = semantic_top_0_05_map = None
            combined_top_0_1_map = combined_top_0_05_map = None

        # Plot original and thresholds for depth and semantic heatmaps
        plot_single_map(depth_heatmap, None, true_pose, depth_heatmap.shape[0], depth_heatmap.shape[1], 0.1, f"Original Depth: {file_name}", axes[i * 2, 0])
        plot_single_map(top_0_1_map, None, true_pose, depth_heatmap.shape[0], depth_heatmap.shape[1], 0.1, f"top 0.1%: {file_name}", axes[i * 2, 1])
        plot_single_map(top_0_05_map, None, true_pose, depth_heatmap.shape[0], depth_heatmap.shape[1], 0.1, f"top 1%: {file_name}", axes[i * 2, 2])
        plot_single_map(combined_heatmap_before_max, None, true_pose, combined_heatmap_before_max.shape[0], combined_heatmap_before_max.shape[1], 0.1, f"combined_heatmap_before_max", axes[i * 2 , 3])
        plot_single_map(combined_top_0_1_map, None, true_pose, combined_heatmap_before_max.shape[0], combined_heatmap_before_max.shape[1], 0.1, f"Combined 0.1%", axes[i * 2 , 4])
        plot_single_map(combined_top_0_05_map, None, true_pose, combined_heatmap_before_max.shape[0], combined_heatmap_before_max.shape[1], 0.1, f"Combined 0.05%", axes[i * 2 , 5])

        if semantic_heatmap is not None:
            plot_single_map(semantic_heatmap, None, true_pose, semantic_heatmap.shape[0], semantic_heatmap.shape[1], 0.1, f"Original Semantic: {semantic_file_name}", axes[i * 2 +1, 0])
            plot_single_map(semantic_top_2_value, None, true_pose, semantic_heatmap.shape[0], semantic_heatmap.shape[1], 0.1, f"Top 2 Semantic: {semantic_file_name}", axes[i * 2 + 1, 1])
            plot_single_map(semantic_top_1_value, None, true_pose, semantic_heatmap.shape[0], semantic_heatmap.shape[1], 0.1, f"Top 1 Semantic: {semantic_file_name}", axes[i * 2 + 1, 2])
            plot_single_map(combined_heatmap_after_max, None, true_pose, semantic_heatmap.shape[0], semantic_heatmap.shape[1], 0.1, f"combined_heatmap_after_max", axes[i * 2 + 1, 3])
            plot_single_map(combined_after_top_0_1_map, None, true_pose, combined_after_top_0_1_map.shape[0], combined_after_top_0_1_map.shape[1], 0.1, f"Combined 0.1%", axes[i * 2 +1 , 4])
            plot_single_map(combined_after_top_0_05_map, None, true_pose, combined_after_top_0_05_map.shape[0], combined_after_top_0_05_map.shape[1], 0.1, f"Combined 0.05%", axes[i * 2 +1 , 5])
            

    plt.tight_layout()
    plt.show()

# Call the function with example usage
scene_num = 0
scene_dir = f"/datadrive2/CRM.AI.Research/TeamFolders/Email/repo_yuval/FloorPlan/Semantic_Floor_plan_localization/data/test_data_set_full/prob_vols/scene_{scene_num}"
poses_file = f"/datadrive2/CRM.AI.Research/TeamFolders/Email/repo_yuval/FloorPlan/Semantic_Floor_plan_localization/data/test_data_set_full/structured3d_perspective_full/scene_{scene_num}/poses.txt"

poses = read_poses(poses_file)
num_cameras = 20  # Number of cameras to process
plot_camera_heatmaps_with_combination(scene_dir, poses, num_cameras=num_cameras)

