In [1]:
%load_ext autoreload
%autoreload 2

## 1. 导入必要库，选择场类型

In [2]:
import torch
import os
import sys
import hydra
import traceback
from pathlib import Path
from lightning import Fabric
from torch_geometric.utils import to_dense_batch

# 设置项目路径，导入项目相关模块
project_root = Path(os.getcwd()).parent
sys.path.insert(0, str(project_root))
from funcmol.utils.constants import PADDING_INDEX
from funcmol.utils.utils_nf import load_neural_field
from funcmol.dataset.dataset_field import create_field_loaders, create_gnf_converter
from gnf_visualizer import visualize_reconstruction

# 设置 torch.compile 兼容性
import torch._dynamo
torch._dynamo.config.suppress_errors = True

Fabric will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Fabric(devices=2)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.


## 2. 完成重建过程

In [3]:
fabric = Fabric(
    accelerator="cpu",
    devices=1,
    precision="32-true",
    strategy="auto"
)
fabric.launch()
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

config_dir = str(Path.cwd().parent / "configs")
with hydra.initialize_config_dir(config_dir=config_dir, version_base=None):
    config = hydra.compose(config_name="train_nf_qm9")

# 在这里修改数据路径
config["dset"]["data_dir"] = str(Path.cwd().parent / "dataset" / "data")

# 查找最新的模型文件
exp_dir = Path.cwd().parent / "exps" / "neural_field"
if not exp_dir.exists():
    raise FileNotFoundError(f"Experiment directory not found: {exp_dir}")

exp_dirs = [d for d in exp_dir.iterdir() if d.is_dir() and d.name.startswith("nf_qm9_")]

latest_exp_dir = max(exp_dirs, key=lambda x: x.stat().st_mtime)
model_path = latest_exp_dir / "model.pt"

if not model_path.exists():
    raise FileNotFoundError(f"Model file not found: {model_path}")

print(f"Loading model from: {model_path}")

# 加载训练好的权重
checkpoint = fabric.load(str(model_path))
enc = load_neural_field(checkpoint, fabric, config)[0]
dec = load_neural_field(checkpoint, fabric, config)[1]

print("Model loaded successfully!")

enc = enc.to(device)
dec = dec.to(device)

# 创建GNFConverter实例用于数据加载
converter = create_gnf_converter(config, device="cpu")

loader_val = create_field_loaders(config, converter, split="val", fabric=fabric)
batch = next(iter(loader_val)).to(device)
coords, _ = to_dense_batch(batch.pos, batch.batch, fill_value=0)
atoms_channel, _ = to_dense_batch(batch.x, batch.batch, fill_value=PADDING_INDEX)
gt_coords = coords
gt_types = atoms_channel

with torch.no_grad():
    codes = enc(batch)

output_dir = Path.cwd() / "gnf_visualization_outputs"
output_dir.mkdir(exist_ok=True)

results = visualize_reconstruction(
    gt_coords, gt_types, converter, dec, codes,
    output_dir=str(output_dir),
    sample_idx=0 # 定义要分析的样本索引
)

Loading model from: /home/huayuchen/funcmol-main-neuralfield/funcmol/exps/neural_field/nf_qm9_20250726_004704_299326/model.pt
>> loaded dec
>> loaded enc
>> loaded dec
>> loaded enc
Model loaded successfully!
>> val set size: 20042


W0727 23:25:09.012452 1344025 site-packages/torch/_dynamo/variables/tensor.py:776] [7/0] Graph break from `Tensor.item()`, consider setting:
W0727 23:25:09.012452 1344025 site-packages/torch/_dynamo/variables/tensor.py:776] [7/0]     torch._dynamo.config.capture_scalar_outputs = True
W0727 23:25:09.012452 1344025 site-packages/torch/_dynamo/variables/tensor.py:776] [7/0] or:
W0727 23:25:09.012452 1344025 site-packages/torch/_dynamo/variables/tensor.py:776] [7/0]     env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0727 23:25:09.012452 1344025 site-packages/torch/_dynamo/variables/tensor.py:776] [7/0] to include these operations in the captured graph.
W0727 23:25:09.012452 1344025 site-packages/torch/_dynamo/variables/tensor.py:776] [7/0] 
W0727 23:25:09.012452 1344025 site-packages/torch/_dynamo/variables/tensor.py:776] [7/0] Graph break: from user code at:
W0727 23:25:09.012452 1344025 site-packages/torch/_dynamo/variables/tensor.py:776] [7/0]   File "/home/huayuchen/miniconda3/envs/funcmol_


Starting reconstruction for molecule 0
Ground truth atoms: 4
Atom types: [0, 2, 1, 1]
Iteration 0, Atom type 0: grad norm = 0.970468
Iteration 0, Atom type 1: grad norm = 0.956370
Iteration 0, Atom type 2: grad norm = 0.970780
Iteration 0, Atom type 3: grad norm = 0.003595
Iteration 0, Atom type 4: grad norm = 0.002163
[DBSCAN] Total points: 500, Clusters found: 1, Noise points: 0
[DBSCAN] Total points: 500, Clusters found: 1, Noise points: 0
[DBSCAN] Total points: 500, Clusters found: 1, Noise points: 0
[DBSCAN] Total points: 500, Clusters found: 0, Noise points: 500
[DBSCAN] Total points: 500, Clusters found: 0, Noise points: 500


## 3. 可视化一维梯度场

In [4]:
from gnf_visualizer import visualize_1d_gradient_field_comparison

# 创建output文件夹
output_dir = Path.cwd() / "gnf_visualization_outputs"
output_dir.mkdir(exist_ok=True)

# 添加梯度场分析
print("\n" + "="*50)
print("开始梯度场分析...")

# 定义要分析的样本索引
sample_idx = 2

# 检查样本中是否有不同类型的原子
gt_mask = (gt_types[sample_idx] != PADDING_INDEX)
gt_valid_types = gt_types[sample_idx][gt_mask]

# 为每种原子类型进行梯度场分析
for atom_type in range(5):  # 0=C, 1=H, 2=O, 3=N, 4=F
    if (gt_valid_types == atom_type).sum() > 0:
        print(f"\n分析 {['C', 'H', 'O', 'N', 'F'][atom_type]} 原子的梯度场...")
        
        try:
            # 使用分子质心作为y,z坐标
            molecule_center = gt_coords[sample_idx][gt_mask].mean(dim=0)
            
            # 修改保存路径到output文件夹
            save_path = output_dir / f"gradient_field_sample_{sample_idx}_{['C', 'H', 'O', 'N', 'F'][atom_type]}.png"
            
            result = visualize_1d_gradient_field_comparison(
                gt_coords=gt_coords,
                gt_types=gt_types,
                converter=converter,
                decoder=dec,
                codes=codes,
                sample_idx=sample_idx, 
                atom_type=atom_type,
                y_coord=molecule_center[1].item(),
                z_coord=molecule_center[2].item(),
                save_path=str(save_path)  # 转换为字符串
            )
            
            if result:
                print(f"{['C', 'H', 'O', 'N', 'F'][atom_type]} 原子梯度场分析完成！")
                print(f"MSE: {result['mse']:.6f}, MAE: {result['mae']:.6f}")
            else:
                print(f"{['C', 'H', 'O', 'N', 'F'][atom_type]} 原子梯度场分析失败")
                
        except Exception as e:
            print(f"分析 {['C', 'H', 'O', 'N', 'F'][atom_type]} 原子时出错: {e}")
            traceback.print_exc()
    else:
        print(f"样本中没有 {['C', 'H', 'O', 'N', 'F'][atom_type]} 原子，跳过分析")

print("\n所有分析完成！")


开始梯度场分析...

分析 C 原子的梯度场...
Gradient field comparison saved to: /home/huayuchen/funcmol-main-neuralfield/funcmol/notebooks/gnf_visualization_outputs/gradient_field_sample_2_C.png
C 原子梯度场分析完成！
MSE: 0.000019, MAE: 0.003621

分析 H 原子的梯度场...
Gradient field comparison saved to: /home/huayuchen/funcmol-main-neuralfield/funcmol/notebooks/gnf_visualization_outputs/gradient_field_sample_2_H.png
H 原子梯度场分析完成！
MSE: 0.000010, MAE: 0.002555
样本中没有 O 原子，跳过分析
样本中没有 N 原子，跳过分析
样本中没有 F 原子，跳过分析

所有分析完成！
