In [8]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
import os
import sys
import torch
import hydra
from pathlib import Path
from lightning import Fabric

# 设置 torch.compile 兼容性
try:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
except ImportError:
    # PyTorch 版本 < 2.0 不支持 torch._dynamo
    print("Warning: torch._dynamo not available in this PyTorch version")

## set up environment
project_root = Path(os.getcwd()).parent
sys.path.insert(0, str(project_root))

from funcmol.utils.constants import PADDING_INDEX
from funcmol.utils.gnf_visualizer import (
    load_config_from_exp_dir, load_model, 
    create_converter, prepare_data, visualize_1d_gradient_field_comparison,
    GNFVisualizer
)

# 模型根目录
model_root = "/datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/neural_field"

In [10]:
# TODO：只需要修改这里的就好，会根据exp_name判断gradient_field_method
exp_name = 'nf_qm9_20250804_153549_358664'
sample_idx = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 判断是gt_only还是gt_pred模式
if '2025' in exp_name:  # gt + predicted field
    option = 'gt_pred'
    model_dir = os.path.join(model_root, exp_name)
else:  # gt only. exp_name is the name of field (e.g., gaussian_mag)
    option = 'gt_only'
    model_dir = os.path.join(model_root, exp_name)
    os.makedirs(model_dir, exist_ok=True)

output_dir = model_dir
print(f"Option: {option}")
print(f"Model directory: {model_dir}")
print(f"Output directory: {output_dir}")

Option: gt_pred
Model directory: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/neural_field/nf_qm9_20250804_153549_358664
Output directory: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/neural_field/nf_qm9_20250804_153549_358664


In [11]:
## Load data
fabric = Fabric(
    accelerator="auto",
    devices=1,
    precision="32-true",
    strategy="auto"
)
fabric.launch()

# 从实验目录加载配置
config = load_config_from_exp_dir(model_dir)

# 准备数据
batch, gt_coords, gt_types = prepare_data(fabric, config, device)
print(f"Data loaded: {gt_coords.shape}, {gt_types.shape}")

Dataset directory: /datapool/data3/storage/pengxingang/pxg/hyc/funcmol-main-neuralfield/funcmol/dataset/data
>> val set size: 20042
Data loaded: torch.Size([128, 18, 3]), torch.Size([128, 18])


In [12]:
print(f"\nProcessing model from: {model_dir}")

## Load model
if option == 'gt_pred':
    encoder, decoder = load_model(fabric, config, model_dir=model_dir)
    # 生成 codes
    with torch.no_grad():
        codes = encoder(batch)
    # 定义预测场函数
    def predicted_field_func(points):
        # 确保 points 是正确的形状
        if points.dim() == 2:  # [n_points, 3]
            points = points.unsqueeze(0)  # [1, n_points, 3]
        elif points.dim() == 3:  # [batch, n_points, 3]
            pass
        else:
            raise ValueError(f"Unexpected points shape: {points.shape}")
        
        result = decoder(points, codes[sample_idx:sample_idx+1])
        # 确保返回 [n_points, n_atom_types, 3] 形状
        if result.dim() == 4:  # [batch, n_points, n_atom_types, 3]
            return result[0]  # 取第一个batch
        else:
            return result
    field_func = predicted_field_func
else:  # gt only
    encoder, decoder = None, None
    # 定义真实场函数
    def gt_field_func(points):
        gt_mask = (gt_types[sample_idx] != PADDING_INDEX)
        gt_valid_coords = gt_coords[sample_idx][gt_mask]
        gt_valid_types = gt_types[sample_idx][gt_mask]
        
        # 确保 points 是正确的形状
        if points.dim() == 2:  # [n_points, 3]
            points = points.unsqueeze(0)  # [1, n_points, 3]
        elif points.dim() == 3:  # [batch, n_points, 3]
            pass
        else:
            raise ValueError(f"Unexpected points shape: {points.shape}")
        
        result = converter.mol2gnf(
            gt_valid_coords.unsqueeze(0),
            gt_valid_types.unsqueeze(0),
            points
        )
        # 确保返回 [n_points, n_atom_types, 3] 形状
        if result.dim() == 4:  # [batch, n_points, n_atom_types, 3]
            return result[0]  # 取第一个batch
        else:
            return result
    field_func = gt_field_func

converter = create_converter(config, device)
print(f"Model loaded successfully!")


Processing model from: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/neural_field/nf_qm9_20250804_153549_358664
Loading model from: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/neural_field/nf_qm9_20250804_153549_358664/model.pt
>> loaded dec
>> loaded enc
>> loaded dec
>> loaded enc
Model loaded successfully!
GNF Converter created with n_iter: 2000, gradient_field_method: tanh
Model loaded successfully!



You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



In [13]:
# 可视化一维梯度场对比（所有原子类型）
atom_types = [0, 1, 2, 3, 4]  # C, H, O, N, F
save_path = os.path.join(output_dir, "recon", f"field1d_sample_{sample_idx}")

gradient_results = visualize_1d_gradient_field_comparison(
    gt_coords=gt_coords,
    gt_types=gt_types,
    converter=converter,
    field_func=field_func,
    sample_idx=sample_idx,
    atom_types=atom_types,  # 传入列表，不需要循环
    x_range=None,
    y_coord=0.0,
    z_coord=0.0,
    save_path=save_path,
)

if gradient_results:
    print(f"Gradient field comparison (model: {model_dir}):")
    print(f"  Available atom types: {gradient_results['available_atom_types']}")
    
    # 打印每个原子类型的统计信息
    for atom_name, stats in gradient_results['all_results'].items():
        print(f"  {atom_name}: MSE={stats['mse']:.6f}, MAE={stats['mae']:.6f}")
        print(f"    Saved to: {stats['save_path']}")

警告：样本 1 中没有类型为 N 的原子
警告：样本 1 中没有类型为 F 的原子
自动计算 x 轴范围: (-1.567720103263855, 1.5677199840545655)
Field 1D comparison (atom_type=C) saved to: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/neural_field/nf_qm9_20250804_153549_358664/recon/field_1d_sample_1_atom_C.png
Field 1D comparison (atom_type=H) saved to: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/neural_field/nf_qm9_20250804_153549_358664/recon/field_1d_sample_1_atom_H.png
Field 1D comparison (atom_type=O) saved to: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/neural_field/nf_qm9_20250804_153549_358664/recon/field_1d_sample_1_atom_O.png
Gradient field comparison (model: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/neural_field/nf_qm9_20250804_153549_358664):
  Available atom types: [0, 1, 2]
  C: MSE=0.025678, MAE=0.121450
    Saved to: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/neural_field/nf_qm9_20250804_153549_358664/recon/field_1d_sa

In [14]:
# 根据option设置重建列表
if option == 'gt_only':
    rec_list = ['gt_field']
else:
    rec_list = ['predicted_field', 'gt_field']

# 创建可视化器
visualizer = GNFVisualizer(output_dir)

# 为每种重建类型执行可视化
for rec_type in rec_list:
    print(f"\n=== 执行 {rec_type} 重建 ===")
    
    # 根据重建类型设置场函数
    if rec_type == 'gt_field':
        # 定义真实场函数
        def gt_field_func(points):
            gt_mask = (gt_types[sample_idx] != PADDING_INDEX)
            gt_valid_coords = gt_coords[sample_idx][gt_mask]
            gt_valid_types = gt_types[sample_idx][gt_mask]
            return converter.mol2gnf(
                gt_valid_coords.unsqueeze(0),
                gt_valid_types.unsqueeze(0),
                points
            )
        field_func = gt_field_func
    else:  # predicted_field
        # 定义预测场函数
        def predicted_field_func(points):
            if points.dim() == 2:
                points = points.unsqueeze(0)
            elif points.dim() == 3:
                pass
            else:
                raise ValueError(f"Unexpected points shape: {points.shape}")
            result = decoder(points, codes[sample_idx:sample_idx+1])
            return result[0] if result.dim() == 4 else result
        field_func = predicted_field_func
    
    # 执行重建可视化
    results = visualizer.create_reconstruction_animation(
        gt_coords=gt_coords,
        gt_types=gt_types,
        converter=converter,
        field_func=field_func,
        save_interval=100,
        animation_name=f"recon_sample_{sample_idx}_{rec_type}",
        sample_idx=sample_idx
    )

    print(f"\n=== {rec_type} 重建结果 ===")
    print(f"RMSD: {results['final_rmsd']:.4f}")
    print(f"Reconstruction Loss: {results['final_loss']:.4f}")
    print(f"KL Divergence (orig->recon): {results['final_kl_1to2']:.4f}")
    print(f"KL Divergence (recon->orig): {results['final_kl_2to1']:.4f}")
    print(f"GIF动画: {results['gif_path']}")
    print(f"对比图: {results['comparison_path']}")


=== 执行 predicted_field 重建 ===

Starting reconstruction for molecule 1
Ground truth atoms: 9
[DBSCAN] Total points: 1000, Clusters found: 9, Noise points: 529
[DBSCAN] Total points: 1000, Clusters found: 6, Noise points: 131
[DBSCAN] Total points: 1000, Clusters found: 1, Noise points: 51
[DBSCAN] Total points: 1000, Clusters found: 0, Noise points: 1000
[DBSCAN] Total points: 1000, Clusters found: 0, Noise points: 1000

=== predicted_field 重建结果 ===
RMSD: 0.4137
Reconstruction Loss: 0.3995
KL Divergence (orig->recon): 8.4911
KL Divergence (recon->orig): -3.3217
GIF动画: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/neural_field/nf_qm9_20250804_153549_358664/recon/recon_sample_1_predicted_field.gif
对比图: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/neural_field/nf_qm9_20250804_153549_358664/recon/recon_sample_1_predicted_field_final.png

=== 执行 gt_field 重建 ===

Starting reconstruction for molecule 1
Ground truth atoms: 9
[DBSCAN] Total points: 1000, Clu