# 3D mask proposal

In [None]:
import numpy as np
import torch
from scripts.utils import load_ply
from pytorch3d.structures import Meshes,Pointclouds
from pytorch3d.renderer import Textures
from pytorch3d.io import load_obj
from point_sam.build_model import build_point_sam
from sklearn.neighbors import NearestNeighbors
from tqdm import tqdm
import pytorch3d.ops as ops
from pytorch3d.ops import sample_farthest_points
from utils.nms import apply_pointwise_nms,visualize_point_clouds_with_masks
from mask_proposal import  mask_proposal,mask_proposal_v2,batch_mask_proposal
from utils.render import render_all_angles_pc,render_single_view,project_3d_to_2d
import glob
from point_sam.build_model import build_point_sam
import numpy as np
import torch
from scripts.utils import load_ply
from pytorch3d.structures import Meshes
from pytorch3d.renderer import Textures
from pytorch3d.io import load_obj
# Use glob to access all files in the directory
from transformers import AutoProcessor, AutoModelForCausalLM
import random
import os
from utils.inference_florence import run_florence2
from PIL import Image
import cv2
import supervision as sv
import open3d as o3d

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def load_prediction_data(filename):
    with open(filename, 'r') as file:
        return [{'file': parts[0], 'prediction': int(parts[1]), 'confidence': float(parts[2])}
                for line in file if len(parts := line.strip().split()) == 3]

def normalize_point_cloud(xyz):
    centroid = np.mean(xyz, axis=0)
    xyz_centered = xyz - centroid
    furthest_distance = np.max(np.sqrt(np.sum(xyz_centered**2, axis=1)))
    return xyz_centered / furthest_distance

def process_scene(scene_id, scene_path, mask_info_path, model, output_dir,mask_infos):
    pcd = o3d.io.read_point_cloud(scene_path)
    xyz = np.asarray(pcd.points)
    rgb = np.asarray(pcd.colors) * 255

#     rotation_matrix = o3d.geometry.get_rotation_matrix_from_xyz((-np.pi/2, 0, 0))

# # Rotate point cloud

# # Rotate mesh
#     pcd = pcd.rotate(rotation_matrix, center=(0, 0, 0)) 
    
    for idx, mask_info in enumerate(mask_infos):
        mask = np.loadtxt(os.path.join(os.path.dirname(mask_info_path), mask_info['file'])).astype(bool)
        obj_xyz = normalize_point_cloud(xyz[mask])
        obj_rgb = rgb[mask]
        
        obj_xyz_tensor = torch.tensor(obj_xyz).to(device).float()
        obj_rgb_tensor = torch.tensor(obj_rgb).to(device).float()

         
        obj_pcd = Pointclouds(points=[obj_xyz_tensor], features=[obj_rgb_tensor])
        obj_xyz_tensor = obj_xyz_tensor.unsqueeze(0)
        obj_rgb_tensor = obj_rgb_tensor.unsqueeze(0)
        top_k_masks, _, _ = mask_proposal(obj_xyz_tensor, obj_rgb_tensor, NUM_PROMPTS, model)
        #instance_pcd
        img_dir, pc_depth, screen_coords, num_views, cameras = render_all_angles_pc(obj_pcd, os.path.join(output_dir, str(idx)), device)
        # save top_k_masks,pc_depth,screen_coords as pt
        # save obj_xyz as np
        # make a new directoy under the os.path.join(output_dir, str(idx)) called ins_info
        instance_info_dir = os.path.join(output_dir, str(idx), 'ins_info')
        os.makedirs(instance_info_dir, exist_ok=True)

        # Save top_k_masks, pc_depth, and screen_coords as pt files
        torch.save(top_k_masks, os.path.join(instance_info_dir, 'top_k_masks.pt'))
        torch.save(pc_depth, os.path.join(instance_info_dir, 'pc_depth.pt'))
        torch.save(screen_coords, os.path.join(instance_info_dir, 'screen_coords.pt'))

        # Save obj_xyz as numpy array
        np.save(os.path.join(instance_info_dir, 'obj_xyz.npy'), obj_xyz)

        print(f"Saved instance information and point cloud data to {instance_info_dir}")
    return top_k_masks, img_dir, pc_depth, screen_coords, num_views, cameras,obj_xyz

# Main execution
if __name__ == "__main__":
    NUM_PROMPTS = 1024
    NUM_MASKS_PER_PROMPT = 3
    NMS_THRESHOLD = 0.3
    TOP_K_PROPOSALS = 250

    
    dataset_dir = '/home/wan/Datasets/Test_scene/part_valid'
    project_path = '/home/wan/Workplace-why/PartScene'
    final_masks_save_dir = os.path.join(project_path, 'part_scene_results')
    by_product_save_dir = 'part_scene_saved'
    ckpt_path = os.path.join(project_path, "checkpoints/model.safetensors")

    model = build_point_sam(ckpt_path, 512, 64).to(device)
    print('Model built successfully')

    for scene_id in os.listdir(dataset_dir):
        print(scene_id)
        scene_path = os.path.join(dataset_dir, scene_id, f'points_{scene_id}.ply')
        mask_info_path = os.path.join(final_masks_save_dir, scene_id, f'{scene_id}_summary.txt')
        output_dir = os.path.join(project_path, by_product_save_dir, scene_id)
        os.makedirs(output_dir, exist_ok=True)
        mask_infos = load_prediction_data(mask_info_path)

        # i
        top_k_masks,img_dir, pc_depth, screen_coords, num_views, cameras,obj_xyz  = process_scene(scene_id, scene_path, mask_info_path, model, output_dir,mask_infos)


# Segment 2d

In [None]:
import numpy as np
import torch
from scripts.utils import load_ply
from pytorch3d.structures import Meshes,Pointclouds
from pytorch3d.renderer import Textures
from pytorch3d.io import load_obj
from point_sam.build_model import build_point_sam
from sklearn.neighbors import NearestNeighbors
from tqdm import tqdm
import glob
import numpy as np
import torch
# Use glob to access all files in the directory
from transformers import AutoProcessor, AutoModelForCausalLM
from third_party.torkit3d.config.config import * 
import random
import os
from utils.inference_florence import run_florence2
from PIL import Image
import cv2
import supervision as sv
from utils.utils_3d import * 
from third_party.Ground_SAM.sam2.build_sam import build_sam2
from third_party.Ground_SAM.sam2.sam2_image_predictor import SAM2ImagePredictor
from third_party.Ground_SAM.mask_proposal_2d import segment2d
import open3d as o3d
import json 


def load_instance_info(instance_info_dir):
    """
    Load instance information and point cloud data from the given directory.
    
    Args:
    instance_info_dir (str): Path to the directory containing the saved files.
    
    Returns:
    dict: A dictionary containing the loaded data.
    """
    # Load PyTorch tensors
    top_k_masks = torch.load(os.path.join(instance_info_dir, 'top_k_masks.pt'))
    pc_depth = torch.load(os.path.join(instance_info_dir, 'pc_depth.pt'))
    screen_coords = torch.load(os.path.join(instance_info_dir, 'screen_coords.pt'))
    
    # Load numpy array
    obj_xyz = np.load(os.path.join(instance_info_dir, 'obj_xyz.npy'))
    
    # Create a dictionary to hold all the loaded data
    
    
    
    return top_k_masks,pc_depth,screen_coords,obj_xyz



import os 


       
FLORENCE2_MODEL_ID = "microsoft/Florence-2-large"
SAM2_CHECKPOINT = "/home/wan/Workplace-why/PartScene/third_party/Ground_SAM/checkpoints/sam2_hiera_large.pt"
SAM2_CONFIG = "sam2_hiera_l.yaml"

torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()

if torch.cuda.get_device_properties(0).major >= 8:
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
reversed_dict = {value: key for key, value in cls_dict.items()}

# build florence-2
florence2_model = AutoModelForCausalLM.from_pretrained(FLORENCE2_MODEL_ID, trust_remote_code=True, torch_dtype='auto').eval().to(device)
florence2_processor = AutoProcessor.from_pretrained(FLORENCE2_MODEL_ID, trust_remote_code=True)
# build sam 2
sam2_model = build_sam2(SAM2_CONFIG, SAM2_CHECKPOINT, device=device)
sam2_predictor = SAM2ImagePredictor(sam2_model)






dataset_dir = '/home/wan/Datasets/Test_scene/part_valid'
project_path = '/home/wan/Workplace-why/PartScene'
final_masks_save_dir = os.path.join(project_path, 'part_scene_results')
by_product_save_dir = 'part_scene_saved'
ckpt_path = os.path.join(project_path, "checkpoints/model.safetensors")


for scene_id in tqdm(os.listdir(dataset_dir)):
        print(scene_id)
        scene_path = os.path.join(dataset_dir, scene_id, f'points_{scene_id}.ply')
        mask_result_path = os.path.join(final_masks_save_dir, scene_id)
        output_scene_dir = os.path.join(project_path, by_product_save_dir, scene_id)
        mask_infos = load_prediction_data( f'{mask_result_path}/{scene_id}_summary.txt')
        for idx,mask in enumerate(mask_infos):
                ins_num = mask['prediction']
                instance_dir = os.path.join(output_scene_dir,str(idx))
                top_k_masks,pc_depth,screen_coords,obj_xyz = load_instance_info(f'{instance_dir}/ins_info')
                ins =reversed_dict[ins_num]
                prompt = cls_part_dict[ins]
                file_paths = glob.glob(os.path.join(f'{instance_dir}/rendered_img', '*'))
                points_3d =[]
                visible_pts_list = []
                # Print all the files found
                num_views = pc_depth.shape[0]
                task_prompt = "<OPEN_VOCABULARY_DETECTION>"
                text_input = prompt
                # torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__(
                torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

                result_dict = segment2d(num_views = num_views ,save_dir=instance_dir,text_input=text_input,task_prompt=task_prompt,florence2_model=florence2_model,florence2_processor=florence2_processor,sam2_predictor= sam2_predictor)
                torch.save(result_dict, os.path.join(f'{instance_dir}/ins_info', 'sem_seg.pt'))
                



# Mask clasification

In [10]:
from utils.utils_3d import * 
import torch
import numpy as np
import open3d as o3d 
from matplotlib import pyplot as plt 
import os 
import json
import glob
from utils.process import *
import sys
import glob
import shutil
import glob
import re
from third_party.torkit3d.config.config import * 
from tqdm import tqdm

In [11]:
def save_mask_results(scene_id, part_mask_after_process, scene_pcd, ins_mask, ins, output_dir, part_label_v2):
    try:
        scene_dir = os.path.join(output_dir, f'{scene_id}')
        pred_part_mask_dir = os.path.join(scene_dir, 'pred_part_mask')
    except PermissionError:
        print(f"Error: Permission denied when trying to create directory: {pred_part_mask_dir}")
        print("Please check that you have write permissions for the output directory.")
        print(f"Current working directory: {os.getcwd()}")
        print(f"Output directory path: {output_dir}")
        sys.exit(1)
    except Exception as e:
        print(f"An unexpected error occurred while creating directories: {str(e)}")
        sys.exit(1)
    
    summary_data = []
    base_cls = ins.split(' ')[-1].lower()

    # Get the current highest index in the pred_part_mask directory
    existing_files = glob.glob(os.path.join(pred_part_mask_dir, '*.txt'))
    numeric_files = [f for f in existing_files if re.match(r'^\d+\.txt$', os.path.basename(f))]
    if numeric_files:
        highest_idx = max([int(os.path.splitext(os.path.basename(f))[0]) for f in numeric_files])
        start_idx = highest_idx + 1
    else:
        start_idx = 0

    for idx, (label_key, data) in enumerate(part_mask_after_process.items(), start=start_idx):
        part_mask = data['mask']
        part_score = data['score']
        
        # Map parts to scene
        scene_part_mask = map_parts_to_scene(scene_pcd, part_mask, ins_mask)
        
        # Save individual mask file
        mask_filename = f'{idx:03d}.txt'
        mask_filepath = os.path.join(pred_part_mask_dir, mask_filename)
        
        # Convert scene_part_mask to integer numpy array
        if isinstance(scene_part_mask, np.ndarray):
            mask_to_save = scene_part_mask.astype(int)
        else:
            mask_to_save = np.array(scene_part_mask, dtype=int)
        
        try:
            # Save the mask
            np.savetxt(mask_filepath, mask_to_save, fmt='%d')
        except PermissionError:
            print(f"Error: Permission denied when trying to save file: {mask_filepath}")
            print("Please check that you have write permissions for the output directory.")
            sys.exit(1)
        except Exception as e:
            print(f"An unexpected error occurred while saving mask file: {str(e)}")
            sys.exit(1)
        
        # Get part label and number
        part_label = f'{base_cls}_{label_key}'
        # print(part_label)
        
        part_label_num = part_label_v2[part_label]  # This will return None if label not found
        # Append to summary data
        summary_data.append(f"pred_part_mask/{mask_filename} {part_label_num} {part_score:.4f}")
    
    # Save part summary file at the same level as regular summary
    part_summary_filepath = os.path.join(scene_dir, f'{scene_id}_part_summary.txt')
    
    try:
        # Append to part summary file
        with open(part_summary_filepath, 'a') as f:
            f.write('\n'.join(summary_data) + '\n')
        print(f"Part summary appended in {part_summary_filepath}")
    except PermissionError:
        print(f"Error: Permission denied when trying to save file: {part_summary_filepath}")
        print("Please check that you have write permissions for the output directory.")
        sys.exit(1)
    except Exception as e:
        print(f"An unexpected error occurred while saving part summary file: {str(e)}")
        sys.exit(1)

In [12]:
dataset_dir = '/home/wan/Datasets/Test_scene/part_valid'
project_path = '/home/wan/Workplace-why/PartScene'
output_dir = '/home/wan/Workplace-why/PartScene/part_scene_results'
final_masks_save_dir = os.path.join(project_path, 'part_scene_results')
by_product_save_dir = 'part_scene_saved'
ckpt_path = os.path.join(project_path, "checkpoints/model.safetensors")
reversed_dict = {value: key for key, value in cls_dict.items()}




for scene_id in tqdm(os.listdir('part_scene_saved')[:]):
        print(scene_id)
        if scene_id == '0055':
                continue
        scene_path = os.path.join(dataset_dir, scene_id, f'points_{scene_id}.ply')
        mask_result_path = os.path.join(final_masks_save_dir, scene_id)
        output_scene_dir = os.path.join(project_path, by_product_save_dir, scene_id)
        mask_infos = load_prediction_data( f'{mask_result_path}/{scene_id}_summary.txt')
        pred_part_mask_dir = os.path.join(mask_result_path, 'pred_part_mask')

                # Check if the directory exists and remove it if necessary
        if os.path.exists(pred_part_mask_dir):
                        shutil.rmtree(pred_part_mask_dir)
                # Create the directory
        os.makedirs(pred_part_mask_dir)

                # Check if the summary file exists
        summary_file = os.path.join(mask_result_path, f'{scene_id}_part_summary.txt')
        if os.path.exists(summary_file):
                # Handle the case when the summary file exists (if needed)
                os.remove(summary_file)

        for idx,mask in enumerate(mask_infos):
                ins_num = mask['prediction']
                mask_file = mask['file']
                instance_dir = os.path.join(output_scene_dir,str(idx))
                top_k_masks,pc_depth,screen_coords,obj_xyz = load_instance_info(f'{instance_dir}/ins_info')
                ins =reversed_dict[ins_num]
                prompt = cls_part_dict[ins]
                file_paths = glob.glob(os.path.join(f'{instance_dir}/rendered_img', '*'))
                points_3d =[]
                visible_pts_list = []
                # Print all the files found
                num_views = pc_depth.shape[0]
                text_input = prompt
                # load the segment result:
                result_dict = torch.load(os.path.join(f'{instance_dir}/ins_info', 'sem_seg.pt'))
                ins_mask =np.loadtxt(f'{mask_result_path}/{mask_file}').astype('bool')
                mask2d_view_list, mask_2d_bbox_correspondences, binary_masks_list = project_3d_to_2d(obj_xyz, top_k_masks, screen_coords, pc_depth)
                target_3d_masks = process_masks_and_calculate_iou(result_dict, num_views, binary_masks_list, 0,0.1)
                final_predictions = assign_labels_to_masks(result_dict, target_3d_masks, num_views, N=2)
                scene_pcd = o3d.io.read_point_cloud(f'/home/wan/Datasets/Test_scene/part_valid/{scene_id}/points_{scene_id}.ply')




                part_mask_after_process = process_mask_results(final_predictions,top_k_masks)
                # save_mask_results(scene_id, part_mask_after_process, scene_pcd, ins_mask, ins, output_dir, part_label_v2)      
        # break  

  top_k_masks = torch.load(os.path.join(instance_info_dir, 'top_k_masks.pt'))
  pc_depth = torch.load(os.path.join(instance_info_dir, 'pc_depth.pt'))
  screen_coords = torch.load(os.path.join(instance_info_dir, 'screen_coords.pt'))
  result_dict = torch.load(os.path.join(f'{instance_dir}/ins_info', 'sem_seg.pt'))


0055
0197
Final predictions:
Final predictions:


  1%|          | 2/300 [00:06<15:42,  3.16s/it]

Final predictions:
0187
Final predictions:
Final predictions:


  1%|          | 2/300 [00:10<27:12,  5.48s/it]


KeyboardInterrupt: 

# Evaluation

In [2]:
import math
import os, sys, argparse
import inspect
from copy import deepcopy
from uuid import uuid4
import numpy as np
import torch
import glob

from scipy import stats
from utils.eval_util import *
import \
    utils.eval_util as util_3d
# import wandb
import numpy as np
from collections import defaultdict

def convert_new_dataset_to_gt_instances(gt_ids, gt_dict, CLASS_LABELS, VALID_CLASS_IDS, ID_TO_LABEL):
    # Create a mapping from gt_dict labels to standardized labels


    # Initialize the gt_instances dictionary
    gt_instances = {label: [] for label in CLASS_LABELS}

    # Count the number of points for each instance
    instance_point_counts = defaultdict(int)
    for id in gt_ids:
        instance_point_counts[id] += 1

    # Process each unique instance
    for instance_id, count in instance_point_counts.items():
        # Get the label from gt_dict and map it to the standardized label
        original_label = gt_dict[str(instance_id)]
        standardized_label = util_3d.label_mapping.get(original_label, original_label)

        # Find the corresponding label_id in VALID_CLASS_IDS
        try:
            label_id = VALID_CLASS_IDS[CLASS_LABELS.index(standardized_label)]
        except ValueError:
            continue

        # Create the instance dictionary
        instance_dict = {
            'instance_id': int(instance_id),
            'label_id': int(label_id),
            'vert_count': int(count),
            'med_dist': -1,
            'dist_conf': 0.0
        }

        # Add the instance to the corresponding label in gt_instances
        gt_instances[standardized_label].append(instance_dict)

    return gt_instances


def identify_void_areas(gt_ids, gt_dict, CLASS_LABELS, VALID_CLASS_IDS):
    # Create a mapping from gt_dict labels to standardized labels


    # Create a mapping from gt_ids to standardized class labels
    id_to_standard_label = {
        int(id): util_3d.label_mapping.get(label, label) 
        for id, label in gt_dict.items()
    }

    # Create a set of valid standardized labels
    valid_labels = set(CLASS_LABELS)

    # Function to check if a gt_id is valid
    def is_valid(id):
        return id_to_standard_label.get(id, '') in valid_labels

    # Create boolean array indicating void areas
    bool_void = np.vectorize(lambda x: not is_valid(x))(gt_ids)

    return bool_void

# Example usage:

def get_args():
    
    '''Command line arguments.'''

    parser = argparse.ArgumentParser(description='OpenIns3D evaluation')
    parser.add_argument('--result_save', default="scannet_results", type=str, help='Path of detection results')
    parser.add_argument('--gt_path', default="data/processed/s3dis/instance_gt/Area_5", help='Path of gt instance')
    parser.add_argument('--dataset', default="part_scene", help='dataset for evaluation, could be s3dis, scannet, stpls3d')
    args = parser.parse_args()
    return args

args = get_args()
dataset = args.dataset






ID_TO_LABEL = {}
LABEL_TO_ID = {}
for i in range(len(VALID_CLASS_IDS)):
    LABEL_TO_ID[CLASS_LABELS[i]] = VALID_CLASS_IDS[i]
    ID_TO_LABEL[VALID_CLASS_IDS[i]] = CLASS_LABELS[i]
# ---------- Evaluation params ---------- #
opt = {}
opt["overlaps"] = np.append(np.append(np.arange(0.1, 0.95, 0.05),0.24),0.324)

print(opt["overlaps"])
# minimum region size for evaluation [verts]
opt["min_region_sizes"] = np.array([0])  # 100 for s3dis, scannet
# distance thresholds [m]
opt["distance_threshes"] = np.array([float("inf")])
# distance confidences
opt["distance_confs"] = np.array([-float("inf")])


def evaluate_matches(matches):
    overlaps = opt["overlaps"]
    min_region_sizes = [opt["min_region_sizes"][0]]
    dist_threshes = [opt["distance_threshes"][0]]
    dist_confs = [opt["distance_confs"][0]]

    # results: class x overlap
    ap = np.zeros(
        (len(dist_threshes), len(CLASS_LABELS), len(overlaps)), float
    )
    for di, (min_region_size, distance_thresh, distance_conf) in enumerate(
        zip(min_region_sizes, dist_threshes, dist_confs)
    ):
        for oi, overlap_th in enumerate(overlaps):
            pred_visited = {}
            for m in matches:
                for p in matches[m]["pred"]:
                    for label_name in CLASS_LABELS:
                        for p in matches[m]["pred"][label_name]:
                            if "uuid" in p:
                                pred_visited[p["uuid"]] = False
            for li, label_name in enumerate(CLASS_LABELS):
                y_true = np.empty(0)
                y_score = np.empty(0)
                hard_false_negatives = 0
                has_gt = False
                has_pred = False
                for m in matches:
                    pred_instances = matches[m]["pred"][label_name]
                    gt_instances = matches[m]["gt"][label_name]
                    # filter groups in ground truth
                    gt_instances = [
                        gt
                        for gt in gt_instances
                        # if gt["instance_id"] >= 1000
                        if gt["vert_count"] >= min_region_size
                        and gt["med_dist"] <= distance_thresh
                        and gt["dist_conf"] >= distance_conf
                    ]
                    if gt_instances:
                        has_gt = True
                    if pred_instances:
                        has_pred = True

                    cur_true = np.ones(len(gt_instances))
                    cur_score = np.ones(len(gt_instances)) * (-float("inf"))
                    cur_match = np.zeros(len(gt_instances), dtype=bool)
                    # collect matches
                    for (gti, gt) in enumerate(gt_instances):
                        found_match = False
                        num_pred = len(gt["matched_pred"])
                        for pred in gt["matched_pred"]:
                            # greedy assignments
                            if pred_visited[pred["uuid"]]:
                                continue
                            overlap = float(pred["intersection"]) / (
                                gt["vert_count"]
                                + pred["vert_count"]
                                - pred["intersection"]
                            )
                            if overlap > overlap_th:
                                confidence = pred["confidence"]
                                # if already have a prediction for this gt,
                                # the prediction with the lower score is automatically a false positive
                                if cur_match[gti]:
                                    max_score = max(cur_score[gti], confidence)
                                    min_score = min(cur_score[gti], confidence)
                                    cur_score[gti] = max_score
                                    # append false positive
                                    cur_true = np.append(cur_true, 0)
                                    cur_score = np.append(cur_score, min_score)
                                    cur_match = np.append(cur_match, True)
                                # otherwise set score
                                else:
                                    found_match = True
                                    cur_match[gti] = True
                                    cur_score[gti] = confidence
                                    pred_visited[pred["uuid"]] = True
                        if not found_match:
                            hard_false_negatives += 1
                    # remove non-matched ground truth instances
                    cur_true = cur_true[cur_match == True]
                    cur_score = cur_score[cur_match == True]

                    # collect non-matched predictions as false positive
                    for pred in pred_instances:
                        found_gt = False
                        for gt in pred["matched_gt"]:
                            overlap = float(gt["intersection"]) / (
                                gt["vert_count"]
                                + pred["vert_count"]
                                - gt["intersection"]
                            )
                            if overlap > overlap_th:
                                found_gt = True
                                break
                        if not found_gt:
                            num_ignore = pred["void_intersection"]
                            for gt in pred["matched_gt"]:
                                # group?
                                if gt["instance_id"] < 1000:
                                    num_ignore += gt["intersection"]
                                # small ground truth instances
                                if (
                                    gt["vert_count"] < min_region_size
                                    or gt["med_dist"] > distance_thresh
                                    or gt["dist_conf"] < distance_conf
                                ):
                                    num_ignore += gt["intersection"]
                            proportion_ignore = (
                                float(num_ignore) / pred["vert_count"]
                            )
                            # if not ignored append false positive
                            if proportion_ignore <= overlap_th:
                                cur_true = np.append(cur_true, 0)
                                confidence = pred["confidence"]
                                cur_score = np.append(cur_score, confidence)

                    # append to overall results
                    y_true = np.append(y_true, cur_true)
                    y_score = np.append(y_score, cur_score)

                # compute average precision
                if has_gt and has_pred:
                    # compute precision recall curve first

                    # sorting and cumsum
                    score_arg_sort = np.argsort(y_score)
                    y_score_sorted = y_score[score_arg_sort]
                    y_true_sorted = y_true[score_arg_sort]
                    y_true_sorted_cumsum = np.cumsum(y_true_sorted)

                    # unique thresholds
                    (thresholds, unique_indices) = np.unique(
                        y_score_sorted, return_index=True
                    )
                    num_prec_recall = len(unique_indices) + 1

                    # prepare precision recall
                    num_examples = len(y_score_sorted)
                    # https://github.com/ScanNet/ScanNet/pull/26
                    # all predictions are non-matched but also all of them are ignored and not counted as FP
                    # y_true_sorted_cumsum is empty
                    # num_true_examples = y_true_sorted_cumsum[-1]
                    num_true_examples = (
                        y_true_sorted_cumsum[-1]
                        if len(y_true_sorted_cumsum) > 0
                        else 0
                    )
                    precision = np.zeros(num_prec_recall)
                    recall = np.zeros(num_prec_recall)

                    # deal with the first point
                    y_true_sorted_cumsum = np.append(y_true_sorted_cumsum, 0)
                    # deal with remaining
                    for idx_res, idx_scores in enumerate(unique_indices):
                        cumsum = y_true_sorted_cumsum[idx_scores - 1]
                        tp = num_true_examples - cumsum
                        fp = num_examples - idx_scores - tp
                        fn = cumsum + hard_false_negatives
                        p = float(tp) / (tp + fp)
                        r = float(tp) / (tp + fn)
                        precision[idx_res] = p
                        recall[idx_res] = r

                    # first point in curve is artificial
                    precision[-1] = 1.0
                    recall[-1] = 0.0

                    # compute average of precision-recall curve
                    recall_for_conv = np.copy(recall)
                    recall_for_conv = np.append(
                        recall_for_conv[0], recall_for_conv
                    )
                    recall_for_conv = np.append(recall_for_conv, 0.0)

                    stepWidths = np.convolve(
                        recall_for_conv, [-0.5, 0, 0.5], "valid"
                    )
                    # integrate is now simply a dot product
                    ap_current = np.dot(precision, stepWidths)

                elif has_gt:
                    ap_current = 0.0
                else:
                    ap_current = float("nan")
                ap[di, li, oi] = ap_current
    return ap




def make_pred_info(pred: dict):
    # pred = {'pred_scores' = 100, 'pred_classes' = 100 'pred_masks' = Nx100}
    pred_info = {}
    assert (
        pred["pred_classes"].shape[0]
        == pred["pred_scores"].shape[0]
        == pred["pred_masks"].shape[1]
    )
    for i in range(len(pred["pred_classes"])):
        info = {}
        info["label_id"] = pred["pred_classes"][i]
        info["conf"] = pred["pred_scores"][i]
        info["mask"] = pred["pred_masks"][:, i]
        pred_info[uuid4()] = info  # we later need to identify these objects
    return pred_info


def assign_instances_for_scan(pred: dict, gt_file: str, gt_dict: dict):
    pred_info = make_pred_info(pred)
    try:
        gt_ids = util_3d.load_ids(gt_file)
    except Exception as e:
        util_3d.print_error("unable to load " + gt_file + ": " + str(e))

    #load the gt dict

    # get gt instances

    # # breakpoint()
    # gt_instances = util_3d.get_instances(
    #     gt_ids, VALID_CLASS_IDS, CLASS_LABELS, ID_TO_LABEL
    # )

    gt_instances = convert_new_dataset_to_gt_instances(gt_ids, gt_dict, CLASS_LABELS, VALID_CLASS_IDS, ID_TO_LABEL)

    # breakpoint()
    # associate
    gt2pred = deepcopy(gt_instances)
    for label in gt2pred:
        for gt in gt2pred[label]:
            gt["matched_pred"] = []
    pred2gt = {}
    for label in CLASS_LABELS:
        pred2gt[label] = []
    num_pred_instances = 0
    # mask of void labels in the groundtruth
    # breakpoint()
    bool_void = identify_void_areas(gt_ids, gt_dict, CLASS_LABELS, VALID_CLASS_IDS)
    # go thru all prediction masks
    for uuid in pred_info:
        label_id = int(pred_info[uuid]["label_id"])
        conf = pred_info[uuid]["conf"]
        if not label_id in ID_TO_LABEL:
            continue
        label_name = ID_TO_LABEL[label_id]
        # read the mask
        pred_mask = pred_info[uuid]["mask"]
        assert len(pred_mask) == len(gt_ids)
        # convert to binary
        pred_mask = np.not_equal(pred_mask, 0)
        num = np.count_nonzero(pred_mask)
        if num < opt["min_region_sizes"][0]:
            continue  # skip if empty

        pred_instance = {}
        pred_instance["uuid"] = uuid
        pred_instance["pred_id"] = num_pred_instances
        pred_instance["label_id"] = label_id
        pred_instance["vert_count"] = num
        pred_instance["confidence"] = conf
        pred_instance["void_intersection"] = np.count_nonzero(
            np.logical_and(bool_void, pred_mask)
        )

        # matched gt instances
        matched_gt = []
        # go thru all gt instances with matching label

        for (gt_num, gt_inst) in enumerate(gt2pred[label_name]):
            intersection = np.count_nonzero(
                np.logical_and(gt_ids == gt_inst["instance_id"], pred_mask)
            )
            if intersection > 0:
                gt_copy = gt_inst.copy()
                pred_copy = pred_instance.copy()
                gt_copy["intersection"] = intersection
                pred_copy["intersection"] = intersection
                matched_gt.append(gt_copy)
                gt2pred[label_name][gt_num]["matched_pred"].append(pred_copy)
        pred_instance["matched_gt"] = matched_gt
        num_pred_instances += 1
        pred2gt[label_name].append(pred_instance)

    return gt2pred, pred2gt


def print_results(avgs):
    # wandb.login(key='d27f3b3e72d749fb99315e0e86c6b36b6e23617e')
    # wandb.init(project="3D Open World Understanding}",
    #                    name='OpenIns3D')
    sep = ""
    col1 = ":"
    lineLen = 64

    print("")
    print("#" * lineLen)
    line = ""
    line += "{:<15}".format("what") + sep + col1
    line += "{:>15}".format("AP") + sep
    line += "{:>15}".format("AP_50%") + sep
    line += "{:>15}".format("AP_25%") + sep
    print(line)
    print("#" * lineLen)
    columns = ['Class','AP','AP_50%','AP_25%']
    # result_table = wandb.Table(columns=columns)
    for (li, label_name) in enumerate(CLASS_LABELS):
        ap_avg = avgs["classes"][label_name]["ap"]
        ap_50o = avgs["classes"][label_name]["ap50%"]
        ap_25o = avgs["classes"][label_name]["ap25%"]
        # line = "{:<15}".format(label_name) + sep + col1
        # line += sep + "{:>15.3f}".format(ap_avg) + sep
        # line += sep + "{:>15.3f}".format(ap_50o) + sep
        # line += sep + "{:>15.3f}".format(ap_25o) + sep
        # print(line)
        # result_table.add_data(label_name, ap_avg, ap_50o, ap_25o)

    all_ap_avg = avgs["all_ap"]
    all_ap_50o = avgs["all_ap_50%"]
    all_ap_25o = avgs["all_ap_25%"]
    # wandb.log({"AP":all_ap_avg,"AP_50":all_ap_50o,"AP_25":all_ap_25o})
    # wandb.log({"Class_AP":result_table})
    print("-" * lineLen)
    line = "{:<15}".format("average") + sep + col1
    line += "{:>15.3f}".format(all_ap_avg) + sep
    line += "{:>15.3f}".format(all_ap_50o) + sep
    line += "{:>15.3f}".format(all_ap_25o) + sep
    print(line)
    print("")




def evaluate(
    preds: dict, gt_path: str, gt_label_path: str, output_file: str, dataset: str = "scannet"
):
    global CLASS_LABELS
    global VALID_CLASS_IDS
    global ID_TO_LABEL
    global LABEL_TO_ID
    global opt

    total_true = 0
    total_seen = 0
    NUM_CLASSES = len(VALID_CLASS_IDS)

    true_positive_classes = np.zeros(NUM_CLASSES)
    positive_classes = np.zeros(NUM_CLASSES)
    gt_classes = np.zeros(NUM_CLASSES)

    # precision & recall
    total_gt_ins = np.zeros(NUM_CLASSES)
    at = 0.5
    tpsins = [[] for _ in range(NUM_CLASSES)]
    fpsins = [[] for _ in range(NUM_CLASSES)]
    # mucov and mwcov
    all_mean_cov = [[] for _ in range(NUM_CLASSES)]
    all_mean_weighted_cov = [[] for _ in range(NUM_CLASSES)]

    print("evaluating", len(preds), "scans...")
    matches = {}
    for i, (k, v) in enumerate(preds.items()):
        gt_file = os.path.join(gt_path, f"gt_mask_{k}.txt")
        label_dict_path = os.path.join(gt_label_path,f'id2part_r{k}.json')
        gt_label = util_3d.load_json(label_dict_path)
        print(gt_file)
        if not os.path.isfile(gt_file):
            util_3d.print_error(
                "Scan {} does not match any gt file".format(k), user_fault=True
            )

        if dataset == "s3dis":
            gt_ids = util_3d.load_ids(gt_file)
            gt_sem = (gt_ids // 1000) - 1
            gt_ins = gt_ids - (gt_ids // 1000) * 1000

            # pred_sem = v['pred_classes'] - 1
            pred_sem = np.zeros(v["pred_masks"].shape[0], dtype=np.int)
            # TODO CONTINUE HERE!!!!!!!!!!!!!
            pred_ins = np.zeros(v["pred_masks"].shape[0], dtype=np.int)

            for inst_id in reversed(range(v["pred_masks"].shape[1])):
                point_ids = np.argwhere(v["pred_masks"][:, inst_id] == 1.0)[
                    :, 0
                ]
                pred_ins[point_ids] = inst_id + 1
                pred_sem[point_ids] = v["pred_classes"][inst_id] - 1

            # semantic acc
            total_true += np.sum(pred_sem == gt_sem)
            total_seen += pred_sem.shape[0]

            # TODO PARALLELIZ THIS!!!!!!!
            # pn semantic mIoU
            """
            for j in range(gt_sem.shape[0]):
                gt_l = int(gt_sem[j])
                pred_l = int(pred_sem[j])
                gt_classes[gt_l] += 1
                positive_classes[pred_l] += 1
                true_positive_classes[gt_l] += int(gt_l == pred_l)
            """

            uniq, counts = np.unique(pred_sem, return_counts=True)
            positive_classes[uniq] += counts

            uniq, counts = np.unique(gt_sem, return_counts=True)
            gt_classes[uniq] += counts

            uniq, counts = np.unique(
                gt_sem[pred_sem == gt_sem], return_counts=True
            )
            true_positive_classes[uniq] += counts

            # instance
            un = np.unique(pred_ins)
            pts_in_pred = [[] for _ in range(NUM_CLASSES)]
            for ig, g in enumerate(un):  # each object in prediction
                if g == -1:
                    continue
                tmp = pred_ins == g
                sem_seg_i = int(stats.mode(pred_sem[tmp])[0])
                pts_in_pred[sem_seg_i] += [tmp]

            un = np.unique(gt_ins)
            pts_in_gt = [[] for _ in range(NUM_CLASSES)]
            for ig, g in enumerate(un):
                tmp = gt_ins == g
                sem_seg_i = int(stats.mode(gt_sem[tmp])[0])
                pts_in_gt[sem_seg_i] += [tmp]

            # instance mucov & mwcov
            for i_sem in range(NUM_CLASSES):
                sum_cov = 0
                mean_cov = 0
                mean_weighted_cov = 0
                num_gt_point = 0
                for ig, ins_gt in enumerate(pts_in_gt[i_sem]):
                    ovmax = 0.0
                    num_ins_gt_point = np.sum(ins_gt)
                    num_gt_point += num_ins_gt_point
                    for ip, ins_pred in enumerate(pts_in_pred[i_sem]):
                        union = ins_pred | ins_gt
                        intersect = ins_pred & ins_gt
                        iou = float(np.sum(intersect)) / np.sum(union)

                        if iou > ovmax:
                            ovmax = iou
                            ipmax = ip

                    sum_cov += ovmax
                    mean_weighted_cov += ovmax * num_ins_gt_point

                if len(pts_in_gt[i_sem]) != 0:
                    mean_cov = sum_cov / len(pts_in_gt[i_sem])
                    all_mean_cov[i_sem].append(mean_cov)

                    mean_weighted_cov /= num_gt_point
                    all_mean_weighted_cov[i_sem].append(mean_weighted_cov)


        matches_key = os.path.abspath(gt_file)
        # assign gt to predictions
        gt2pred, pred2gt = assign_instances_for_scan(v, gt_file,gt_label)
        matches[matches_key] = {}
        matches[matches_key]["gt"] = gt2pred
        matches[matches_key]["pred"] = pred2gt
        sys.stdout.write("\rscans processed: {}".format(i + 1))
        sys.stdout.flush()
    print("")
    ap_scores = evaluate_matches(matches)
    avgs = compute_averages(ap_scores)

    # print
    print_results(avgs)



def main():
    pred_dir = 'part_scene_results'

    gt_path = '/home/wan/Datasets/Test_scene/part_valid_gt'
    gt_label_path = '/home/wan/Datasets/Test_scene/id2part_valid_gt'
    finished_scene_path = glob.glob(pred_dir+"/*")
    finished_scene = [scene.split("/")[-1] for scene in finished_scene_path]
    
    preds = {}
    #part_scene_results/0001/0001_part_summary.txt
    for scene_name in finished_scene[:]:
        print(scene_name)
        file_path = os.path.join(pred_dir, scene_name, scene_name + '_part_summary.txt')  # {SCENE_ID}.txt file
        scene_pred_mask_list = np.loadtxt(file_path, dtype=str)  # (num_masks, 2)
        scene_pred_mask_list = scene_pred_mask_list.reshape(-1,3)
        assert scene_pred_mask_list.shape[1] == 3, f'{scene_name} Each line should have 2 values: instance mask path and confidence score.'

        pred_masks = []
        pred_scores = []
        pred_class = []

        for mask_path, prediction, conf_score in scene_pred_mask_list: 
            # Read mask and confidence score for each instance mask
            pred_mask = np.loadtxt(os.path.join(pred_dir, scene_name, mask_path), dtype=int) # Values: 0 for the background, 1 for the instance
            pred_masks.append(pred_mask)
            pred_scores.append(float(conf_score))
            pred_class.append(int(prediction))

        assert len(pred_masks) == len(pred_scores) == len(pred_class), f'{scene_name}Number of masks and confidence scores should be the same.'

        # Aggregate masks and scores for each scene - pred_class is always 1 (we only have one semantic class, 'object', referring to the query object)
        preds[scene_name] = {
            'pred_masks': torch.from_numpy(np.vstack(pred_masks).T) if len(pred_masks) > 0 else np.zeros((1, 1)),
            'pred_scores': torch.from_numpy(np.vstack(pred_scores)).squeeze(0) if len(pred_masks) > 0 else np.zeros(1),
            'pred_classes': torch.from_numpy(np.vstack(pred_class)).squeeze(0) if len(pred_masks) > 0 else np.ones(1, dtype=np.int64)*255
        }

    evaluate(preds, gt_path,gt_label_path, f"./{dataset}_final_result.csv")

if __name__ == "__main__":



    main()


[0.1   0.15  0.2   0.25  0.3   0.35  0.4   0.45  0.5   0.55  0.6   0.65
 0.7   0.75  0.8   0.85  0.9   0.24  0.324]
0055
0197
0187
0247
0036
0296
0292
0181
0224
0157
0294
0223
0268
0210
0171
0010
0005
0291
0066
0297
0076
0059
0192
0186
0168
0215
0255
0245
0206
0284
0065
0204
0019
0071
0048
0250
0111
0242
0044
0193
0182
0078
0056
0239
0094
0295
0042
0184
0229
0238
0063
0051
0228
0275
0041
0022
0199
0271
0152
0113
0014
0054
0052
0064
0053
0082
0240
0176
0097
0069
0191
0159
0232
0150
0104
0083
0246
0025
0230
0248
0024
0143
0149
0231
0220
0103
0233
0283
0267
0288
0045
0100
0004
0135
0209
0081
0251
0273
0098
0194
0256
0074
0105
0144
0153
0075
0260
0033
0161
0236
0287
0026
0203
0034
0158
0013
0058
0132
0120
0166
0259
0263
0090
0016
0087
0145
0020
0290
0235
0070
0002
0179
0108
0241
0009
0091
0222
0102
0021
0109
0155
0293
0285
0095
0079
0119
0133
0130
0141
0234
0188
0101
0123
0112
0272
0080
0001
0127
0279
0015
0226
0061
0136
0254
0202
0110
0189
0086
0167
0046
0174
0011
0237
0289
0128
0040
0115

  pred_mask = np.not_equal(pred_mask, 0)
  np.logical_and(bool_void, pred_mask)
  np.logical_and(gt_ids == gt_inst["instance_id"], pred_mask)


/home/wan/Datasets/Test_scene/part_valid_gt/gt_mask_0291.txt
scans processed: 18/home/wan/Datasets/Test_scene/part_valid_gt/gt_mask_0066.txt
scans processed: 19/home/wan/Datasets/Test_scene/part_valid_gt/gt_mask_0297.txt
scans processed: 20/home/wan/Datasets/Test_scene/part_valid_gt/gt_mask_0076.txt
scans processed: 21/home/wan/Datasets/Test_scene/part_valid_gt/gt_mask_0059.txt
scans processed: 22/home/wan/Datasets/Test_scene/part_valid_gt/gt_mask_0192.txt
scans processed: 23/home/wan/Datasets/Test_scene/part_valid_gt/gt_mask_0186.txt
scans processed: 24/home/wan/Datasets/Test_scene/part_valid_gt/gt_mask_0168.txt
scans processed: 25/home/wan/Datasets/Test_scene/part_valid_gt/gt_mask_0215.txt
scans processed: 26/home/wan/Datasets/Test_scene/part_valid_gt/gt_mask_0255.txt
scans processed: 27/home/wan/Datasets/Test_scene/part_valid_gt/gt_mask_0245.txt
scans processed: 28/home/wan/Datasets/Test_scene/part_valid_gt/gt_mask_0206.txt
scans processed: 29/home/wan/Datasets/Test_scene/part_valid

KeyboardInterrupt: 