In [None]:
# 变换位置编码后的验证程序
import torch
import numpy as np
# torch.autograd.set_detect_anomaly(True)
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import os
import random
import matplotlib.pyplot as plt
import re
from matplotlib.widgets import Slider
from skimage import measure

import plotly.graph_objects as go
import io

class NeRF(nn.Module):
    def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=2, skips=[4], use_viewdirs=True):
        """
        D: 深度，多少层网络
        W: 网络内的channel 宽度
        input_ch: xyz的宽度
        input_ch_views: direction的宽度
        output_ch: 这个参数尽在 use_viewdirs=False的时候会被使用
        skips: 类似resnet的残差连接，表明在第几层进行连接
        use_viewdirs:

        网络输入已经被位置编码后的参数，输入为[64*bs,90]，输出为[64*bs，2]，一位是体积密度，一位是后向散射系数
        """
        super(NeRF, self).__init__()
        self.D = D
        self.W = W
        self.input_ch = input_ch
        self.input_ch_views = input_ch_views
        self.skips = skips
        self.use_viewdirs = use_viewdirs

        # 神经网络,MLP
        # 3D的空间坐标进入的网络
        # 这个跳跃连接层是直接拼接，不是resnet的那种相加
        self.pts_linears = nn.ModuleList(
            [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in
                                        range(D - 1)])

        # 这里channel削减一半 128
        ### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105)
        self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W // 2)])
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

        if use_viewdirs:
            # 特征
            self.feature_linear = nn.Linear(W, W)
            # 体积密度,一个值
            self.alpha_linear = nn.Linear(W, 1)
            # 后向散射系数，一个值
            self.rho_linear = nn.Linear(W // 2, 1)
        else:
            self.output_linear = nn.Linear(W, output_ch)

    def forward(self, x):
        # x [bs*64, 90]
        # input_pts [bs*64, 63]
        # input_views [bs*64,27]
        input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)

        h = input_pts

        for i, l in enumerate(self.pts_linears):

            h = self.pts_linears[i](h)
            h = F.relu(h)
            # 第四层后相加
            if i in self.skips:
                h = torch.cat([input_pts, h], -1)

        if self.use_viewdirs:
            # alpha只与xyz有关
            alpha = self.alpha_linear(h)
            feature = self.feature_linear(h)
            # rho与xyz和d都有关
            h = torch.cat([feature, input_views], -1)

            for i, l in enumerate(self.views_linears):
                h = self.views_linears[i](h)
                h = F.relu(h)

            sigma = self.rho_linear(h)
            alpha = self.relu(alpha) 
            # sigma = 100*self.sigmoid(sigma)
            sigma = self.relu(sigma) 
            
            outputs = torch.cat([alpha, sigma], -1)
        else:
            outputs = self.output_linear(h)

        return outputs


def batchrender(omega,LOS,model,doppler_num):
    '''
    omega为一个[bs,3]变量，指向旋转轴方向，模值为角速度
    LOS为一个[bs,3]变量，方向为视线方向指向物体，模值为1
    model是nerf模型，将一个已经进行位置编码后的位置和视线向量输入进model,可以返回这个位置的体积密度和散射系数
    doppler_num为一个[bs]变量，确定了渲染后光线所在的位置
    '''
    # 确定回波波长
    fc = torch.tensor([9.7e9]).to(device)
    c = torch.tensor([299792458]).to(device)
    lambda0 = c/fc
    # 确定网格参数
    distance_max = 0.6
    distance_min = -0.6
    distance_gap = 100
    doppler_max = 0.15
    doppler_min = -0.15
    doppler_gap = 100
    n_max = 0.60
    n_min = -0.60
    n_gap = 120
    # distance_max = 0.582
    # distance_min = -0.582
    # distance_gap = 97
    # doppler_max = 0.1575
    # doppler_min = -0.1575
    # doppler_gap = 105
    # n_max = 0.60
    # n_min = -0.60
    # n_gap = 120
    # 确定输入batch_size
    batch_size,len = omega.shape
    # 确定每个batch_size输入的投影平面
    omega_norm = torch.linalg.norm(omega,dim = 1)
    omega_normlize = omega/omega_norm.unsqueeze(1)
    Doppler_vector = torch.cross(LOS,omega,dim=1)
    LOSomega_sin_angel = torch.linalg.norm(Doppler_vector,dim=1)/(torch.linalg.norm(omega,dim=1)*torch.linalg.norm(LOS,dim=1))
    Doppler_vector = Doppler_vector/torch.linalg.norm(Doppler_vector,dim = 1).unsqueeze(1)
    # 绘制投影坐标
    distance = torch.linspace(distance_min,distance_max,distance_gap).to(device)
    distance = distance.repeat(batch_size,1)
    distance_delta = torch.tensor((distance_max-distance_min)/distance_gap).to(device)
    doppler = torch.linspace(doppler_min,doppler_max,doppler_gap).repeat(batch_size,1).to(device)
    doppler = doppler*4/LOSomega_sin_angel.unsqueeze(1)
    distance_map = distance.unsqueeze(2)*LOS.unsqueeze(1)
    doppler_map = doppler.unsqueeze(2)*Doppler_vector.unsqueeze(1)
    # 确定投影平面法向量
    n = torch.cross(LOS,Doppler_vector,dim=1)
    n = n/torch.linalg.norm(n,dim = 1).unsqueeze(1)
    # 对投影平面法向量进行随机采样
    n_array = torch.linspace(n_min,n_max,n_gap+1).to(device)
    n_array = n_array.repeat(batch_size,distance_gap,1)
    # 非随机采样
    # n_random_array = n_array[:,:,0:-1] + (n_array[:,:,1:] - n_array[:,:,0:-1])*torch.ones(batch_size,distance_gap,n_gap).to(device)*0.5
    # # 随机采样
    n_random_array = n_array[:,:,0:-1] + (n_array[:,:,1:] - n_array[:,:,0:-1])*torch.rand(batch_size,distance_gap,n_gap).to(device)
    n_random_map = n_random_array.unsqueeze(3)*n.unsqueeze(1).unsqueeze(2)
    # 计算不同随机法向量之间的间隔
    start_n = n.unsqueeze(1).unsqueeze(2)*torch.tensor(n_min).float().to(device)
    start_n = start_n * torch.ones(batch_size,distance_gap,1,3).to(device)
    n_random_map_temp = torch.cat((start_n,n_random_map),dim=2)
    n_delta = torch.norm(n_random_map_temp[:,:,0:-1,:]-n_random_map,dim=3)

    # 计算所有需要输入网络的坐标
    code_flag = 1
    if code_flag == 1:
        xyz = doppler_map[torch.arange(batch_size),doppler_num,:].unsqueeze(1).unsqueeze(2) + distance_map.unsqueeze(2) + n_random_map
        xyz_coding = positon_code_xyz(xyz)
        LOS_coding = position_code_LOS(LOS)
        LOS_coding = ((LOS_coding.unsqueeze(1).unsqueeze(2))*torch.ones(batch_size,distance_gap,n_gap,27).to(device)).view(-1,27)
        xyzLOS_coding = torch.cat((xyz_coding,LOS_coding),dim=1)
    else:
        xyz = doppler_map[torch.arange(batch_size),doppler_num,:].unsqueeze(1).unsqueeze(2) + distance_map.unsqueeze(2) + n_random_map
        xyz_coding = xyz.view(-1,3)
        LOS_coding = ((LOS.unsqueeze(1).unsqueeze(2))*torch.ones(batch_size,distance_gap,n_gap,3).to(device)).view(-1,3)
        print(LOS_coding)
        xyzLOS_coding = torch.cat((xyz_coding,LOS_coding),dim=1)
    output = model(xyzLOS_coding)
    output = output.view(batch_size,distance_gap,n_gap,2)
    render_equaltion = 2
    if render_equaltion == 0:
        Ti = torch.cumprod(torch.exp(-output[:,:,:,0]*distance_delta),dim=1)
        distance_profile = torch.sum(output[:,:,:,0]*(1-torch.exp(-output[:,:,:,1]*n_delta))*Ti,dim=2)
    elif render_equaltion == 1:
        Ti = torch.cumprod(torch.exp(-output[:,:,:,0]**2*distance_delta),dim=1)
        temp = output[:,:,:,0]*output[:,:,:,1]*n_delta*Ti
        distance_profile = torch.sum(output[:,:,:,0]*output[:,:,:,1]*n_delta*Ti,dim=2)
    elif render_equaltion == 2:
        Ti = torch.cumprod(torch.exp(-output[:,:,:,0]*distance_delta),dim=1)
        # print(Ti[:,0,:])
        # 将Ti的第1维首增加一个1，并去除最后一维，方便计算
        Ti = torch.cat((torch.ones(batch_size,1,n_gap).to(device),Ti),dim=1)[:,:-1,:]
        # 计算alpha_i
        alphai = 1-torch.exp(-output[:,:,:,0]*distance_delta)
        temp = alphai*output[:,:,:,1]*n_delta*Ti
        distance_profile = torch.sum(temp,dim=2)
    elif render_equaltion == 3:
        # Ranerf的累积规则
        Ti = torch.cumprod(torch.exp(-output[:,:,:,0]*distance_delta),dim=2)
        Ti = torch.cat((torch.ones(batch_size,distance_gap,1).to(device),Ti),dim=2)[:,:,:-1]
        alphai = 1-torch.exp(-output[:,:,:,0]*n_delta)
        temp = alphai*output[:,:,:,1]*n_delta*Ti
        distance_profile = torch.sum(alphai*output[:,:,:,1]*n_delta*Ti,dim=2)    
    # return distance_profile,temp[:,:,:]
    return distance_profile,output[:,:,:,1]


def positon_code_xyz(xyz):
    code_len = 10
    batch_size,distance,n,dimension = xyz.shape
    xyz = xyz.view(-1,dimension)
    xyz = xyz
    position_coding = torch.zeros_like(xyz).to(device)
    position_coding = position_coding.repeat(1,code_len*2)
    div_term = 2 ** torch.arange(0,code_len,step=1).to(device)
    position_coding[:,0::2] = torch.sin((xyz.unsqueeze(1) * math.pi * div_term.unsqueeze(1).unsqueeze(0)).view(batch_size*distance*n,-1))
    position_coding[:,1::2] = torch.cos((xyz.unsqueeze(1) * math.pi * div_term.unsqueeze(1).unsqueeze(0)).view(batch_size*distance*n,-1))
    position_coding = torch.cat((xyz,position_coding),dim=1)
    
    return position_coding

def position_code_LOS(LOS):
    code_len = 4
    batch_size,dimension = LOS.shape
    position_coding = torch.zeros_like(LOS).to(device)
    position_coding = position_coding.repeat(1,code_len*2)
    div_term = 2 ** torch.arange(0,code_len,step=1).to(device)
    position_coding[:,0::2] = torch.sin((LOS.unsqueeze(1) * div_term.unsqueeze(1).unsqueeze(0)).view(batch_size,-1))
    position_coding[:,1::2] = torch.cos((LOS.unsqueeze(1) * div_term.unsqueeze(1).unsqueeze(0)).view(batch_size,-1))
    position_coding = torch.cat((LOS,position_coding),dim=1)

    return position_coding

def picture_sample(images,LOS_dirs,omegas,batch_size,image_hight = 100,image_width = 100, image_num = 30):
    temp_num = random.sample(range(image_num),1)
    temp_num = [54]
    print(temp_num)
    temp = [t*image_hight for t in temp_num] + np.arange(100)
    data_num = [x//image_hight for x in temp]
    doppler_numbers = [x % image_hight for x in temp]
    
    LOS_dirs_batch = [LOS_dirs[x] for x in data_num]
    omegas_batch = [omegas[x] for x in data_num]
    range_profile_batch = [images[x][y,:] for x,y in zip(data_num,doppler_numbers)]


    omegas_batch_tensor = torch.stack(omegas_batch).to(device)
    LOS_dirs_batch_tensor = torch.stack(LOS_dirs_batch).to(device)
    range_profile_batch_tensor = torch.stack(range_profile_batch).to(device)
    doppler_profil_num_tensor = torch.tensor(doppler_numbers).long().to(device)

    # range_image = range_profile_batch_tensor.detach().cpu()
    # plt.imshow(range_image)
    # plt.show()
    
    return omegas_batch_tensor,LOS_dirs_batch_tensor,range_profile_batch_tensor,doppler_profil_num_tensor

def natural_sort_key(s):
    # 分割字符串中的数字并将它们转换为整数
    return [int(text) if text.isdigit() else text.lower() for text in re.split(r'(\d+)', s)]

def loaddata(folder_path):
    '''
    输入文件夹路径，输出数据集
    '''
    # 获取文件夹中的所有文件和子文件夹
    items = os.listdir(folder_path)
    # 过滤出所有文件（排除子文件夹）
    files = [item for item in items if os.path.isfile(os.path.join(folder_path, item)) and item.endswith('.npz')]
    files_sorted = sorted(files, key=natural_sort_key)
    # files_sorted = random.sample(files_sorted, 16)
    #载入数据
    images = []
    LOS_dirs = []
    omegas = []
    for file in files_sorted:
        full_path = folder_path+"/"+file
        data = np.load(full_path)
        image = torch.from_numpy(data['image']).to(device)
        LOS_dir = torch.from_numpy(data['LOS']).to(device)
        omega = torch.from_numpy(data['rotation_axis']).to(device)
        images.append(image)
        LOS_dirs.append(LOS_dir)
        omegas.append(omega)

    # 可视化LOS_dirs
    LOS_dirs = torch.stack(LOS_dirs)
    print(LOS_dirs.shape)
    # 在三维空间中表示LOS_dirs
    fig = go.Figure(data=[go.Scatter3d(x=LOS_dirs[:, 0].cpu().numpy(),
                                        y=LOS_dirs[:, 1].cpu().numpy(),
                                        z=LOS_dirs[:, 2].cpu().numpy(),
                                        mode='markers',
                                        marker=dict(size=2, color='blue'))])
    fig.update_layout(scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'),
                      title='LOS Directions in 3D Space')
    fig.show()
    



    return images,LOS_dirs,omegas

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(torch.cuda.is_available())

modelname = 'experiment212'

model = NeRF(input_ch = 63, input_ch_views = 27, use_viewdirs = True).to(device)
model.load_state_dict(torch.load('/DATA/disk1/asteroid/asteroid_inverse/Instant-ngp/model/'+ modelname +'/model_state_dict.pth'))
model.eval()

omega_real = math.pi/900

folder_path = '/DATA/disk1/asteroid/asteroid_inverse/ImageGen/3dmodel/XXX/XXX_dilate_real_image_13.8du'
# folder_path = '/DATA/disk1/asteroid/asteroid_inverse/Instant-ngp/new_dataset/sys_data/contactball_rot90'

images,LOS_dirs,omegas = loaddata(folder_path)

# 获得list类型数据omegas的长度
image_num = len(images)

omegas_batch_tensor,LOS_dirs_batch_tensor,range_profile_batch_tensor,doppler_profil_num_tensor = picture_sample(images,LOS_dirs,omegas,batch_size = 40,image_num=image_num)
distance_profile_batch,output = batchrender(omegas_batch_tensor*omega_real,LOS_dirs_batch_tensor,model,doppler_profil_num_tensor)

# print(LOS_dirs)
# print(LOS_dirs_batch_tensor)

range_image1 = distance_profile_batch.detach().cpu()
# plt.figure()
# plt.imshow(range_image1)
# plt.colorbar()


range_image2 = range_profile_batch_tensor.detach().cpu()

# plt.figure()
# plt.imshow(range_image2)
# plt.colorbar()
# plt.show()
# 创建子图
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

# 图像1
axes[0].imshow(range_image2,cmap='gray')
axes[0].set_title("groundtruth image")
axes[0].axis('off')

# 图像2
axes[1].imshow(range_image1,cmap='gray')
axes[1].set_title("nerf image")
axes[1].axis('off')

# 将image1和image2的图片以灰度形式分别导出为png
plt.imsave('groundtruth_image.png', range_image2, cmap='gray')
plt.imsave('nerf_image.png', range_image1, cmap='gray')

# 截取两张图片的以（40,50）为中心的10*10区域
range_image1_cropped = range_image1[30:50, 40:65]
range_image2_cropped = range_image2[30:50, 40:65]
# 将截取的区域以灰度形式导出为png
plt.imsave('groundtruth_image_cropped.png', range_image2_cropped, cmap='gray')
plt.imsave('nerf_image_cropped.png', range_image1_cropped, cmap='gray')


# 在两张图片上标注某一个点的位置
axes[0].scatter(40, 50, c='r', s=10)
axes[1].scatter(40, 50, c='r', s=10)

# 计算图像之间的MSE
mse = ((range_image1 - range_image2) ** 2).mean()
print(mse)
plt.tight_layout()
plt.show()

# 计算图像之间的相似度
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import mean_squared_error as mse
print(range_image1.max(), range_image1.min())
similarity_index = ssim(range_image1.detach().cpu().numpy(), range_image2.detach().cpu().numpy(),data_range=1.0365)
print("SSIM:", similarity_index)
# 计算图像之间的PSNR
from skimage.metrics import peak_signal_noise_ratio as psnr
psnr_value = psnr(range_image1.detach().cpu().numpy(), range_image2.detach().cpu().numpy(), data_range=1.0365)
print("PSNR:", psnr_value)
# # 计算图像之间的均方误差
# mse_value = mse(range_image1.numpy(), range_image2.numpy())
# print("MSE:", mse_value)
# # 计算图像之间的峰值信噪比
# from skimage.metrics import peak_signal_noise_ratio as psnr
# psnr_value = psnr(range_image1.numpy(), range_image2.numpy())
# print("PSNR:", psnr_value)
# # 计算图像之间的均方根误差
# rmse_value = np.sqrt(mse_value)
# print("RMSE:", rmse_value)
# # 计算图像之间的平均绝对误差
# mae_value = np.mean(np.abs(range_image1.numpy() - range_image2.numpy()))
# print("MAE:", mae_value)
# # 计算图像之间的互信息
# from skimage.metrics import mutual_info_score as mi
# mi_value = mi(range_image1.numpy().flatten(), range_image2.numpy().flatten())
# print("MI:", mi_value)

ray_distribution = output.detach().cpu().numpy()
# ray_distribution_one = ray_distribution[40,78,:]

# # 绘制图像
# plt.figure()
# plt.plot(ray_distribution_one)
# plt.show()




In [None]:

# 1. 交互式切片可视化
def visualize_slices(density_volume):
    """创建交互式切片可视化"""
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # 初始切片位置
    x_idx, y_idx, z_idx = density_volume.shape[0]//2, density_volume.shape[1]//2, density_volume.shape[2]//2
    
    # 创建切片显示
    x_slice = axes[0].imshow(density_volume[x_idx, :, :], cmap='viridis')
    axes[0].set_title(f'YZ平面 (X={x_idx})')
    y_slice = axes[1].imshow(density_volume[:, y_idx, :], cmap='viridis')
    axes[1].set_title(f'XZ平面 (Y={y_idx})')
    z_slice = axes[2].imshow(density_volume[:, :, z_idx], cmap='viridis')
    axes[2].set_title(f'XY平面 (Z={z_idx})')
    
    # 添加滑块
    ax_x = plt.axes([0.1, 0.02, 0.8, 0.03])
    ax_y = plt.axes([0.1, 0.06, 0.8, 0.03])
    ax_z = plt.axes([0.1, 0.10, 0.8, 0.03])
    
    s_x = Slider(ax_x, 'X切片', 0, density_volume.shape[0]-1, valinit=x_idx, valstep=1)
    s_y = Slider(ax_y, 'Y切片', 0, density_volume.shape[1]-1, valinit=y_idx, valstep=1)
    s_z = Slider(ax_z, 'Z切片', 0, density_volume.shape[2]-1, valinit=z_idx, valstep=1)
    
    def update(val):
        x_idx = int(s_x.val)
        y_idx = int(s_y.val)
        z_idx = int(s_z.val)
        
        x_slice.set_data(density_volume[x_idx, :, :])
        axes[0].set_title(f'YZ平面 (X={x_idx})')
        y_slice.set_data(density_volume[:, y_idx, :])
        axes[1].set_title(f'XZ平面 (Y={y_idx})')
        z_slice.set_data(density_volume[:, :, z_idx])
        axes[2].set_title(f'XY平面 (Z={z_idx})')
        fig.canvas.draw_idle()
    
    s_x.on_changed(update)
    s_y.on_changed(update)
    s_z.on_changed(update)
    
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.25)
    plt.show()

# 2. Marching Cubes可视化
def visualize_isosurfaces(density_volume, threshold_ratio=0.5):
    """使用Marching Cubes创建等值面可视化"""
    # 计算阈值
    vmin, vmax = density_volume.min(), density_volume.max()
    threshold = vmin + (vmax - vmin) * threshold_ratio
    
    # 提取等值面
    verts, faces, normals, values = measure.marching_cubes(density_volume, threshold)
    
    # 调整顶点坐标到[-0.6, 0.6]范围
    verts = verts / density_volume.shape[0] * 1.2 - 0.6
    
    # 创建3D图
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection='3d')
    
    # 绘制等值面
    mesh = ax.plot_trisurf(verts[:, 0], verts[:, 1], faces, verts[:, 2],color='cyan', alpha=0.7, shade=True)

    filename="mesh_export25.obj"

    with open(filename, 'w') as f:
        # 写入文件头
        f.write(f"# OBJ文件由NeRF模型生成\n")
        f.write(f"# 顶点数: {verts.shape[0]}\n")
        f.write(f"# 面数: {faces.shape[0]}\n\n")
        
        # 写入顶点
        for v in verts:
            f.write(f"v {v[0]} {v[1]} {v[2]}\n")
        
        # 写入面 (OBJ索引从1开始)
        for face in faces:
            # 索引加1转换为OBJ格式
            f.write(f"f {face[0]+1} {face[1]+1} {face[2]+1}\n")
        
    print(f"OBJ文件导出成功: {filename}")
    
    
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_title(f'密度等值面 (阈值: {threshold:.4f})')
    
    ax.set_xlim([-1.0, 1.0])
    ax.set_ylim([-1.0, 1.0])
    ax.set_zlim([-1.0, 1.0])
    
    plt.tight_layout()
    plt.show()

# 3. 多阈值等值面可视化
def visualize_multi_isosurfaces(density_volume, thresholds=[0.3, 0.5, 0.7]):
    """使用多个阈值创建嵌套等值面可视化"""
    vmin, vmax = density_volume.min(), density_volume.max()
    thresholds = [vmin + (vmax - vmin) * t for t in thresholds]
    
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection='3d')
    
    colors = ['blue', 'cyan', 'green']
    alphas = [0.3, 0.5, 0.7]
    
    for i, threshold in enumerate(thresholds):
        verts, faces, normals, values = measure.marching_cubes(density_volume, threshold)
        verts = verts / density_volume.shape[0] * 1.0 - 0.6
        
        mesh = ax.plot_trisurf(verts[:, 0], verts[:, 1], faces, verts[:, 2],color=colors[i], alpha=alphas[i], shade=True)
    
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_title('多阈值密度等值面')
    
    
    ax.set_xlim([-0.6, 0.6])
    ax.set_ylim([-0.6, 0.6])
    ax.set_zlim([-0.6, 0.6])
    
    plt.tight_layout()
    plt.show()

def visualize_plotly_isosurface(density_volume, threshold_ratio=0.3, filename=None):
    """
    使用Plotly创建交互式3D等值面可视化
    
    参数:
        density_volume: 3D密度体积数组
        threshold_ratio: 阈值比例 (0.0-1.0)
        filename: 可选，保存为HTML文件的路径
    """
    # 计算阈值
    vmin, vmax = density_volume.min(), density_volume.max()
    threshold = vmin + (vmax - vmin) * threshold_ratio
    
    print(f"正在提取等值面，阈值: {threshold:.4f}...")
    # 提取等值面
    verts, faces, normals, values = measure.marching_cubes(density_volume, threshold)
    
    # 调整顶点坐标到[-0.6, 0.6]范围
    verts = verts / density_volume.shape[0] * 1.2 - 0.6
    
    # 创建Plotly的Mesh3d对象
    fig = go.Figure(data=[go.Mesh3d(
        x=verts[:, 0],
        y=verts[:, 1],
        z=verts[:, 2],
        i=faces[:, 0],
        j=faces[:, 1],
        k=faces[:, 2],
        opacity=0.8,
        colorscale='Blues',
        intensity=values,
        showscale=True
    )])

    # 可视化相机视角向量
    LOS_dirs_batch_numpy = LOS_dirs_batch_tensor[0,:].cpu().numpy()
    fig.add_trace(go.Scatter3d(
        x=[0, -LOS_dirs_batch_numpy[0]*0.5],
        y=[0, -LOS_dirs_batch_numpy[1]*0.5],
        z=[0, -LOS_dirs_batch_numpy[2]*0.5],
        mode='lines+markers',
        marker=dict(size=2),
        line=dict(width=2)
    ))

    # # 设置相机视角
    # camera = LOS_dirs_batch_tensor[0,:].cpu().numpy()
    # fig.update_layout(
    #     scene_camera=dict(
    #         eye=dict(x=camera[0]*3, y=camera[1]*3, z=camera[2]*3),
    #         up=dict(x=0, y=0, z=1)
    #     )
    # )

    # 设置布局
    fig.update_layout(
        title=f'NeRF密度等值面 (阈值: {threshold:.4f})',
        scene=dict(
            xaxis=dict(range=[-0.6, 0.6], title='X'),
            yaxis=dict(range=[-0.6, 0.6], title='Y'),
            zaxis=dict(range=[-0.6, 0.6], title='Z'),
            aspectmode='cube'
        ),
        width=900,
        height=900,
    )
    
    # 导出OBJ文件(如果需要)
    if filename and filename.endswith('.obj'):
        with open(filename, 'w') as f:
            f.write(f"# 顶点数: {verts.shape[0]}\n")
            f.write(f"# 面数: {faces.shape[0]}\n\n")
            
            for v in verts:
                f.write(f"v {v[0]} {v[1]} {v[2]}\n")
            
            for face in faces:
                f.write(f"f {face[0]+1} {face[1]+1} {face[2]+1}\n")
        print(f"OBJ文件导出成功: {filename}")

    
    
    # 保存HTML(如果需要)
    if filename and filename.endswith('.html'):
        fig.write_html(filename)
        print(f"交互式HTML文件导出成功: {filename}")
        
    return fig

def visualize_plotly_multi_isosurfaces(density_volume, thresholds=[0.3, 0.5, 0.7], colorscales=['Blues', 'Greens', 'Reds'], filename=None):
    """
    使用Plotly创建多个交互式3D等值面可视化
    
    参数:
        density_volume: 3D密度体积数组
        thresholds: 阈值比例列表
        colorscales: 颜色映射列表
        filename: 可选，保存为HTML文件的路径
    """
    fig = go.Figure()
    vmin, vmax = density_volume.min(), density_volume.max()
    
    # 为每个阈值添加一个等值面
    for i, threshold_ratio in enumerate(thresholds):
        threshold = vmin + (vmax - vmin) * threshold_ratio
        verts, faces, normals, values = measure.marching_cubes(density_volume, threshold)
        verts = verts / density_volume.shape[0] * 1.2 - 0.6
        
        opacity = 0.9 - i * 0.2  # 随着阈值增加，透明度增加
        
        fig.add_trace(go.Mesh3d(
            x=verts[:, 0],
            y=verts[:, 1],
            z=verts[:, 2],
            i=faces[:, 0],
            j=faces[:, 1],
            k=faces[:, 2],
            opacity=max(0.3, opacity),
            colorscale=colorscales[i % len(colorscales)],
            intensity=values,
            showscale=True,
            name=f'阈值: {threshold:.4f}'
        ))
    
    fig.update_layout(
        title='NeRF密度多阈值等值面',
        scene=dict(
            xaxis=dict(range=[-0.6, 0.6], title='X'),
            yaxis=dict(range=[-0.6, 0.6], title='Y'),
            zaxis=dict(range=[-0.6, 0.6], title='Z'),
            aspectmode='cube'
        ),
        width=900,
        height=900,
    )
    
    if filename and filename.endswith('.html'):
        fig.write_html(filename)
        print(f"交互式HTML文件导出成功: {filename}")
        
    return fig

# 添加一个结合切片和等值面的可视化函数
# def visualize_plotly_combined(density_volume, threshold_ratio=0.3, slice_pos=None, filename=None):
    """
    结合等值面和切片的可视化
    
    参数:
        density_volume: 3D密度体积数组
        threshold_ratio: 阈值比例
        slice_pos: 切片位置[x, y, z]，如果为None则使用中间位置
        filename: 可选的HTML输出文件名
    """
    if slice_pos is None:
        slice_pos = [density_volume.shape[0]//2, 
                    density_volume.shape[1]//2, 
                    density_volume.shape[2]//2]
    
    fig = go.Figure()
    vmin, vmax = density_volume.min(), density_volume.max()
    threshold = vmin + (vmax - vmin) * threshold_ratio
    
    # 添加等值面
    verts, faces, normals, values = measure.marching_cubes(density_volume, threshold)
    verts = verts / density_volume.shape[0] * 1.2 - 0.6
    
    fig.add_trace(go.Mesh3d(
        x=verts[:, 0],
        y=verts[:, 1],
        z=verts[:, 2],
        i=faces[:, 0],
        j=faces[:, 1],
        k=faces[:, 2],
        opacity=0.7,
        colorscale='Blues',
        intensity=values,
        showscale=True,
        name=f'等值面: {threshold:.4f}'
    ))
    
    # 添加X、Y、Z切片
    x, y, z = np.meshgrid(
        np.linspace(-0.6, 0.6, density_volume.shape[0]),
        np.linspace(-0.6, 0.6, density_volume.shape[1]),
        np.linspace(-0.6, 0.6, density_volume.shape[2]),
        indexing='ij'
    )
    
    # X切片
    fig.add_trace(go.Volume(
        x=x.flatten(),
        y=y.flatten(),
        z=z.flatten(),
        value=density_volume.flatten(),
        opacity=0.2,
        surface_count=1,
        colorscale='Viridis',
        caps=dict(x=dict(show=False), y=dict(show=False), z=dict(show=False)),
        slices=dict(
            x=dict(show=True, locations=[x[slice_pos[0],0,0]]),
            y=dict(show=True, locations=[y[0,slice_pos[1],0]]),
            z=dict(show=True, locations=[z[0,0,slice_pos[2]]])
        ),
        name='密度切片'
    ))
    
    fig.update_layout(
        title='NeRF密度等值面与切片',
        scene=dict(
            xaxis=dict(range=[-0.6, 0.6], title='X'),
            yaxis=dict(range=[-0.6, 0.6], title='Y'),
            zaxis=dict(range=[-0.6, 0.6], title='Z'),
            aspectmode='cube'
        ),
        width=900,
        height=900,
    )
    
    if filename and filename.endswith('.html'):
        fig.write_html(filename)
        print(f"交互式HTML文件导出成功: {filename}")
        
    return fig

def visualize_slice_imagesc(density_volume, slice_direction='z', slice_index=None, 
                           cmap='viridis', figsize=(10, 8), title=None, 
                           save_path=None, scale_factor=1.2, offset=-0.6):
    """
    使用类似MATLAB imagesc的方式可视化密度体积的单个切片
    
    参数:
        density_volume: 3D密度体积数组
        slice_direction: 切片方向，'x', 'y' 或 'z'
        slice_index: 切片索引，如果为None则使用中间切片
        cmap: 颜色映射
        figsize: 图像大小
        title: 图像标题，如果为None则自动生成
        save_path: 保存图像的路径，如果为None则不保存
        scale_factor: 坐标比例因子，用于将体素坐标映射到真实空间
        offset: 坐标偏移，用于将体素坐标映射到真实空间
    
    返回:
        fig, ax: matplotlib的图形和轴对象
    """
    # 设置切片索引
    if slice_index is None:
        if slice_direction == 'x':
            slice_index = density_volume.shape[0] // 2
        elif slice_direction == 'y':
            slice_index = density_volume.shape[1] // 2
        else:  # z
            slice_index = density_volume.shape[2] // 2
    
    # 提取切片
    if slice_direction == 'x':
        slice_data = density_volume[slice_index, :, :]
        extent = [offset, offset + scale_factor, offset, offset + scale_factor]
        xlabel, ylabel = 'Y', 'Z'
    elif slice_direction == 'y':
        slice_data = density_volume[:, slice_index, :]
        extent = [offset, offset + scale_factor, offset, offset + scale_factor]
        xlabel, ylabel = 'X', 'Z'
    else:  # z
        slice_data = density_volume[:, :, slice_index]
        extent = [offset, offset + scale_factor, offset, offset + scale_factor]
        xlabel, ylabel = 'X', 'Y'
    
    # 计算真实空间坐标
    real_coord = offset + (slice_index / density_volume.shape[0]) * scale_factor
    
    # 创建图形
    fig, ax = plt.subplots(figsize=figsize)
    
    # 使用imshow显示切片，类似MATLAB的imagesc
    im = ax.imshow(slice_data, 
                   cmap=cmap, 
                   interpolation='nearest', 
                   aspect='equal', 
                   origin='lower', 
                   extent=extent)
    
    # 添加颜色条
    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label('colorbar')
    
    # 设置标题和坐标轴标签
    if title is None:
        title = f'{slice_direction.upper()} slice ({slice_direction}={real_coord:.2f})'
    ax.set_title(title)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    
    # 设置坐标轴刻度
    ax.set_xticks(np.linspace(offset, offset + scale_factor, 5))
    ax.set_yticks(np.linspace(offset, offset + scale_factor, 5))
    
    # 添加网格
    ax.grid(color='white', linestyle='-', linewidth=0.3, alpha=0.7)
    
    plt.tight_layout()
    
    # 保存图像
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"图像已保存到: {save_path}")
    
    return fig, ax

def visualize_plotly_isosurface_rotate_x(density_volume, degree1=46.7,degree2=162, threshold_ratio=0.01, filename=None):
    """
    使用Plotly创建等值面可视化，将模型绕X轴顺时针旋转指定角度
    
    参数:
        density_volume: 3D密度体积数组
        degree: 顺时针旋转角度（度）
        threshold_ratio: 阈值比例 (0.0-1.0)
        filename: 可选，保存为HTML文件的路径
    """
    from scipy.spatial.transform import Rotation as R
    
    # 计算阈值
    vmin, vmax = density_volume.min(), density_volume.max()
    threshold = vmin + (vmax - vmin) * threshold_ratio
    
    print(f"正在提取等值面，阈值: {threshold:.4f}...")
    # 提取等值面
    verts, faces, normals, values = measure.marching_cubes(density_volume, threshold)
    
    # 调整顶点坐标到[-0.6, 0.6]范围
    verts = verts / density_volume.shape[0] * 1.2 - 0.6

    # 将顶点在X轴方向缩放2倍
    verts[:, 0] =verts[:, 0]/np.cos(np.radians(degree1))  # 缩放X轴坐标
    print(np.radians(degree1))
    print(np.sin(np.radians(degree1)))
    
    # 计算绕X轴旋转的旋转矩阵 (顺时针, 所以角度为负)
    angle_rad = -np.radians(degree1)  # 转换为弧度，负号表示顺时针
    rotation_matrix = np.array([
        [1, 0, 0],
        [0, np.cos(angle_rad), -np.sin(angle_rad)],
        [0, np.sin(angle_rad), np.cos(angle_rad)]
    ])

    # 计算绕Z轴旋转的旋转矩阵
    angle_rad2 = -np.radians(degree2)  # 转换为弧度，负号表示顺时针
    rotation_matrix2 = np.array([
        [np.cos(angle_rad2), -np.sin(angle_rad2), 0],
        [np.sin(angle_rad2), np.cos(angle_rad2), 0],
        [0, 0, 1]
    ])
    
    # 应用旋转矩阵到所有顶点
    rotated_verts = np.dot(verts, rotation_matrix.T)
    rotated_verts = np.dot(rotated_verts, rotation_matrix2.T)  # 先绕X轴旋转，再绕Y轴旋转
    
    # 计算旋转后的坐标轴
    x_axis = np.array([1, 0, 0])
    y_axis = np.dot(rotation_matrix, np.array([0, 1, 0]))
    y_axis = np.dot(y_axis, rotation_matrix2)  # 先绕X轴旋转，再绕Y轴旋转
    z_axis = np.dot(rotation_matrix, np.array([0, 0, 1]))
    z_axis = np.dot(z_axis, rotation_matrix2)  # 先绕X轴旋转，再绕Y轴旋转
    
    # 创建Plotly的Mesh3d对象
    fig = go.Figure(data=[go.Mesh3d(
        x=rotated_verts[:, 0],
        y=-rotated_verts[:, 1],
        z=rotated_verts[:, 2],
        i=faces[:, 0],
        j=faces[:, 1],
        k=faces[:, 2],
        opacity=0.5,
        color='lightgray',  # 直接设置为灰色
        # colorscale='Blues',  # 注释掉原来的colorscale
        # intensity=values,     # 注释掉intensity
        # showscale=True       # 注释掉showscale
    )])

    # 可视化旋转后的坐标系
    LOS_dir = LOS_dirs_batch_tensor[0,:].cpu().numpy()
    # # X轴 (红色)
    # fig.add_trace(go.Scatter3d(
    #     x=[0, -LOS_dir[2]],
    #     y=[0, -LOS_dir[0]],
    #     z=[0, LOS_dir[1]],
    #     mode='lines+markers',
    #     line=dict(color='red', width=3),
    #     marker=dict(size=4),
    #     name='视线方向'
    # ))

    # 设置布局
    fig.update_layout(
        title=f'NeRF密度等值面 (绕X轴顺时针旋转{degree1}度)',
        scene=dict(
            xaxis=dict(
                range=[-0.8, 0.8], 
                title='X',
                visible=False,  # 隐藏X轴
                showgrid=False,  # 隐藏网格
                showline=False,  # 隐藏轴线
                showticklabels=False  # 隐藏刻度标签
            ),
            yaxis=dict(
                range=[-0.8, 0.8], 
                title='Y',
                visible=False,  # 隐藏Y轴
                showgrid=False,  # 隐藏网格
                showline=False,  # 隐藏轴线
                showticklabels=False  # 隐藏刻度标签
            ),
            zaxis=dict(
                range=[-0.8, 0.8], 
                title='Z',
                visible=False,  # 隐藏Z轴
                showgrid=False,  # 隐藏网格
                showline=False,  # 隐藏轴线
                showticklabels=False  # 隐藏刻度标签
            ),
            aspectmode='cube',
            bgcolor='rgba(0,0,0,0)',  # 设置背景为透明
            camera=dict(
                projection=dict(type='orthographic')  # 可选：使用正交投影
            )
        ),
        width=1500,
        height=1500,
        legend=dict(x=0.7, y=0.1),
        paper_bgcolor='rgba(0,0,0,0)',  # 设置整个图形背景为透明
        plot_bgcolor='rgba(0,0,0,0)'   # 设置绘图区域背景为透明
    )
    
    filename_png = filename + "/model_3d_view_rotate_x.png"
    # # 保存PNG图片
    # img_bytes = fig.to_image(format="png")
    # # 转换为PIL图像
    # img = Image.open(io.BytesIO(img_bytes))
    # # 保存
    # img.save(filename_png)

    # 定义OBJ文件名
    obj_filename = filename + "/mesh_export_rotate_x.obj"

    # 保存HTML
    if filename and filename.endswith('.html'):
        fig.write_html(filename)
        print(f"交互式HTML文件导出成功: {filename}")


    # 保存为obj文件
    if obj_filename and obj_filename.endswith('.obj'):
        with open(obj_filename, 'w') as f:
            f.write(f"# OBJ文件由NeRF模型生成\n")
            f.write(f"# 顶点数: {rotated_verts.shape[0]}\n")
            f.write(f"# 面数: {faces.shape[0]}\n\n")
            
            for v in rotated_verts:
                f.write(f"v {v[0]} {v[1]} {v[2]}\n")
            
            for face in faces:
                f.write(f"f {face[0]+1} {face[1]+1} {face[2]+1}\n")
        print(f"OBJ文件导出成功: {obj_filename}")
     
    return fig

In [None]:
# visualize_slices(ray_distribution)

# print(ray_distribution.shape)
# visualize_isosurfaces(ray_distribution,threshold_ratio=0.1)

# visualize_multi_isosurfaces(ray_distribution, thresholds=[0.25, 0.5, 0.75])

# fig = visualize_plotly_isosurface(ray_distribution, threshold_ratio=0.01, 
#                                 filename='/DATA/disk1/asteroid/asteroid_inverse/Instant-ngp/model/experiment52')
# fig.show()

# 使用多个阈值创建多等值面可视化
# fig2 = visualize_plotly_multi_isosurfaces(ray_distribution, 
#                                       thresholds=[0.1, 0.2, 0.3],
#                                       filename='/DATA/disk1/Instant-ngp/model/experiment60/multi_isosurface1.html')
# fig2.show()

# # 结合等值面和切片的综合可视化
# fig3 = visualize_plotly_combined(ray_distribution, threshold_ratio=0.3,
#                                filename='/DATA/disk1/Instant-ngp/model/experiment27/combined_view.html')
# fig3.show()

# 对ray_distribution其中一个切片进行成像，不进行marchingcube算法
visualize_slice_imagesc(ray_distribution, slice_direction='x', slice_index=40, 
                       cmap='jet', title='Density Distribution X=50 Slice')
plt.show()
visualize_slice_imagesc(ray_distribution, slice_direction='x', slice_index=45, 
                       cmap='jet', title='Density Distribution X=50 Slice')
plt.show()
visualize_slice_imagesc(ray_distribution, slice_direction='x', slice_index=50, 
                       cmap='jet', title='Density Distribution X=50 Slice')
plt.show()
visualize_slice_imagesc(ray_distribution, slice_direction='x', slice_index=55, 
                       cmap='jet', title='Density Distribution X=50 Slice')
plt.show()
visualize_slice_imagesc(ray_distribution, slice_direction='x', slice_index=60, 
                       cmap='jet', title='Density Distribution X=50 Slice')
plt.show()
visualize_slice_imagesc(ray_distribution, slice_direction='x', slice_index=65, 
                       cmap='jet', title='Density Distribution X=50 Slice')
plt.show()

# # 根据LOS_dir计算需要旋转的角度
# def calculate_rotation_angle(LOS_dir):
#     """
#     计算LOS方向向量与X轴之间的旋转角度（顺时针）
    
#     参数:
#         LOS_dir: 3D LOS方向向量 (numpy array)
    
#     返回:
#         旋转角度（度）
#     """
#     # 计算LOS方向向量在XZ平面上的投影
#     xz_projection = np.array([LOS_dir[0],0, LOS_dir[2]])
    
#     # 计算投影与原向量之间的夹角
#     if np.linalg.norm(xz_projection) == 0:
#         return 0, 0  # 如果投影为零向量，返回0角度
#     xz_projection = xz_projection / np.linalg.norm(xz_projection)  # 归一化投影向量
#     LOS_dir = LOS_dir / np.linalg.norm(LOS_dir)  # 归一化LOS方向向量
#     # 计算投影与LOS方向向量的夹角
#     angle_rad1 = np.arccos(np.clip(np.dot(xz_projection, LOS_dir), -1.0, 1.0))
    
#     angle_deg1 = np.degrees(angle_rad1)  # 转换为度数

#     # 计算投影的余弦值
#     angle_rad2 = np.arctan2(xz_projection[2], xz_projection[0])
#     angle_deg2 = np.degrees(angle_rad2)
    
#     return angle_deg1, angle_deg2

# angle1,angle2 = calculate_rotation_angle(LOS_dirs_batch_tensor[0,:].cpu().numpy())
# print(f"计算得到的旋转角度: {angle1:.2f}度")
# print(f"计算得到的旋转角度2: {angle2:.2f}度")

# # 调用函数进行旋转
# fig = visualize_plotly_isosurface_rotate_x(
#     ray_distribution,
#     degree1=angle1,  
#     degree2=angle2,  
#     threshold_ratio=0.2,
#     filename='/DATA/disk1/asteroid/asteroid_inverse/Instant-ngp/model/' + modelname
# )
# fig.show()

In [None]:
import trimesh
from scipy.spatial.distance import directed_hausdorff
from sklearn.metrics import mean_squared_error

def load_obj_model(obj_path):
    """
    加载OBJ模型并返回顶点和面
    
    参数:
        obj_path: OBJ文件路径
    
    返回:
        vertices: 顶点数组
        faces: 面数组
    """
    try:
        # 使用trimesh加载OBJ文件
        mesh = trimesh.load_mesh(obj_path)
        return mesh.vertices, mesh.faces
    except:
        # 如果trimesh失败，使用简单的OBJ解析器
        vertices = []
        faces = []
        
        with open(obj_path, 'r') as f:
            for line in f:
                if line.startswith('v '):
                    # 解析顶点
                    coords = list(map(float, line.strip().split()[1:]))
                    vertices.append(coords)
                elif line.startswith('f '):
                    # 解析面
                    face_indices = []
                    for vertex in line.strip().split()[1:]:
                        # 处理可能的纹理/法线索引 (v/vt/vn格式)
                        face_indices.append(int(vertex.split('/')[0]) - 1)  # OBJ索引从1开始
                    faces.append(face_indices)
        
        return np.array(vertices), np.array(faces)

def compare_meshes(mesh1_verts, mesh1_faces, mesh2_verts, mesh2_faces, angle1=0, angle2=0,
                   sample_points=10000, normalize=True):
    """
    比较两个网格模型的相似度
    
    参数:
        mesh1_verts, mesh1_faces: 第一个模型的顶点和面
        mesh2_verts, mesh2_faces: 第二个模型的顶点和面
        sample_points: 用于比较的采样点数
        normalize: 是否归一化模型到相同尺度
    
    返回:
        字典，包含各种相似度指标
    """
    # 归一化到相同尺度
    if normalize:
        # 计算包围盒并归一化
        bbox1 = np.max(mesh1_verts, axis=0) - np.min(mesh1_verts, axis=0)
        bbox2 = np.max(mesh2_verts, axis=0) - np.min(mesh2_verts, axis=0)
        
        scale1 = np.max(bbox1)
        scale2 = np.max(bbox2)
        
        mesh1_verts_norm = mesh1_verts / scale1
        mesh2_verts_norm = mesh2_verts / scale2
        
        # 中心化
        center1 = np.mean(mesh1_verts_norm, axis=0)
        center2 = np.mean(mesh2_verts_norm, axis=0)
        
        mesh1_verts_norm -= center1
        mesh2_verts_norm -= center2
    else:
        mesh1_verts_norm = mesh1_verts
        mesh2_verts_norm = mesh2_verts

    # 如果需要，应用旋转
    if angle1 != 0 or angle2 != 0: 
        mesh1_verts_norm = rotate_mesh(mesh1_verts_norm, degree1=angle1, degree2=angle2)
    
    # 使用trimesh进行采样（如果可用）
    try:
        mesh1_trimesh = trimesh.Trimesh(vertices=mesh1_verts_norm, faces=mesh1_faces)
        mesh2_trimesh = trimesh.Trimesh(vertices=mesh2_verts_norm, faces=mesh2_faces)
        
        # 在表面均匀采样点
        samples1, _ = trimesh.sample.sample_surface_even(mesh1_trimesh, sample_points)
        samples2, _ = trimesh.sample.sample_surface_even(mesh2_trimesh, sample_points)
        
    except:
        # 如果trimesh不可用，使用简单采样
        print("Trimesh不可用，使用简单顶点采样")
        # 随机采样顶点
        if len(mesh1_verts_norm) > sample_points:
            idx1 = np.random.choice(len(mesh1_verts_norm), sample_points, replace=False)
            samples1 = mesh1_verts_norm[idx1]
        else:
            samples1 = mesh1_verts_norm
            
        if len(mesh2_verts_norm) > sample_points:
            idx2 = np.random.choice(len(mesh2_verts_norm), sample_points, replace=False)
            samples2 = mesh2_verts_norm[idx2]
        else:
            samples2 = mesh2_verts_norm
    
    # 计算各种相似度指标
    metrics = {}
    
    # 1. Hausdorff距离
    try:
        hausdorff_dist_1to2 = directed_hausdorff(samples1, samples2)[0]
        hausdorff_dist_2to1 = directed_hausdorff(samples2, samples1)[0]
        metrics['hausdorff_distance'] = max(hausdorff_dist_1to2, hausdorff_dist_2to1)
        metrics['hausdorff_1to2'] = hausdorff_dist_1to2
        metrics['hausdorff_2to1'] = hausdorff_dist_2to1
    except Exception as e:
        print(f"计算Hausdorff距离时出错: {e}")
        metrics['hausdorff_distance'] = None
    
    # 2. 最近邻距离的平均值
    from scipy.spatial import cKDTree
    
    tree1 = cKDTree(samples1)
    tree2 = cKDTree(samples2)
    
    # 从mesh1到mesh2的距离
    distances_1to2, _ = tree2.query(samples1)
    # 从mesh2到mesh1的距离
    distances_2to1, _ = tree1.query(samples2)
    
    metrics['mean_distance_1to2'] = np.mean(distances_1to2)
    metrics['mean_distance_2to1'] = np.mean(distances_2to1)
    metrics['mean_distance_symmetric'] = (metrics['mean_distance_1to2'] + metrics['mean_distance_2to1']) / 2
    
    # 3. RMS距离
    metrics['rms_distance_1to2'] = np.sqrt(np.mean(distances_1to2**2))
    metrics['rms_distance_2to1'] = np.sqrt(np.mean(distances_2to1**2))
    metrics['rms_distance_symmetric'] = (metrics['rms_distance_1to2'] + metrics['rms_distance_2to1']) / 2
    
    # 4. 体积比较（如果使用trimesh）
    try:
        volume1 = mesh1_trimesh.volume
        volume2 = mesh2_trimesh.volume
        metrics['volume_ratio'] = min(volume1, volume2) / max(volume1, volume2)
        metrics['volume_difference'] = abs(volume1 - volume2)
    except:
        metrics['volume_ratio'] = None
        metrics['volume_difference'] = None
    
    # 5. 表面积比较
    try:
        area1 = mesh1_trimesh.area
        area2 = mesh2_trimesh.area
        metrics['area_ratio'] = min(area1, area2) / max(area1, area2)
        metrics['area_difference'] = abs(area1 - area2)
    except:
        metrics['area_ratio'] = None
        metrics['area_difference'] = None
    
    # 6. 顶点数量比较
    metrics['vertex_count_ratio'] = min(len(mesh1_verts), len(mesh2_verts)) / max(len(mesh1_verts), len(mesh2_verts))
    metrics['face_count_ratio'] = min(len(mesh1_faces), len(mesh2_faces)) / max(len(mesh1_faces), len(mesh2_faces))
    
    return metrics

# 设计一个旋转函数
def rotate_mesh(verts, degree1=46.7, degree2=162):
    # 将顶点在X轴方向缩放2倍
    verts[:, 0] =verts[:, 0]/np.cos(np.radians(degree1))  # 缩放X轴坐标
    print(np.radians(degree1))
    print(np.sin(np.radians(degree1)))
    
    # 计算绕X轴旋转的旋转矩阵 (顺时针, 所以角度为负)
    angle_rad = -np.radians(degree1)  # 转换为弧度，负号表示顺时针
    rotation_matrix = np.array([
        [1, 0, 0],
        [0, np.cos(angle_rad), -np.sin(angle_rad)],
        [0, np.sin(angle_rad), np.cos(angle_rad)]
    ])

    # 计算绕Z轴旋转的旋转矩阵
    angle_rad2 = -np.radians(degree2)  # 转换为弧度，负号表示顺时针
    rotation_matrix2 = np.array([
        [np.cos(angle_rad2), -np.sin(angle_rad2), 0],
        [np.sin(angle_rad2), np.cos(angle_rad2), 0],
        [0, 0, 1]
    ])
    
    # 应用旋转矩阵到所有顶点
    rotated_verts = np.dot(verts, rotation_matrix.T)
    rotated_verts = np.dot(rotated_verts, rotation_matrix2.T)  # 先绕X轴旋转，再绕Y轴旋转

    return rotated_verts

def visualize_mesh_comparison(mesh1_verts, mesh1_faces, mesh2_verts, mesh2_faces, angle1=46.7, angle2=162,
                             title1="模型1", title2="模型2",normalize=True):
    """
    可视化两个模型的比较
    """
    # mesh1_verts[:, 0] =mesh1_verts[:, 0]/np.cos(np.radians(angle1))  # 缩放X轴坐标
    # 归一化到相同尺度
    if normalize:
        # 计算包围盒并归一化
        # bbox1 = np.max(mesh1_verts, axis=0) - np.min(mesh1_verts, axis=0)
        bbox2 = np.max(mesh2_verts, axis=0) - np.min(mesh2_verts, axis=0)
        
        # scale1 = np.max(bbox1)
        scale2 = np.max(bbox2)
        
        # mesh1_verts_norm = mesh1_verts / scale1
        mesh2_verts_norm = mesh2_verts / scale2
        
        # 中心化
        # center1 = np.mean(mesh1_verts_norm, axis=0)
        center2 = np.mean(mesh2_verts_norm, axis=0)
        
        # mesh1_verts_norm -= center1
        mesh1_verts_norm = mesh1_verts
        mesh2_verts_norm -= center2

        
        # 关于Y轴对称
        mesh2_verts_norm[:, 1] = -mesh2_verts_norm[:, 1]
    else:
        mesh1_verts_norm = mesh1_verts
        mesh2_verts_norm = mesh2_verts

    mesh1_verts_norm = rotate_mesh(mesh1_verts_norm, degree1=angle1, degree2=angle2+90)

    # 创建Plotly的Mesh3d对象
    fig = go.Figure()
    
    fig.add_trace(go.Mesh3d(
        x=mesh1_verts_norm[:, 0],
        y=mesh1_verts_norm[:, 1],
        z=mesh1_verts_norm[:, 2],
        i=mesh1_faces[:, 0],
        j=mesh1_faces[:, 1],
        k=mesh1_faces[:, 2],
        opacity=0.5,
        color='red',
        name=title1
    ))

    # 以zxy输入参考三维模型
    fig.add_trace(go.Mesh3d(
        x=-mesh2_verts_norm[:, 2],
        y=mesh2_verts_norm[:, 0],
        z=-mesh2_verts_norm[:, 1],
        i=mesh2_faces[:, 0],
        j=mesh2_faces[:, 1],
        k=mesh2_faces[:, 2],
        opacity=0.5,
        color='blue',
        name=title2
    ))
    
    fig.update_layout(
        title="模型比较",
        scene=dict(aspectmode='cube'),
        width=900,
        height=900,
    )

    # 固定坐标取值范围
    fig.update_layout(
        scene=dict(
            xaxis=dict(range=[-0.8, 0.8], title='X'),
            yaxis=dict(range=[-0.8, 0.8], title='Y'),
            zaxis=dict(range=[-0.8, 0.8], title='Z'),
            aspectmode='cube'
        )
    )
    return fig

# 使用示例
def compare_with_reference_model(ray_distribution, threshold_ratio=0.3, 
                                reference_obj_path=None,angle1=46.7, angle2=162):
    """
    将当前NeRF模型与参考OBJ模型进行比较
    
    参数:
        ray_distribution: NeRF输出的密度分布
        threshold_ratio: 用于生成等值面的阈值
        reference_obj_path: 参考OBJ模型的路径
    
    返回:
        比较结果字典
    """
    # 从NeRF密度生成当前模型
    vmin, vmax = ray_distribution.min(), ray_distribution.max()
    threshold = vmin + (vmax - vmin) * threshold_ratio
    
    current_verts, current_faces, _, _ = measure.marching_cubes(ray_distribution, threshold)
    current_verts = current_verts / ray_distribution.shape[0] * 1.2 - 0.6
    
    # 加载参考模型
    if reference_obj_path and os.path.exists(reference_obj_path):
        
        ref_verts, ref_faces = load_obj_model(reference_obj_path)
        
        # 比较模型
        comparison_metrics = compare_meshes(current_verts, current_faces, 
                                          ref_verts, ref_faces,angle1=angle1, angle2=angle2)
        
        # 可视化比较
        fig = visualize_mesh_comparison(
            current_verts, current_faces, 
            ref_verts, ref_faces,
            angle1=angle1, angle2=angle2,
            title1="NeRF模型", title2="参考模型"
        )
        
        return comparison_metrics, fig
    else:
        print(f"参考模型文件不存在: {reference_obj_path}")
        return None, None

# 在您的代码中添加比较功能
reference_model_path = "/DATA/disk1/asteroid/asteroid_inverse/Instant-ngp/new_dataset/model_real/Geographos Radar-based, low-res.obj"
metrics, comparison_fig = compare_with_reference_model(
    ray_distribution, 
    threshold_ratio=0.35,
    reference_obj_path=reference_model_path,
    angle1=angle1,  
    angle2=angle2,  
)

if metrics:
    print("模型比较结果:")
    print(f"Hausdorff距离: {metrics['hausdorff_distance']:.6f}")
    print(f"平均对称距离: {metrics['mean_distance_symmetric']:.6f}")
    print(f"RMS对称距离: {metrics['rms_distance_symmetric']:.6f}")
    print(f"体积比: {metrics['volume_ratio']:.6f}")
    print(f"表面积比: {metrics['area_ratio']:.6f}")
    print(f"顶点数比: {metrics['vertex_count_ratio']:.6f}")
    
    comparison_fig.show()

In [None]:

# 创建Plotly的Mesh3d对象
fig = go.Figure()

mesh_path = '/DATA/disk1/asteroid/asteroid_inverse/Instant-ngp/new_dataset/model_real/Arrokoth Stern 2019.obj'
# 加载参考模型
mesh1_verts, mesh1_faces = load_obj_model(mesh_path)

# 模型归一化
mesh1_verts = mesh1_verts / np.max(np.abs(mesh1_verts))  # 归一化到[-1, 1]范围

# 生成新模型，模型参数为zxy输入
mesh2_verts = mesh1_verts.copy()
print(mesh1_verts.shape, mesh1_faces.shape)
# print(mesh2_verts.shape, mesh2_faces.shape)
mesh2_verts[:,0] = mesh1_verts[:,2]
mesh2_verts[:,1] = mesh1_verts[:,0]
mesh2_verts[:,2] = mesh1_verts[:,1]

# 对新模型绕Z轴旋转90度
rotation_matrix = np.array([
    [0, -1, 0],
    [1, 0, 0],
    [0, 0, 1]
])
mesh2_verts = np.dot(mesh2_verts, rotation_matrix.T)

# 以zxy输入参考三维模型
fig.add_trace(go.Mesh3d(
    x=mesh2_verts[:, 0],
    y=mesh2_verts[:, 1],
    z=mesh2_verts[:, 2],
    i=mesh1_faces[:, 0],
    j=mesh1_faces[:, 1],
    k=mesh1_faces[:, 2],
    # opacity=0.5,
    color='lightgray',
))


fig.update_layout(
    title="模型比较",
    scene=dict(aspectmode='cube'),
    width=900,
    height=900,
)

fig.update_layout(
    scene=dict(
        xaxis=dict(
            range=[-1, 1], 
            title='X',
            visible=False,  # 隐藏X轴
            showgrid=False,  # 隐藏网格
            showline=False,  # 隐藏轴线
            showticklabels=False  # 隐藏刻度标签
        ),
        yaxis=dict(
            range=[-1, 1], 
            title='Y',
            visible=False,  # 隐藏Y轴
            showgrid=False,  # 隐藏网格
            showline=False,  # 隐藏轴线
            showticklabels=False  # 隐藏刻度标签
        ),
        zaxis=dict(
            range=[-1, 1], 
            title='Z',
            visible=False,  # 隐藏Z轴
            showgrid=False,  # 隐藏网格
            showline=False,  # 隐藏轴线
            showticklabels=False  # 隐藏刻度标签
        ),
        aspectmode='cube',
        bgcolor='rgba(0,0,0,0)',  # 设置背景为透明
        camera=dict(
            # eye=dict(x=0.2, y=0.2, z=0.2),  # 非常近的距离
            # center=dict(x=0, y=0, z=0),
            # up=dict(x=0, y=0, z=1),
            projection=dict(type='orthographic')  # 可选：使用正交投影
        )
    ),
    width=1500,
    height=1500,
    legend=dict(x=0.7, y=0.1),
    paper_bgcolor='rgba(0,0,0,0)',  # 设置整个图形背景为透明
    plot_bgcolor='rgba(0,0,0,0)'   # 设置绘图区域背景为透明
)