In [None]:
import numpy as np
import torch
import torchvision
from PIL import Image
import imageio
# import litellm

# Grounding DINO
# import GroundingDINO.groundingdino.datasets.transforms as T
# from GroundingDINO.groundingdino.models import build_model
# from GroundingDINO.groundingdino.util.slconfig import SLConfig
# from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap

# # segment anything
# from segment_anything import (
#     build_sam,
#     build_sam_hq,
#     SamPredictor
# ) 
import cv2
import matplotlib.pyplot as plt

# Recognize Anything Model & Tag2Text
# from ram.models import ram
# from ram import inference_ram
# import torchvision.transforms as TS

from functools import partial
from PCR_CG.lib.timer import Timer
from PCR_CG.lib.utils import load_obj, natural_key
from PCR_CG.datasets.indoor import IndoorDataset
from PCR_CG.datasets.modelnet import get_train_datasets, get_test_datasets
import os,re,sys,json,yaml,random, argparse, torch, pickle
from easydict import EasyDict as edict
from PCR_CG.configs.models import architectures
from PCR_CG.models.architectures import KPFCNN
import open3d as o3d
# import open3d.visualization.jupyter as o3d_jupyter
import plotly.graph_objects as go
import torch.nn.functional as F

def load_config(path):
    """
    Loads config file:

    Args:
        path (str): path to the config file

    Returns: 
        config (dict): dictionary of the configuration parameters, merge sub_dicts

    """
    with open(path,'r') as f:
        cfg = yaml.safe_load(f)
    
    config = dict()
    for key, value in cfg.items():
        for k,v in value.items():
            config[k] = v

    return config




def get_datasets(config):
    info_train = load_obj(config.train_info)
    # print("=======",info_train)
    train_set = IndoorDataset(info_train,config,data_augmentation=True)
    return train_set

# ChatGPT or nltk is required when using tags_chineses
# import openai
# import nltk

def load_image(image_path):
    # load image
    image_pil = Image.open(image_path).convert("RGB")  # load image

    transform = T.Compose(
        [
            T.RandomResize([800], max_size=1333),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )
    image, _ = transform(image_pil, None)  # 3, h, w
    return image_pil, image


def check_tags_chinese(tags_chinese, pred_phrases, max_tokens=100, model="gpt-3.5-turbo"):
    object_list = [obj.split('(')[0] for obj in pred_phrases]
    object_num = []
    for obj in set(object_list):
        object_num.append(f'{object_list.count(obj)} {obj}')
    object_num = ', '.join(object_num)
    print(f"Correct object number: {object_num}")

    if openai_key:
        prompt = [
            {
                'role': 'system',
                'content': 'Revise the number in the tags_chinese if it is wrong. ' + \
                           f'tags_chinese: {tags_chinese}. ' + \
                           f'True object number: {object_num}. ' + \
                           'Only give the revised tags_chinese: '
            }
        ]
        response = litellm.completion(model=model, messages=prompt, temperature=0.6, max_tokens=max_tokens)
        reply = response['choices'][0]['message']['content']
        # sometimes return with "tags_chinese: xxx, xxx, xxx"
        tags_chinese = reply.split(':')[-1].strip()
    return tags_chinese


def load_model(model_config_path, model_checkpoint_path, device):
    args = SLConfig.fromfile(model_config_path)
    args.device = device
    model = build_model(args)
    checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
    load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
    print(load_res)
    _ = model.eval()
    return model



class PointCloudToImageMapper(object):
    def __init__(
            self, image_dim, visibility_threshold=0.1, cut_bound=0, intrinsics=None, device="cpu",
            use_torch=False, eps=1e-8):

        self.image_dim = image_dim
        self.vis_thres = visibility_threshold
        self.cut_bound = cut_bound
        self.intrinsics = intrinsics
        self.eps = eps

        self.device = device
        if use_torch:
            self.intrinsics = torch.from_numpy(self.intrinsics).to(device)

    def compute_mapping_torch(self, camera_to_world, coords, depth=None, intrinsic=None, vis_thresh=None):
        """
        :param camera_to_world: 4 x 4
        :param coords: N x 3 format
        :param depth: H x W format
        :param intrinsic: 3x3 format
        :return: mapping, N x 3 format, (H,W,mask)
        """

        depth = depth.squeeze(0)
        print(depth.shape)
        device = coords.device
        if vis_thresh is not None:
            self.vis_thres = vis_thresh
        if intrinsic is not None:  # adjust intrinsic
            self.intrinsics = intrinsic
        else:
            intrinsic = self.intrinsics
    
        camera_to_world = torch.from_numpy(camera_to_world).float()
        # camera_to_world = camera_to_world.to(device).float()
        mapping = torch.zeros((3, coords.shape[0]), dtype=torch.long, device=device)
        coords_new = torch.cat([coords, torch.ones([coords.shape[0], 1], dtype=torch.float, device=device)], dim=1).T

        assert coords_new.shape[0] == 4, "[!] Shape error"

        world_to_camera = torch.linalg.inv(camera_to_world)
        p = world_to_camera.float() @ coords_new.float()
        print(p.shape)
        # p =  torch.from_numpy(self.intrinsics) @ p
        # p[2][torch.abs(p[2]) < self.eps] = self.eps
        p[0] = (p[0] * intrinsic[0][0]) / p[2] + intrinsic[0][2]
        p[1] = (p[1] * intrinsic[1][1]) / p[2] + intrinsic[1][2]
        pi = torch.round(p).long()  # simply round the projected coordinates
        inside_mask = (
                (pi[0] >= self.cut_bound)
                * (pi[1] >= self.cut_bound)
                * (pi[0] < self.image_dim[0] - self.cut_bound)
                * (pi[1] < self.image_dim[1] - self.cut_bound)
        )
        if depth is not None:
            depth = torch.from_numpy(depth).to(device)
            # print(inside_mask.shape, depth.shape, pi.shape, pi[1][inside_mask].max(), pi[0][inside_mask].max())
            occlusion_mask = torch.abs(
                depth[pi[1][inside_mask], pi[0][inside_mask]] - p[2][inside_mask]) <= self.vis_thres
            inside_mask[inside_mask == True] = occlusion_mask.clone()
        else:
            front_mask = p[2] > 0  # make sure the depth is in front
            inside_mask = front_mask * inside_mask

        new_inside_mask = inside_mask

        mapping[0][new_inside_mask] = pi[1][new_inside_mask]
        mapping[1][new_inside_mask] = pi[0][new_inside_mask]
        mapping[2][new_inside_mask] = 1

        return mapping.T

    def compute_mapping(self, camera_to_world, coords, depth=None, intrinsic=None):
        """
        :param camera_to_world: 4 x 4
        :param coords: N x 3 format
        :param depth: H x W format
        :param intrinsic: 3x3 format
        :return: mapping, N x 3 format, (H,W,mask)
        """
        if intrinsic is not None:  # adjust intrinsic
            self.intrinsics = intrinsic
        else:
            intrinsic = self.intrinsics

        mapping = np.zeros((3, coords.shape[0]), dtype=int)
        coords_new = np.concatenate([coords, np.ones([coords.shape[0], 1])], axis=1).T
        assert coords_new.shape[0] == 4, "[!] Shape error"

        world_to_camera = np.linalg.inv(camera_to_world)
        p = np.matmul(world_to_camera, coords_new)
        p[2][np.abs(p[2]) < self.eps] = self.eps
        p[0] = (p[0] * intrinsic[0][0]) / p[2] + intrinsic[0][2]
        p[1] = (p[1] * intrinsic[1][1]) / p[2] + intrinsic[1][2]
        pi = np.round(p).astype(np.int32)  # simply round the projected coordinates
        inside_mask = (
                (pi[0] >= self.cut_bound)
                * (pi[1] >= self.cut_bound)
                * (pi[0] < self.image_dim[0] - self.cut_bound)
                * (pi[1] < self.image_dim[1] - self.cut_bound)
        )
        if depth is not None:
            depth_cur = depth[pi[1][inside_mask], pi[0][inside_mask]]
            occlusion_mask = (
                    np.abs(
                        depth[pi[1][inside_mask], pi[0][inside_mask]] - p[2][inside_mask]) <= self.vis_thres * depth_cur
            )
            inside_mask[inside_mask == True] = occlusion_mask
        else:
            front_mask = p[2] > 0  # make sure the depth is in front
            inside_mask = front_mask * inside_mask

        # NOTE detect occlusion
        pi_x_ = pi[1][inside_mask]
        pi_y_ = pi[0][inside_mask]
        pi_depth_ = pi[2][inside_mask]

        inds = (pi_x_ * self.image_dim[0] + pi_y_).astype(np.int32)
        _, inds = np.unique(inds, return_inverse=True)

        depth_min = torch_scatter.scatter_min(
            torch.from_numpy(pi_depth_).float(), torch.from_numpy(inds).long(), dim=0
        )[0]
        depth_min = torch.where(depth_min < 0.0, 0.0, depth_min)
        depth_min = depth_min.numpy()
        depth_min_broadcast = depth_min[inds]

        THRESHOLD = 0.2  # (meter)
        depth_occlusion_mask = (pi_depth_ - depth_min_broadcast) <= THRESHOLD

        new_inside_mask = inside_mask.copy()
        new_inside_mask[inside_mask] = depth_occlusion_mask
        ############################

        mapping[0][new_inside_mask] = pi[1][new_inside_mask]
        mapping[1][new_inside_mask] = pi[0][new_inside_mask]
        mapping[2][new_inside_mask] = 1

        return mapping.T

def adjust_intrinsic(intrinsic, intrinsic_image_dim, image_dim):

    if intrinsic_image_dim == image_dim:
        return intrinsic

    intrinsic_return = np.copy(intrinsic)

    height_after = image_dim[1]
    height_before = intrinsic_image_dim[1]
    height_ratio = height_after / height_before

    width_after = image_dim[0]
    width_before = intrinsic_image_dim[0]
    width_ratio = width_after / width_before

    if width_ratio >= height_ratio:
        resize_height = height_after
        resize_width = height_ratio * width_before

    else:
        resize_width = width_after
        resize_height = width_ratio * height_before

    intrinsic_return[0,0] *= float(resize_width)/float(width_before)
    intrinsic_return[1,1] *= float(resize_height)/float(height_before)
    # account for cropping/padding here
    intrinsic_return[0,2] *= float(resize_width-1)/float(width_before-1)
    intrinsic_return[1,2] *= float(resize_height-1)/float(height_before-1)



    return intrinsic_return    

if __name__ == "__main__":
    torch.cuda.empty_cache()
    config = edict(load_config('./PCR_CG/configs/train/indoor.yaml'))
    # print(config)
    train_set = get_datasets(config)
    print(train_set)
    # print(train_set[500])
    point = train_set[10]
    # torch.save(point,'/home/cx/cv_project/Open3DIS/pointdata.pth')
    print('已保存')
    src_path = point['src_path']
    tgt_path = point['tgt_path']
    parser = argparse.ArgumentParser("Grounded-Segment-Anything Demo", add_help=True)
    parser.add_argument("--config", type=str, required=False, help="path to config file")
    parser.add_argument(
        "--ram_checkpoint", type=str, required=False, help="path to checkpoint file"
    )
    parser.add_argument(
        "--grounded_checkpoint", type=str, required=False, help="path to checkpoint file"
    )
    parser.add_argument(
        "--sam_checkpoint", type=str, required=False, help="path to checkpoint file"
    )
    parser.add_argument(
        "--sam_hq_checkpoint", type=str, default=None, help="path to sam-hq checkpoint file"
    )
    parser.add_argument(
        "--use_sam_hq", action="store_true", help="using sam-hq for prediction"
    )
    parser.add_argument("--input_image", type=str, required=False, help="path to image file")
    parser.add_argument("--split", default=",", type=str, help="split for text prompt")
    parser.add_argument("--openai_key", type=str, help="key for chatgpt")
    parser.add_argument("--openai_proxy", default=None, type=str, help="proxy for chatgpt")
    parser.add_argument(
        "--output_dir", "-o", type=str, default="outputs", required=False, help="output directory"
    )

    parser.add_argument("--box_threshold", type=float, default=0.25, help="box threshold")
    parser.add_argument("--text_threshold", type=float, default=0.2, help="text threshold")
    parser.add_argument("--iou_threshold", type=float, default=0.5, help="iou threshold")

    parser.add_argument("--device", type=str, default="cpu", help="running on cpu only!, default=False")
    args = parser.parse_known_args()

    # cfg
    config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'  # change the path of the model config file
    ram_checkpoint = './ram_swin_large_14m.pth'  # change the path of the model
    grounded_checkpoint = './groundingdino_swint_ogc.pth'  # change the path of the model
    sam_checkpoint = './sam_vit_h_4b8939.pth'
    sam_hq_checkpoint = None
    use_sam_hq = ''
    image_path_src = src_path
    image_path_tgt = tgt_path
    split = ','
    openai_key = ''
    openai_proxy = None
    output_dir = "outputs_src"
    output_dir_tgt = "outputs_tgt"
    box_threshold = 0.25
    text_threshold = 0.2
    iou_threshold = 0.5
    device = "cpu"
    depth_scale = 1000.0
    img_dim = (640, 480)
    cut_num_pixel_boundary = 10


    # make dir
    os.makedirs(output_dir, exist_ok=True)


    # load 3D data (point cloud)
    points =  point['src_pcd']
    points = torch.from_numpy(points).to(device)
    n_points = points.shape[0]

    n_finished = 0

    # short hand for processing 2D features
    img_dir = point['src_path']
    # num_img = len(img_dirs)
    device = torch.device('cpu')

    n_points_cur = n_points
    counter = torch.zeros((n_points_cur, 1), device=device)
    mapping = torch.ones([n_points, 4], dtype=int, device=device)

    ################ Feature Fusion ###################
    # load pose
    posepath = img_dir.replace('color', 'pose').replace('.png', '.txt')
    pose = np.loadtxt(posepath).astype(float)

    # load depth and convert to meter
    depth = point['src_depth_image']
    depth = depth.cpu().numpy()
    origin_intrinsics = point['origin_intrinsics']
    big_size, image_size = [640, 480], [640, 480]
    # tranpose height and width
    intrinsics = adjust_intrinsic(origin_intrinsics, big_size, image_size)
    print('intrinsics',intrinsics)
    pose = np.loadtxt(posepath).astype(float)
    print('pose',pose)
    #                     尝试过indoor.py带回来的world2camera数据
    src1_world2camera = point['src1_world2camera']
    # camera2world = point['camera2world']
    # color_name = os.path.basename(img_dir).split('.')[0]
    # masks = torch.load(join(data_root_mask, 'maskraw_{}.pth'.format(color_name)))['mask']
    ### Cannot query directly 200 classes so split them into multiple chunks -- see Supplementary
    image = Image.open(point['src_path']).convert("RGB").resize(img_dim)
    image = np.array(image)
    pointcloud_mapper = PointCloudToImageMapper(
            image_dim=img_dim, intrinsics=intrinsics, cut_bound=cut_num_pixel_boundary
        )
    #                     这里传的是pose.txt文件里的camera2world数据
    mapping[:, 1:4] = pointcloud_mapper.compute_mapping_torch(
        pose, points, depth)
    # new_mapping = scaling_mapping(
    #     torch.squeeze(mapping[:, 1:3]), img_dim[1], img_dim[0], rgb_img_dim[0], rgb_img_dim[1]
    # )
    # mapping[:, 1:4] = torch.cat((new_mapping, mapping[:, 3].unsqueeze(1)), dim=1)
    # valid_mask = mapping[:, 2] == 1
    # valid_points = points[valid_mask]
    # 为点云添加颜色
    h, w = mapping[:, 0], mapping[:, 1]
    colors = image[h, w] / 255 # 归一化
    print(colors)
    fig = go.Figure(
        data=[
            go.Scatter3d(
            x=points[:,0], y=points[:,1], z=points[:,2],
            mode='markers',
            marker=dict(size=1, color=colors)
    )
    ],
    layout=dict(
        scene=dict(
        xaxis=dict(visible=False),
        yaxis=dict(visible=False),
        zaxis=dict(visible=False)
    )
    )
    )

    fig.show()

      # 创建 Open3D 点云对象
    pcd = o3d.geometry.PointCloud()

    # 设置点云坐标
    pcd.points = o3d.utility.Vector3dVector(points)

    # 设置颜色
    pcd.colors = o3d.utility.Vector3dVector(colors)

    # 保存为 PLY 文件
    o3d.io.write_point_cloud('src_mapping.ply', pcd)



 