In [None]:
from chat_with_nerf.chat.agent import Agent 
import os
import json
import numpy as np
from chat_with_nerf.chat.session import Session
import time
import open3d as o3d
from pathlib import Path
from chat_with_nerf.chat.system_prompt import EDITED_SYSTEM_PROMPT, NO_VISUAL_FEEDBACK_SYSTEM_PROMPT
from chat_with_nerf.settings import Settings
from joblib import Parallel, delayed
from evaluation_vis_util import draw_plotly, create_bbox
from tqdm import tqdm
import uuid
import torch
from collections import OrderedDict
from utils import box3d_iou, process_json, process_all_json_files, is_label_unique, convert_origin_bbox, get_transformation_matrix, construct_bbox_corners, get_box3d_min_max

In [None]:
# GPT35
# os.environ['API_URL'] = ""
# os.environ['OPENAI_API_KEY'] = ""
# GPT 4
# os.environ['API_URL'] = ""
# os.environ['OPENAI_API_KEY'] = ""

root_directory = ''  # Assuming current directory, adjust path if needed

In [None]:
def get_val_set(root_directory):
    json_dict = {}
    # List of all subfolders and their files
    subfolders_files = [(dp, filenames) for dp, _, filenames in os.walk(root_directory)]
    # Dictionary comprehension to pick only the first JSON from each subfolder
    json_dict = {os.path.basename(dp): os.path.join(dp, filenames[0]) for dp, filenames in subfolders_files if any(fn.endswith('.json') for fn in filenames)}

    return json_dict

json_dict = get_val_set(root_directory)

In [None]:
acc_25 = 0
acc_50 = 0
acc_25_unique = 0
acc_50_unique = 0
acc_25_multiple = 0
acc_50_multiple = 0
list_iou = []
total_object = 0
total_unique_object = 0
total_multiple_object = 0
session_id_list = []
acc_25_top2_hit = 0
acc_25_top3_hit = 0
acc_25_top5_hit =0
acc_25_all_hit = 0
acc_25_top2_hit_unique = 0
acc_25_top2_hit_multiple = 0
acc_25_top3_hit_unique = 0
acc_25_top3_hit_multiple = 0
acc_25_top5_hit_unique = 0
acc_25_top5_hit_multiple = 0
acc_25_all_hit_unique = 0
acc_25_all_hit_multiple = 0


result_dict ={
    'scene_name': list(),
    'description': list(),
    'centroid_list': list(),
    'extent_list': list(),
    'similarity_mean_list_list': list(),
    'ground truth': list()
}

In [None]:
def process_description(agent, scene_name, description, corners_original, is_unique, object_id, query_rank_id, center_original, extents_original):  
    MAX_RETRIES = 3
    for retry in range(MAX_RETRIES):
        try:
            print(f"Created new session for scene {scene_name}........................................................")
            new_session = Session.create_for_scene(scene_name)  
            unique_id = str(uuid.uuid4())
            new_session.session_id = f"{scene_name}-{object_id}-{query_rank_id}-{unique_id}" 
            new_session.working_scene_name = scene_name
            new_session.grounding_query = description
            new_session.ground_truth = [center_original, extents_original]
            print(description)  
            generator = agent.act(  
                NO_VISUAL_FEEDBACK_SYSTEM_PROMPT,  
                description,  
                0.9,
                1,
                scene_name,
                new_session  
            )  
            for chat_history_for_display, chat_counter, get_status_code, session, grounding_result_mesh_path in generator:  
                pass  
            prediction = session.candidate[str(session.chosen_candidate_id)]
            center = prediction['centroid']  
            extent = prediction['extent']  
            prediction = construct_bbox_corners(center, extent)
            iou3d_top_1 = box3d_iou(np.array(corners_original), prediction)
            # we want to calculate top 3 iou from top_5_objects2scores and to make top_5_objects2scores ordered by the value
            top_5_objects2scores = session.top_5_objects2scores
            ranked_top_5_objects2scores = OrderedDict(sorted(top_5_objects2scores.items(), key=lambda item: item[1], reverse=True))
            # top3 hit by looping over ranked_top_5_objects2scores and caluclate the iou
            top2_hit_list = []
            top3_hit_list = []
            top5_hit_list = []
            for index, (object_id, score) in enumerate(ranked_top_5_objects2scores.items()):
                cur_candidate = session.candidate[str(object_id)]
                cur_center = cur_candidate['centroid']  
                cur_extent = cur_candidate['extent'] 
                cur_prediction = construct_bbox_corners(cur_center, cur_extent)
                iou3d_top = box3d_iou(np.array(corners_original), cur_prediction)
                if index < 2:
                    top2_hit_list.append(iou3d_top)
                if index < 3:
                    top3_hit_list.append(iou3d_top)
                top5_hit_list.append(iou3d_top)
            
            top2_hit = max(top2_hit_list)
            top3_hit = max(top3_hit_list)
            top5_hit = max(top5_hit_list)
            
            all_hit_list = []
            
            for _, candidate_mem in session.candidate.items():
                center_candidate = candidate_mem['centroid']  
                extent_candidate = candidate_mem['extent']  
                prediction_candidate = construct_bbox_corners(center_candidate, extent_candidate)
                iou3d_candidate = box3d_iou(np.array(corners_original), prediction_candidate)
                all_hit_list.append(iou3d_candidate)
            
            all_hit = max(all_hit_list)
            print("new session being saved........................................................")
            new_session.save("/workspace/chat-with-nerf-dev/chat-with-nerf/session_output_lerf_14_with_gpt_4_final")
            return iou3d_top_1, top2_hit, top3_hit, top5_hit, all_hit, session.session_id, session, is_unique
        except Exception as exp:
            if retry < MAX_RETRIES - 1:  # If it's not the last retry
                print(f"Attempt {retry + 1} failed. Retrying... Error is {exp}")
                continue
            else:  # On the last retry, print the exception and return a message
                print(exp)
                # return f"Failed after {MAX_RETRIES} attempts.", None, None, None

In [None]:
results_list = []

In [None]:
for i, (scene_name, json_path) in enumerate(json_dict.items()):
    if scene_name == 'scene0169_00' or scene_name == 'scene0412_00':
        print(f"Processing scene {i+1}/{len(json_dict)}: {scene_name}")
        preprocessed_data = []
        scene_path = json_dict[scene_name]
        with open(scene_path, 'r') as file:
            data = json.load(file)
        furnitures = data['objects']
        agent = Agent(scene_name=scene_name)
        agent.API_URL = os.environ['API_URL']
        agent.OPENAI_API_KEY = os.environ['OPENAI_API_KEY']
        for furniture in furnitures:
            bbox = furniture['bbox']
            center_original, extents_original = bbox[:3], bbox[3:6]
            corners_original = construct_bbox_corners(center_original, extents_original)
            label = furniture['label']
            is_unique = is_label_unique(furnitures, label)
            
            if is_unique:
                total_unique_object += len(furniture['description'])
            else:
                total_multiple_object += len(furniture['description'])
            
            for idx, description in enumerate(furniture['description']):
                total_object += 1
                preprocessed_data.append((agent, scene_name, description, corners_original, is_unique, label, idx, center_original, extents_original))
        results = Parallel(n_jobs=15, backend='threading')(delayed(process_description)(*data) for data in tqdm(preprocessed_data, desc="Processing descriptions"))
        results_list.append(results)
        del agent
        torch.cuda.empty_cache()
    
    

In [None]:
for results in results_list:
    for result in results:
        if result is None:
            continue

        iou3d, top2_hit, top3_hit, top5_hit, all_hit, session_id, session, is_unique = result
    
        if iou3d:
            if iou3d > 0.25:  
                acc_25 += 1
                if is_unique:  
                    acc_25_unique += 1  
                else:
                    acc_25_multiple += 1  
            if iou3d > 0.5:
                acc_50 += 1  
                if is_unique:  
                    acc_50_unique += 1  
                else:  
                    acc_50_multiple += 1
        else:
            iou3d = 0
            sessionid = 'none'
            
        if top2_hit:
            if top2_hit > 0.25:
                acc_25_top2_hit += 1
                if is_unique:
                    acc_25_top2_hit_unique += 1
                else:
                    acc_25_top2_hit_multiple += 1
    
        if top3_hit:
            if top3_hit > 0.25:
                acc_25_top3_hit += 1
                if is_unique:
                    acc_25_top3_hit_unique += 1
                else:
                    acc_25_top3_hit_multiple += 1
        
        if top5_hit:
            if top5_hit > 0.25:
                acc_25_top5_hit += 1
                if is_unique:
                    acc_25_top5_hit_unique += 1
                else:
                    acc_25_top5_hit_multiple += 1
                    
        if all_hit:
            if all_hit > 0.25:
                acc_25_all_hit += 1
                if is_unique:
                    acc_25_all_hit_unique += 1
                else:
                    acc_25_all_hit_multiple += 1
                    
                    
        list_iou.append(iou3d)
        session_id_list.append(session_id)

In [None]:
print("acc_25 =", acc_25)
print("acc_50 =", acc_50)
print("acc_25_unique =", acc_25_unique)
print("acc_50_unique =", acc_50_unique)
print("acc_25_multiple =", acc_25_multiple)
print("acc_50_multiple =", acc_50_multiple)
print("acc_25_top2_hit =", acc_25_top2_hit)
print("acc_25_top3_hit =", acc_25_top3_hit)
print("acc_25_top5_hit =", acc_25_top5_hit)
print("acc_25_all_hit =", acc_25_all_hit)
print('total_object: ', total_object)
print('total_unique_object: ', total_unique_object)
print('total_multiple_object: ', total_multiple_object)
print("acc_25_top2_hit_unique: ", acc_25_top2_hit_unique)
print("acc_25_top2_hit_multiple: ", acc_25_top2_hit_multiple)
print("acc_25_top3_hit_unique: ", acc_25_top3_hit_unique)
print("acc_25_top3_hit_multiple: ", acc_25_top3_hit_multiple)
print("acc_25_top5_hit_unique: ", acc_25_top5_hit_unique)   
print("acc_25_top5_hit_multiple: ", acc_25_top5_hit_multiple)
print("acc_25_all_hit_unique: ", acc_25_all_hit_unique)
print("acc_25_all_hit_multiple: ", acc_25_all_hit_multiple)
print("list_iou =", list_iou)
print("total_object =", total_object)
print("session_id_list =", session_id_list)