In [74]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [75]:
import os
import sys
import torch
import hydra
from pathlib import Path
from lightning import Fabric

# 设置 torch.compile 兼容性
try:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
except ImportError:
    # PyTorch 版本 < 2.0 不支持 torch._dynamo
    print("Warning: torch._dynamo not available in this PyTorch version")

## set up environment
project_root = Path(os.getcwd()).parent
sys.path.insert(0, str(project_root))

from funcmol.utils.constants import PADDING_INDEX
from funcmol.utils.gnf_visualizer import (
    load_config_from_exp_dir, load_model, 
    create_converter, GNFVisualizer, create_gnf_converter
)
from funcmol.dataset.dataset_field import create_field_loaders
from torch_geometric.utils import to_dense_batch

# 模型根目录
model_root = "/datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/neural_field"

In [76]:
# 设置参数
exp_name = 'nf_qm9_20250804_153549_358664'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 判断是gt_only还是gt_pred模式
if '2025' in exp_name:  # gt + predicted field
    option = 'gt_pred'
    model_dir = os.path.join(model_root, exp_name)
else:  # gt only. exp_name is the name of field (e.g., gaussian_mag)
    option = 'gt_only'
    model_dir = os.path.join(model_root, exp_name)
    os.makedirs(model_dir, exist_ok=True)

output_dir = model_dir
print(f"Option: {option}")
print(f"Model directory: {model_dir}")
print(f"Output directory: {output_dir}")

Option: gt_pred
Model directory: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/neural_field/nf_qm9_20250804_153549_358664
Output directory: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/neural_field/nf_qm9_20250804_153549_358664


In [None]:
## Load data
fabric = Fabric(
    accelerator="auto",
    devices=1,
    precision="32-true",
    strategy="auto"
)
fabric.launch()

# 从实验目录加载配置
config = load_config_from_exp_dir(model_dir)

# 准备数据 - 直接从数据集加载，确保数据代表性
def prepare_data_directly(config, device, max_samples=1000):
    """直接从数据集加载数据，绕过DataLoader的限制"""
    from funcmol.dataset.dataset_field import FieldDataset
    
    # 创建GNF转换器
    gnf_converter = create_gnf_converter(config, device="cpu")
    
    # 直接创建数据集实例，不使用DataLoader
    dataset = FieldDataset(
        gnf_converter=gnf_converter,
        dset_name=config["dset"]["dset_name"],
        data_dir=config["dset"]["data_dir"],
        elements=config["dset"]["elements"],
        split="val",
        n_points=config["dset"]["n_points"],
        rotate=False,  # 验证时不旋转
        resolution=config["dset"]["resolution"],
        grid_dim=config["dset"]["grid_dim"],
        radius=config["dset"]["atomic_radius"],
        sample_full_grid=False,
        debug_one_mol=False,
        debug_subset=False,
    )
    
    print(f"数据集总大小: {len(dataset)}")
    
    # 随机采样指定数量的样本
    import random
    random.seed(42)
    if max_samples >= len(dataset):
        sample_indices = list(range(len(dataset)))
        print(f"请求样本数({max_samples}) >= 数据集大小({len(dataset)})，使用所有样本")
    else:
        sample_indices = random.sample(range(len(dataset)), max_samples)
        print(f"从 {len(dataset)} 个样本中随机选择 {max_samples} 个")
    
    # 加载选中的样本
    all_coords = []
    all_types = []
    
    print("开始加载样本...")
    for i, idx in enumerate(sample_indices):
        if i % 100 == 0:
            print(f"已加载 {i}/{len(sample_indices)} 个样本...")
        
        # 直接从数据集获取样本
        sample = dataset[idx]
        
        # 提取坐标和原子类型
        coords = sample.pos  # [n_atoms, 3]
        atoms_channel = sample.x  # [n_atoms]
        
        # 移除padding
        valid_mask = atoms_channel != PADDING_INDEX
        coords = coords[valid_mask]
        atoms_channel = atoms_channel[valid_mask]
        
        # 填充到固定长度（与原始数据格式一致）
        max_atoms = 18  # 根据原始数据设置
        if len(coords) < max_atoms:
            # 填充到固定长度
            pad_coords = torch.zeros(max_atoms - len(coords), 3)
            pad_atoms = torch.full((max_atoms - len(atoms_channel),), PADDING_INDEX, dtype=atoms_channel.dtype)
            
            coords = torch.cat([coords, pad_coords], dim=0)
            atoms_channel = torch.cat([atoms_channel, pad_atoms], dim=0)
        elif len(coords) > max_atoms:
            # 截断到固定长度
            coords = coords[:max_atoms]
            atoms_channel = atoms_channel[:max_atoms]
        
        all_coords.append(coords.unsqueeze(0))  # [1, max_atoms, 3]
        all_types.append(atoms_channel.unsqueeze(0))  # [1, max_atoms]
    
    # 合并所有样本
    gt_coords = torch.cat(all_coords, dim=0)  # [n_samples, max_atoms, 3]
    gt_types = torch.cat(all_types, dim=0)    # [n_samples, max_atoms]
    
    print(f"成功加载 {len(gt_coords)} 个样本")
    print(f"数据形状: coords={gt_coords.shape}, types={gt_types.shape}")
    
    # 创建一个示例batch用于兼容性
    sample_batch = dataset[0]
    
    return sample_batch, gt_coords, gt_types

# 加载数据
batch, gt_coords, gt_types = prepare_data_directly_from_dataset(config, device, max_samples=1000)
print(f"Data loaded: {gt_coords.shape}, {gt_types.shape}")

Dataset directory: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/funcmol/dataset/data
数据集总大小: 20042
从 20042 个样本中随机选择 1000 个
开始加载样本...
已加载 0/1000 个样本...
已加载 100/1000 个样本...
已加载 200/1000 个样本...
已加载 300/1000 个样本...
已加载 400/1000 个样本...
已加载 500/1000 个样本...
已加载 600/1000 个样本...
已加载 700/1000 个样本...
已加载 800/1000 个样本...
已加载 900/1000 个样本...
成功加载 1000 个样本
数据形状: coords=torch.Size([1000, 18, 3]), types=torch.Size([1000, 18])
Data loaded: torch.Size([1000, 18, 3]), torch.Size([1000, 18])


In [79]:
print(f"\nProcessing model from: {model_dir}")

## Load model (for GT field comparison, we don't need encoder/decoder)
encoder, decoder = None, None
codes = None
print("GT field comparison mode - no encoder/decoder needed")

# 创建转换器用于GT field计算
converter = create_converter(config, device)
print(f"GNF Converter created successfully!")


Processing model from: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/neural_field/nf_qm9_20250804_153549_358664
GT field comparison mode - no encoder/decoder needed
GNF Converter created with n_iter: 2000, gradient_field_method: tanh, n_atom_types: 5
GNF Converter created successfully!


In [80]:
# 添加RMSD计算函数（双向匹配）
import numpy as np
from scipy.optimize import linear_sum_assignment

def compute_rmsd_hungarian(coords1, coords2):
    """
    计算两个分子坐标集合之间的对称RMSD，使用Hungarian算法进行最优匹配
    
    Args:
        coords1: 第一个分子坐标 [n_atoms1, 3]
        coords2: 第二个分子坐标 [n_atoms2, 3]
    
    Returns:
        rmsd: 对称RMSD值
    """
    coords1 = coords1.detach().cpu().numpy() if torch.is_tensor(coords1) else coords1
    coords2 = coords2.detach().cpu().numpy() if torch.is_tensor(coords2) else coords2
    
    n1, n2 = len(coords1), len(coords2)
    
    if n1 == 0 or n2 == 0:
        return float('inf')
    
    # 计算距离矩阵
    dist_matrix = np.sqrt(np.sum((coords1[:, np.newaxis, :] - coords2[np.newaxis, :, :])**2, axis=2))
    
    # 使用Hungarian算法找到最优匹配
    if n1 <= n2:
        # coords1中的每个点匹配到coords2中的点
        row_indices, col_indices = linear_sum_assignment(dist_matrix)
        matched_distances = dist_matrix[row_indices, col_indices]
        
        # 计算双向RMSD
        # 方向1: coords1 -> coords2
        rmsd_1to2 = np.sqrt(np.mean(matched_distances**2))
        
        # 方向2: coords2 -> coords1 (未匹配的点到最近点的距离)
        matched_cols = set(col_indices)
        unmatched_cols = [i for i in range(n2) if i not in matched_cols]
        
        if unmatched_cols:
            unmatched_coords2 = coords2[unmatched_cols]
            min_dists_2to1 = []
            for coord in unmatched_coords2:
                dists_to_coords1 = np.sqrt(np.sum((coords1 - coord)**2, axis=1))
                min_dists_2to1.append(np.min(dists_to_coords1))
            min_dists_2to1 = np.array(min_dists_2to1)  # 转换为numpy数组
            rmsd_2to1 = np.sqrt(np.mean(min_dists_2to1**2))
        else:
            rmsd_2to1 = 0.0
        
        # 对称RMSD
        symmetric_rmsd = (rmsd_1to2 + rmsd_2to1) / 2
        
    else:
        # coords2中的每个点匹配到coords1中的点
        row_indices, col_indices = linear_sum_assignment(dist_matrix.T)
        # 注意：当使用dist_matrix.T时，row_indices对应coords2，col_indices对应coords1
        matched_distances = dist_matrix[col_indices, row_indices]  # 修正索引顺序
        
        # 计算双向RMSD
        # 方向1: coords2 -> coords1
        rmsd_2to1 = np.sqrt(np.mean(matched_distances**2))
        
        # 方向2: coords1 -> coords2 (未匹配的点到最近点的距离)
        matched_rows = set(row_indices)
        unmatched_rows = [i for i in range(n1) if i not in matched_rows]
        
        if unmatched_rows:
            unmatched_coords1 = coords1[unmatched_rows]
            min_dists_1to2 = []
            for coord in unmatched_coords1:
                dists_to_coords2 = np.sqrt(np.sum((coords2 - coord)**2, axis=1))
                min_dists_1to2.append(np.min(dists_to_coords2))
            min_dists_1to2 = np.array(min_dists_1to2)  # 转换为numpy数组
            rmsd_1to2 = np.sqrt(np.mean(min_dists_1to2**2))
        else:
            rmsd_1to2 = 0.0
        
        # 对称RMSD
        symmetric_rmsd = (rmsd_1to2 + rmsd_2to1) / 2
    
    return symmetric_rmsd

def reconstruction_loss(coords1, points):
    """Calculate reconstruction loss between original coordinates and sampled points"""
    # Calculate pairwise distances between all points
    dist1 = torch.sum((coords1.unsqueeze(1) - points.unsqueeze(0))**2, dim=2)
    
    eps = 1e-8
    
    # For each original atom, find the closest sampled point
    min_dist_to_samples = torch.min(dist1 + eps, dim=1)[0]
    
    # For each sampled point, find the closest original atom
    min_dist_to_atoms = torch.min(dist1 + eps, dim=0)[0]
    
    # Combine both directions
    coverage_loss = torch.mean(min_dist_to_samples)  # 确保每个原子都有近邻采样点
    clustering_loss = torch.mean(min_dist_to_atoms)  # 确保采样点集中在原子位置附近
    
    # 总损失是两个方向的加权和
    total_loss = coverage_loss + 0.1 * clustering_loss  # 可以调整权重
    
    return torch.sqrt(total_loss)

print("RMSD计算函数已定义")

RMSD计算函数已定义


In [81]:
# 定义gt field函数
def create_gt_field_func(converter, gt_coords, gt_types, sample_idx):
    """创建真实梯度场函数"""
    def gt_field_func(points):
        gt_mask = (gt_types[sample_idx] != PADDING_INDEX)
        gt_valid_coords = gt_coords[sample_idx][gt_mask]
        gt_valid_types = gt_types[sample_idx][gt_mask]
        
        # 确保 points 是正确的形状
        if points.dim() == 2:  # [n_points, 3]
            points = points.unsqueeze(0)  # [1, n_points, 3]
        elif points.dim() == 3:  # [batch, n_points, 3]
            pass
        else:
            raise ValueError(f"Unexpected points shape: {points.shape}")
        
        result = converter.mol2gnf(
            gt_valid_coords.unsqueeze(0),
            gt_valid_types.unsqueeze(0),
            points
        )
        # 确保返回 [n_points, n_atom_types, 3] 形状
        if result.dim() == 4:  # [batch, n_points, n_atom_types, 3]
            return result[0]  # 取第一个batch
        else:
            return result
    return gt_field_func

print("GT Field函数定义完成")

GT Field函数定义完成


In [82]:
# 设置参数 - 批量分析300个样本
import random
import numpy as np

# 批量分析参数
TOTAL_SAMPLES = 100  # 要分析的样本总数（可选）
BATCH_SIZE = 20      # 每批处理的样本数（可选）
RANDOM_SEED = 42     # 随机种子，确保结果可重现

# 设置随机种子
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

# 获取实际数据集大小
actual_dataset_size = len(gt_coords)
print(f"实际数据集大小: {actual_dataset_size}")

# 根据实际数据集大小调整采样策略
if TOTAL_SAMPLES > actual_dataset_size:
    print(f"警告: 请求的样本数({TOTAL_SAMPLES})大于数据集大小({actual_dataset_size})")
    print(f"将使用所有可用样本: {actual_dataset_size}")
    TOTAL_SAMPLES = actual_dataset_size

# 方法1：随机采样
sample_indices = random.sample(range(actual_dataset_size), TOTAL_SAMPLES)

# 方法2：连续采样（如果你想分析前N个样本）
# sample_indices = list(range(TOTAL_SAMPLES))

# 方法3：分层采样（如果你想均匀分布在不同区间）
# def stratified_sampling(total_samples, n_samples):
#     intervals = np.linspace(0, total_samples, 10)
#     samples_per_interval = n_samples // 10
#     sample_indices = []
#     for i in range(len(intervals)-1):
#         start = int(intervals[i])
#         end = int(intervals[i+1])
#         interval_samples = random.sample(range(start, end), samples_per_interval)
#         sample_indices.extend(interval_samples)
#     return sample_indices
# sample_indices = stratified_sampling(actual_dataset_size, TOTAL_SAMPLES)

# 排序以便于跟踪
sample_indices.sort()

field_methods = ['gaussian_mag', 'tanh']  # 要比较的field方法
output_dir = "/datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/funcmol/field_set"

print(f"将分析 {len(sample_indices)} 个样本")
print(f"样本索引范围: {min(sample_indices)} - {max(sample_indices)}")
print(f"前10个样本: {sample_indices[:10]}")
print(f"后10个样本: {sample_indices[-10:]}")
print(f"数据集大小: {actual_dataset_size}")
print(f"请求样本数: {TOTAL_SAMPLES}")

实际数据集大小: 1000
将分析 100 个样本
样本索引范围: 6 - 990
前10个样本: [6, 25, 27, 30, 32, 44, 46, 71, 80, 81]
后10个样本: [890, 906, 913, 947, 954, 964, 978, 983, 986, 990]
数据集大小: 1000
请求样本数: 100


In [None]:
# GT Field比较主逻辑
import time
import pandas as pd
import json
from tqdm import tqdm

# 定义分批处理函数
def process_batch(sample_indices_batch, field_methods, converters, gt_coords, gt_types, visualizer, output_dir):
    """处理一批样本"""
    batch_results = {
        "field_methods": field_methods,
        "sample_indices": sample_indices_batch,
        "comparison_results": {},
        "summary_statistics": {}
    }
    
    # 初始化结果结构
    for method in field_methods:
        batch_results["comparison_results"][method] = {
            "rmsd_values": [],
            "reconstruction_losses": [],
            "rmsd_hungarian_values": [],
            "reconstruction_times": [],
            "sample_details": {}
        }
    
    # 对每个样本进行测试
    for sample_idx in tqdm(sample_indices_batch, desc=f"处理批次", leave=False):
        # 获取真实分子信息
        gt_mask = (gt_types[sample_idx] != PADDING_INDEX)
        gt_valid_coords = gt_coords[sample_idx][gt_mask]
        gt_valid_types = gt_types[sample_idx][gt_mask]
        
        # 对每种field方法进行重建
        for method in field_methods:
            converter = converters[method]
            
            # 使用gt field函数
            field_func = create_gt_field_func(converter, gt_coords, gt_types, sample_idx)
            
            # 执行重建
            start_time = time.time()
            
            # 使用visualizer的create_reconstruction_animation方法
            reconstruction_results = visualizer.create_reconstruction_animation(
                gt_coords=gt_coords,
                gt_types=gt_types,
                converter=converter,
                field_func=field_func,
                save_interval=100,
                animation_name=f"gt_field_sample_{sample_idx}_{method}",
                sample_idx=sample_idx
            )
            
            reconstruction_time = time.time() - start_time
            
            # 获取重建结果
            recon_coords = reconstruction_results['final_points']
            recon_types = reconstruction_results['final_types']
            
            # 计算RMSD（使用Hungarian算法）
            rmsd_hungarian = compute_rmsd_hungarian(gt_valid_coords, recon_coords)
            
            # 计算重建损失
            recon_loss = reconstruction_loss(gt_valid_coords, recon_coords)
            
            # 保存结果
            batch_results["comparison_results"][method]["rmsd_values"].append(reconstruction_results['final_rmsd'])
            batch_results["comparison_results"][method]["reconstruction_losses"].append(recon_loss.item())
            batch_results["comparison_results"][method]["rmsd_hungarian_values"].append(rmsd_hungarian)
            batch_results["comparison_results"][method]["reconstruction_times"].append(reconstruction_time)
            
            # 保存样本详细信息
            batch_results["comparison_results"][method]["sample_details"][sample_idx] = {
                "rmsd": reconstruction_results['final_rmsd'],
                "rmsd_hungarian": rmsd_hungarian,
                "reconstruction_loss": recon_loss.item(),
                "reconstruction_time": reconstruction_time,
                "gt_atoms": len(gt_valid_coords),
                "recon_atoms": len(recon_coords),
                "gt_types": gt_valid_types.cpu().numpy().tolist(),
                "recon_types": recon_types.cpu().numpy().tolist() if len(recon_types) > 0 else [],
                "gif_path": reconstruction_results['gif_path'],
                "comparison_path": reconstruction_results['comparison_path']
            }
    
    return batch_results

# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
os.makedirs(os.path.join(output_dir, "recon"), exist_ok=True)

# 创建可视化器
visualizer = GNFVisualizer(output_dir)
print(f"输出目录: {output_dir}")

# 存储结果
gt_results = {
    "field_methods": field_methods,
    "sample_indices": sample_indices,
    "comparison_results": {},
    "summary_statistics": {}
}

# 为每种field方法创建转换器
converters = {}
for method in field_methods:
    print(f"\n创建 {method} gt field转换器...")
    # 创建新的配置，修改gradient_field_method
    method_config = load_config_from_exp_dir(model_dir)
    method_config["converter"]["gradient_field_method"] = method
    
    # 创建转换器
    method_converter = create_converter(method_config, device)
    converters[method] = method_converter
    print(f"{method} gt field转换器创建完成")

print("\n开始gt field比较...")

# 选择处理方式：分批处理或一次性处理
USE_BATCH_PROCESSING = True  # 设置为True启用分批处理，False则一次性处理所有样本

if USE_BATCH_PROCESSING and len(sample_indices) > BATCH_SIZE:
    print(f"启用分批处理：将分 {(len(sample_indices) + BATCH_SIZE - 1) // BATCH_SIZE} 批处理")
    
    # 分批处理
    for i in range(0, len(sample_indices), BATCH_SIZE):
        batch_indices = sample_indices[i:i+BATCH_SIZE]
        print(f"\n处理第 {i//BATCH_SIZE + 1} 批，样本 {i} 到 {i+len(batch_indices)-1}")
        
        batch_results = process_batch(batch_indices, field_methods, converters, gt_coords, gt_types, visualizer, output_dir)
        
        # 合并结果到主结果中
        for method in field_methods:
            if method not in gt_results["comparison_results"]:
                gt_results["comparison_results"][method] = {
                    "rmsd_values": [],
                    "reconstruction_losses": [],
                    "rmsd_hungarian_values": [],
                    "reconstruction_times": [],
                    "sample_details": {}
                }
            
            gt_results["comparison_results"][method]["rmsd_values"].extend(batch_results["comparison_results"][method]["rmsd_values"])
            gt_results["comparison_results"][method]["reconstruction_losses"].extend(batch_results["comparison_results"][method]["reconstruction_losses"])
            gt_results["comparison_results"][method]["rmsd_hungarian_values"].extend(batch_results["comparison_results"][method]["rmsd_hungarian_values"])
            gt_results["comparison_results"][method]["reconstruction_times"].extend(batch_results["comparison_results"][method]["reconstruction_times"])
            gt_results["comparison_results"][method]["sample_details"].update(batch_results["comparison_results"][method]["sample_details"])
        
        # 每批保存一次结果（可选）
        batch_file = os.path.join(output_dir, f"batch_{i//BATCH_SIZE + 1}_results.json")
        with open(batch_file, 'w', encoding='utf-8') as f:
            json.dump(batch_results, f, indent=2, ensure_ascii=False)
        print(f"第 {i//BATCH_SIZE + 1} 批结果已保存到: {batch_file}")

else:
    # 一次性处理所有样本
    print(f"一次性处理所有 {len(sample_indices)} 个样本...")
    
    # 对每个样本进行测试 - 使用进度条
    for sample_idx in tqdm(sample_indices, desc="处理样本"):
        print(f"\n处理样本 {sample_idx}...")
        
        # 获取真实分子信息
        gt_mask = (gt_types[sample_idx] != PADDING_INDEX)
        gt_valid_coords = gt_coords[sample_idx][gt_mask]
        gt_valid_types = gt_types[sample_idx][gt_mask]
        
        print(f"真实原子数: {len(gt_valid_coords)}")
        print(f"真实原子类型: {gt_valid_types.cpu().numpy()}")
        
        # 对每种field方法进行重建
        for method in field_methods:
            print(f"\n--- 测试 {method} gt field方法 ---")
            
            converter = converters[method]
            
            # 使用gt field函数
            field_func = create_gt_field_func(converter, gt_coords, gt_types, sample_idx)
            
            # 执行重建
            start_time = time.time()
            
            # 使用visualizer的create_reconstruction_animation方法
            reconstruction_results = visualizer.create_reconstruction_animation(
                gt_coords=gt_coords,
                gt_types=gt_types,
                converter=converter,
                field_func=field_func,
                save_interval=100,
                animation_name=f"gt_field_sample_{sample_idx}_{method}",
                sample_idx=sample_idx
            )
            
            reconstruction_time = time.time() - start_time
            
            # 获取重建结果
            recon_coords = reconstruction_results['final_points']
            recon_types = reconstruction_results['final_types']
            
            # 计算RMSD（使用Hungarian算法）
            rmsd_hungarian = compute_rmsd_hungarian(gt_valid_coords, recon_coords)
            
            # 计算重建损失
            recon_loss = reconstruction_loss(gt_valid_coords, recon_coords)
            
            # 保存结果
            if method not in gt_results["comparison_results"]:
                gt_results["comparison_results"][method] = {
                    "rmsd_values": [],
                    "reconstruction_losses": [],
                    "rmsd_hungarian_values": [],
                    "reconstruction_times": [],
                    "sample_details": {}
                }
            
            gt_results["comparison_results"][method]["rmsd_values"].append(reconstruction_results['final_rmsd'])
            gt_results["comparison_results"][method]["reconstruction_losses"].append(recon_loss.item())
            gt_results["comparison_results"][method]["rmsd_hungarian_values"].append(rmsd_hungarian)
            gt_results["comparison_results"][method]["reconstruction_times"].append(reconstruction_time)
            
            # 保存样本详细信息
            gt_results["comparison_results"][method]["sample_details"][sample_idx] = {
                "rmsd": reconstruction_results['final_rmsd'],
                "rmsd_hungarian": rmsd_hungarian,
                "reconstruction_loss": recon_loss.item(),
                "reconstruction_time": reconstruction_time,
                "gt_atoms": len(gt_valid_coords),
                "recon_atoms": len(recon_coords),
                "gt_types": gt_valid_types.cpu().numpy().tolist(),
                "recon_types": recon_types.cpu().numpy().tolist() if len(recon_types) > 0 else [],
                "gif_path": reconstruction_results['gif_path'],
                "comparison_path": reconstruction_results['comparison_path']
            }
            
            print(f"样本 {sample_idx} - {method} gt field:")
            print(f"  RMSD: {reconstruction_results['final_rmsd']:.4f}")
            print(f"  RMSD_Hungarian: {rmsd_hungarian:.4f}")
            print(f"  Reconstruction Loss: {recon_loss.item():.4f}")
            print(f"  重建时间: {reconstruction_time:.2f}s")
            print(f"  真实原子数: {len(gt_valid_coords)}")
            print(f"  重建原子数: {len(recon_coords)}")
            print(f"  重建原子类型: {recon_types.cpu().numpy() if len(recon_types) > 0 else 'None'}")

print("\n" + "=" * 60)
print("gt field比较完成！")

输出目录: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/funcmol/field_set

创建 gaussian_mag gt field转换器...
Dataset directory: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/funcmol/dataset/data
GNF Converter created with n_iter: 2000, gradient_field_method: gaussian_mag, n_atom_types: 5
gaussian_mag gt field转换器创建完成

创建 tanh gt field转换器...
Dataset directory: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/funcmol/dataset/data
GNF Converter created with n_iter: 2000, gradient_field_method: tanh, n_atom_types: 5
tanh gt field转换器创建完成

开始gt field比较...
启用分批处理：将分 5 批处理

处理第 1 批，样本 0 到 19


处理批次:   0%|          | 0/20 [00:00<?, ?it/s]


Starting reconstruction for molecule 6
Ground truth atoms: 14
[DBSCAN] Total points: 2000, Clusters found: 5, Noise points: 1880
[DBSCAN] Total points: 2000, Clusters found: 5, Noise points: 1840
[DBSCAN] Total points: 2000, Clusters found: 1, Noise points: 1962
[DBSCAN] Total points: 2000, Clusters found: 3, Noise points: 1911
[DBSCAN] Total points: 2000, Clusters found: 0, Noise points: 2000

Starting reconstruction for molecule 6
Ground truth atoms: 14
[DBSCAN] Total points: 1000, Clusters found: 7, Noise points: 432
[DBSCAN] Total points: 1000, Clusters found: 5, Noise points: 297
[DBSCAN] Total points: 1000, Clusters found: 1, Noise points: 751
[DBSCAN] Total points: 1000, Clusters found: 5, Noise points: 467
[DBSCAN] Total points: 1000, Clusters found: 0, Noise points: 1000


处理批次:   5%|▌         | 1/20 [01:42<32:28, 102.58s/it]


Starting reconstruction for molecule 25
Ground truth atoms: 18
[DBSCAN] Total points: 2000, Clusters found: 5, Noise points: 1928
[DBSCAN] Total points: 2000, Clusters found: 9, Noise points: 1862
[DBSCAN] Total points: 2000, Clusters found: 1, Noise points: 1981
[DBSCAN] Total points: 2000, Clusters found: 3, Noise points: 1954
[DBSCAN] Total points: 2000, Clusters found: 0, Noise points: 2000

Starting reconstruction for molecule 25
Ground truth atoms: 18
[DBSCAN] Total points: 1000, Clusters found: 6, Noise points: 832
[DBSCAN] Total points: 1000, Clusters found: 10, Noise points: 608
[DBSCAN] Total points: 1000, Clusters found: 2, Noise points: 903
[DBSCAN] Total points: 1000, Clusters found: 6, Noise points: 821
[DBSCAN] Total points: 1000, Clusters found: 0, Noise points: 1000


处理批次:  10%|█         | 2/20 [03:46<34:34, 115.27s/it]


Starting reconstruction for molecule 27
Ground truth atoms: 18
[DBSCAN] Total points: 2000, Clusters found: 4, Noise points: 1950
[DBSCAN] Total points: 2000, Clusters found: 9, Noise points: 1868
[DBSCAN] Total points: 2000, Clusters found: 0, Noise points: 2000
[DBSCAN] Total points: 2000, Clusters found: 1, Noise points: 1986
[DBSCAN] Total points: 2000, Clusters found: 0, Noise points: 2000

Starting reconstruction for molecule 27
Ground truth atoms: 18
[DBSCAN] Total points: 1000, Clusters found: 5, Noise points: 905
[DBSCAN] Total points: 1000, Clusters found: 11, Noise points: 736
[DBSCAN] Total points: 1000, Clusters found: 0, Noise points: 1000
[DBSCAN] Total points: 1000, Clusters found: 2, Noise points: 939
[DBSCAN] Total points: 1000, Clusters found: 0, Noise points: 1000


处理批次:  15%|█▌        | 3/20 [05:47<33:20, 117.69s/it]


Starting reconstruction for molecule 30
Ground truth atoms: 17
[DBSCAN] Total points: 2000, Clusters found: 6, Noise points: 1877
[DBSCAN] Total points: 2000, Clusters found: 8, Noise points: 1816
[DBSCAN] Total points: 2000, Clusters found: 1, Noise points: 1980
[DBSCAN] Total points: 2000, Clusters found: 2, Noise points: 1952
[DBSCAN] Total points: 2000, Clusters found: 0, Noise points: 2000

Starting reconstruction for molecule 30
Ground truth atoms: 17
[DBSCAN] Total points: 1000, Clusters found: 6, Noise points: 641
[DBSCAN] Total points: 1000, Clusters found: 8, Noise points: 394
[DBSCAN] Total points: 1000, Clusters found: 2, Noise points: 805
[DBSCAN] Total points: 1000, Clusters found: 3, Noise points: 761
[DBSCAN] Total points: 1000, Clusters found: 0, Noise points: 1000


处理批次:  20%|██        | 4/20 [12:10<59:17, 222.37s/it]


Starting reconstruction for molecule 32
Ground truth atoms: 16
[DBSCAN] Total points: 2000, Clusters found: 5, Noise points: 1937
[DBSCAN] Total points: 2000, Clusters found: 5, Noise points: 1937
[DBSCAN] Total points: 2000, Clusters found: 1, Noise points: 1981
[DBSCAN] Total points: 2000, Clusters found: 1, Noise points: 1988
[DBSCAN] Total points: 2000, Clusters found: 0, Noise points: 2000

Starting reconstruction for molecule 32
Ground truth atoms: 16
[DBSCAN] Total points: 1000, Clusters found: 6, Noise points: 840
[DBSCAN] Total points: 1000, Clusters found: 9, Noise points: 745
[DBSCAN] Total points: 1000, Clusters found: 2, Noise points: 921
[DBSCAN] Total points: 1000, Clusters found: 1, Noise points: 980
[DBSCAN] Total points: 1000, Clusters found: 0, Noise points: 1000


处理批次:  25%|██▌       | 5/20 [14:06<46:01, 184.13s/it]


Starting reconstruction for molecule 44
Ground truth atoms: 18
[DBSCAN] Total points: 2000, Clusters found: 8, Noise points: 1869
[DBSCAN] Total points: 2000, Clusters found: 8, Noise points: 1859
[DBSCAN] Total points: 2000, Clusters found: 0, Noise points: 2000
[DBSCAN] Total points: 2000, Clusters found: 1, Noise points: 1984
[DBSCAN] Total points: 2000, Clusters found: 0, Noise points: 2000

Starting reconstruction for molecule 44
Ground truth atoms: 18
[DBSCAN] Total points: 1000, Clusters found: 10, Noise points: 653
[DBSCAN] Total points: 1000, Clusters found: 10, Noise points: 562
[DBSCAN] Total points: 1000, Clusters found: 0, Noise points: 1000
[DBSCAN] Total points: 1000, Clusters found: 3, Noise points: 889
[DBSCAN] Total points: 1000, Clusters found: 0, Noise points: 1000


处理批次:  30%|███       | 6/20 [16:03<37:38, 161.30s/it]


Starting reconstruction for molecule 46
Ground truth atoms: 17
[DBSCAN] Total points: 2000, Clusters found: 7, Noise points: 1851
[DBSCAN] Total points: 2000, Clusters found: 8, Noise points: 1828
[DBSCAN] Total points: 2000, Clusters found: 2, Noise points: 1955
[DBSCAN] Total points: 2000, Clusters found: 0, Noise points: 2000
[DBSCAN] Total points: 2000, Clusters found: 0, Noise points: 2000

Starting reconstruction for molecule 46
Ground truth atoms: 17
[DBSCAN] Total points: 1000, Clusters found: 8, Noise points: 699
[DBSCAN] Total points: 1000, Clusters found: 12, Noise points: 525
[DBSCAN] Total points: 1000, Clusters found: 3, Noise points: 721
[DBSCAN] Total points: 1000, Clusters found: 0, Noise points: 1000
[DBSCAN] Total points: 1000, Clusters found: 0, Noise points: 1000


处理批次:  35%|███▌      | 7/20 [22:06<49:12, 227.09s/it]


Starting reconstruction for molecule 71
Ground truth atoms: 13
[DBSCAN] Total points: 2000, Clusters found: 4, Noise points: 1734
[DBSCAN] Total points: 2000, Clusters found: 4, Noise points: 1710
[DBSCAN] Total points: 2000, Clusters found: 1, Noise points: 1890
[DBSCAN] Total points: 2000, Clusters found: 3, Noise points: 1771
[DBSCAN] Total points: 2000, Clusters found: 1, Noise points: 1903

Starting reconstruction for molecule 71
Ground truth atoms: 13
[DBSCAN] Total points: 1000, Clusters found: 6, Noise points: 149
[DBSCAN] Total points: 1000, Clusters found: 5, Noise points: 41
[DBSCAN] Total points: 1000, Clusters found: 4, Noise points: 410
[DBSCAN] Total points: 1000, Clusters found: 8, Noise points: 161
[DBSCAN] Total points: 1000, Clusters found: 1, Noise points: 457


处理批次:  40%|████      | 8/20 [23:52<37:44, 188.67s/it]


Starting reconstruction for molecule 80
Ground truth atoms: 14
[DBSCAN] Total points: 2000, Clusters found: 5, Noise points: 1899
[DBSCAN] Total points: 2000, Clusters found: 5, Noise points: 1888
[DBSCAN] Total points: 2000, Clusters found: 3, Noise points: 1929
[DBSCAN] Total points: 2000, Clusters found: 1, Noise points: 1975
[DBSCAN] Total points: 2000, Clusters found: 0, Noise points: 2000

Starting reconstruction for molecule 80
Ground truth atoms: 14
[DBSCAN] Total points: 1000, Clusters found: 6, Noise points: 593
[DBSCAN] Total points: 1000, Clusters found: 6, Noise points: 489
[DBSCAN] Total points: 1000, Clusters found: 6, Noise points: 585
[DBSCAN] Total points: 1000, Clusters found: 3, Noise points: 764
[DBSCAN] Total points: 1000, Clusters found: 0, Noise points: 1000


处理批次:  45%|████▌     | 9/20 [25:35<29:39, 161.78s/it]


Starting reconstruction for molecule 81
Ground truth atoms: 14
[DBSCAN] Total points: 2000, Clusters found: 4, Noise points: 1935
[DBSCAN] Total points: 2000, Clusters found: 5, Noise points: 1914
[DBSCAN] Total points: 2000, Clusters found: 2, Noise points: 1971
[DBSCAN] Total points: 2000, Clusters found: 1, Noise points: 1978
[DBSCAN] Total points: 2000, Clusters found: 0, Noise points: 2000

Starting reconstruction for molecule 81
Ground truth atoms: 14
[DBSCAN] Total points: 1000, Clusters found: 6, Noise points: 781
[DBSCAN] Total points: 1000, Clusters found: 5, Noise points: 745
[DBSCAN] Total points: 1000, Clusters found: 2, Noise points: 858
[DBSCAN] Total points: 1000, Clusters found: 2, Noise points: 877
[DBSCAN] Total points: 1000, Clusters found: 0, Noise points: 1000


处理批次:  50%|█████     | 10/20 [31:47<37:48, 226.86s/it]


Starting reconstruction for molecule 89
Ground truth atoms: 17
[DBSCAN] Total points: 2000, Clusters found: 4, Noise points: 1928
[DBSCAN] Total points: 2000, Clusters found: 8, Noise points: 1865
[DBSCAN] Total points: 2000, Clusters found: 2, Noise points: 1961
[DBSCAN] Total points: 2000, Clusters found: 2, Noise points: 1958
[DBSCAN] Total points: 2000, Clusters found: 0, Noise points: 2000

Starting reconstruction for molecule 89
Ground truth atoms: 17


In [None]:
# 处理完成后的总结
print("\n" + "=" * 80)
print("批量处理完成！")
print("=" * 80)

# 显示处理统计
total_samples = len(sample_indices)
total_batches = (total_samples + BATCH_SIZE - 1) // BATCH_SIZE if USE_BATCH_PROCESSING else 1

print(f"处理统计:")
print(f"  总样本数: {total_samples}")
print(f"  总批次数: {total_batches}")
print(f"  每批样本数: {BATCH_SIZE}")
print(f"  处理方式: {'分批处理' if USE_BATCH_PROCESSING and total_samples > BATCH_SIZE else '一次性处理'}")

# 显示结果文件
print(f"\n生成的文件:")
print(f"  输出目录: {output_dir}")
print(f"  重建图片: {output_dir}/recon/")
if USE_BATCH_PROCESSING and total_samples > BATCH_SIZE:
    print(f"  分批结果: {output_dir}/batch_*_results.json")
print(f"  最终结果: {output_dir}/gt_field_comparison_results.json")

print("\n" + "=" * 80)


In [None]:
# 保存gt field比较结果
import numpy as np

print("正在生成gt field比较统计报告...")

# 计算汇总统计
for method in field_methods:
    method_results = gt_results["comparison_results"][method]
    
    # 计算统计指标
    rmsd_mean = np.mean(method_results["rmsd_values"])
    rmsd_std = np.std(method_results["rmsd_values"])
    rmsd_hungarian_mean = np.mean(method_results["rmsd_hungarian_values"])
    rmsd_hungarian_std = np.std(method_results["rmsd_hungarian_values"])
    loss_mean = np.mean(method_results["reconstruction_losses"])
    loss_std = np.std(method_results["reconstruction_losses"])
    time_mean = np.mean(method_results["reconstruction_times"])
    time_std = np.std(method_results["reconstruction_times"])
    
    gt_results["summary_statistics"][method] = {
        "rmsd_mean": float(rmsd_mean),
        "rmsd_std": float(rmsd_std),
        "rmsd_hungarian_mean": float(rmsd_hungarian_mean),
        "rmsd_hungarian_std": float(rmsd_hungarian_std),
        "loss_mean": float(loss_mean),
        "loss_std": float(loss_std),
        "time_mean": float(time_mean),
        "time_std": float(time_std)
    }

# 打印汇总结果
print("\n" + "=" * 80)
print("GT FIELD方法比较结果汇总")
print("=" * 80)

for method in field_methods:
    stats = gt_results["summary_statistics"][method]
    print(f"\n{method.upper()} GT FIELD方法:")
    print(f"  RMSD (原始): {stats['rmsd_mean']:.4f} ± {stats['rmsd_std']:.4f}")
    print(f"  RMSD (Hungarian): {stats['rmsd_hungarian_mean']:.4f} ± {stats['rmsd_hungarian_std']:.4f}")
    print(f"  Reconstruction Loss: {stats['loss_mean']:.4f} ± {stats['loss_std']:.4f}")
    print(f"  重建时间: {stats['time_mean']:.2f} ± {stats['time_std']:.2f}s")

# 比较两种方法
print(f"\n" + "=" * 80)
print("GT FIELD方法比较:")
print("=" * 80)

gaussian_stats = gt_results["summary_statistics"]["gaussian_mag"]
tanh_stats = gt_results["summary_statistics"]["tanh"]

print(f"RMSD (Hungarian) 比较:")
print(f"  gaussian_mag: {gaussian_stats['rmsd_hungarian_mean']:.4f}")
print(f"  tanh: {tanh_stats['rmsd_hungarian_mean']:.4f}")
if gaussian_stats['rmsd_hungarian_mean'] < tanh_stats['rmsd_hungarian_mean']:
    print(f"  → gaussian_mag gt field 更好 (低 {tanh_stats['rmsd_hungarian_mean'] - gaussian_stats['rmsd_hungarian_mean']:.4f})")
else:
    print(f"  → tanh gt field 更好 (低 {gaussian_stats['rmsd_hungarian_mean'] - tanh_stats['rmsd_hungarian_mean']:.4f})")

print(f"\nReconstruction Loss 比较:")
print(f"  gaussian_mag: {gaussian_stats['loss_mean']:.4f}")
print(f"  tanh: {tanh_stats['loss_mean']:.4f}")
if gaussian_stats['loss_mean'] < tanh_stats['loss_mean']:
    print(f"  → gaussian_mag gt field 更好 (低 {tanh_stats['loss_mean'] - gaussian_stats['loss_mean']:.4f})")
else:
    print(f"  → tanh gt field 更好 (低 {gaussian_stats['loss_mean'] - tanh_stats['loss_mean']:.4f})")

# 保存结果到JSON
gt_results_file = os.path.join(output_dir, "gt_field_comparison_results.json")
with open(gt_results_file, 'w', encoding='utf-8') as f:
    json.dump(gt_results, f, indent=2, ensure_ascii=False)
print(f"\ngt field比较结果已保存到: {gt_results_file}")

print(f"\n" + "=" * 80)
print("GT FIELD比较完成！")
print("=" * 80)

正在生成gt field比较统计报告...

GT FIELD方法比较结果汇总

GAUSSIAN_MAG GT FIELD方法:
  RMSD (原始): 0.7208 ± 0.0000
  RMSD (Hungarian): 0.0002 ± 0.0000
  Reconstruction Loss: 0.0005 ± 0.0000
  重建时间: 57.89 ± 0.00s

TANH GT FIELD方法:
  RMSD (原始): 0.1667 ± 0.0000
  RMSD (Hungarian): 0.0046 ± 0.0000
  Reconstruction Loss: 0.0096 ± 0.0000
  重建时间: 41.35 ± 0.00s

GT FIELD方法比较:
RMSD (Hungarian) 比较:
  gaussian_mag: 0.0002
  tanh: 0.0046
  → gaussian_mag gt field 更好 (低 0.0044)

Reconstruction Loss 比较:
  gaussian_mag: 0.0005
  tanh: 0.0096
  → gaussian_mag gt field 更好 (低 0.0091)

gt field比较结果已保存到: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/funcmol/field_set/gt_field_comparison_results.json

GT FIELD比较完成！
