In [33]:
%load_ext autoreload
%autoreload 2

# %cd /home/huayuchen/Neurl-voxel/funcmol/notebooks
%cd /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/funcmol/notebooks

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
/datapool/data3/storage/pengxingang/pxg/hyc/funcmol-main-neuralfield/funcmol/notebooks


In [34]:
import os
import sys
os.environ['CUDA_VISIBLE_DEVICES'] = "7"

import torch
import hydra
import numpy as np
import random
from pathlib import Path
from omegaconf import OmegaConf

# 设置 torch.compile 兼容性
try:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
except ImportError:
    # PyTorch 版本 < 2.0 不支持 torch._dynamo
    print("Warning: torch._dynamo not available in this PyTorch version")

## set up environment
# 使用硬编码的项目根目录，确保路径一致性
# 所有路径都基于 /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield
project_root = Path("/datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield")
sys.path.insert(0, str(project_root))
print(f"Project root: {project_root}")
print(f"Python path: {sys.path[0]}")

from funcmol.dataset.dataset_field import create_gnf_converter, prepare_data_with_sample_idx
from funcmol.utils.utils_nf import load_neural_field
from funcmol.utils.utils_fm import load_checkpoint_fm
from funcmol.utils.constants import PADDING_INDEX
from funcmol.utils.gnf_visualizer import (
    visualize_1d_gradient_field_comparison, 
    GNFVisualizer,
    visualize_generated_molecule,
    create_visualization_callback,
    create_gif_from_frames
)
from funcmol.utils.misc import load_nf_config, load_funcmol_config, create_field_function
from funcmol.models.funcmol import FuncMol

seed = 1234
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# 模型根目录（与项目根目录保持一致）
model_root = str(project_root / "exps" / "neural_field")
config = load_nf_config("train_nf_qm9")

Project root: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield
Python path: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield
Dataset directory: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/funcmol/dataset/data
Config loaded successfully: train_nf_qm9
n_iter from converter config: 1000


In [35]:
##### SETTINGS #####
# TODO：手动指定是 gt_only、gt_pred 还是 denoiser_only 模式
option = 'gt_pred'  # 'gt_only', 'gt_pred', 'denoiser_only'

# TODO：手动指定 checkpoint 文件路径，会根据ckpt_path自动提取exp_name
nf_ckpt_path = '/datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/neural_field/nf_qm9/20260123/lightning_logs/version_1/checkpoints/last.ckpt'
# nf_ckpt_path = '/datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/neural_field/nf_qm9/20251121/lightning_logs/version_1/checkpoints/model-epoch=999.ckpt'
fm_ckpt_path = '/datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/funcmol/fm_qm9/20260116/lightning_logs/version_2/checkpoints/last.ckpt'
# nf_ckpt_path = '/datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/neural_field/nf_qm9/20251024/lightning_logs/version_0/checkpoints/model-epoch=409.ckpt'
# fm_ckpt_path = '/datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/funcmol/fm_qm9/20251108/lightning_logs/version_1/checkpoints/last.ckpt'
# nf_ckpt_path = '/home/huayuchen/Neurl-voxel/exps/neural_field/nf_qm9/20250911/lightning_logs/version_1/checkpoints/model-epoch=39.ckpt'
# fm_ckpt_path = '/home/huayuchen/Neurl-voxel/exps/funcmol/fm_qm9/20250917/lightning_logs/version_22/checkpoints/model-epoch=144.ckpt'

# TODO：手动指定 sample_idx（仅用于 gt_only 和 gt_pred 模式）
sample_idx = 3362  # 2,7,74,83,108,158,186,375,404,433
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# TODO：手动指定 codes 编号 （仅用于 denoiser_only 模式）
# codes_source = 'load'  # 'load' 或 'sample' 
codes_source = 'sample'  # 'load' 或 'sample' 

# TODO：手动指定 codes 目录（仅用于 denoiser_only 模式且 codes_source='load'）
# 例如：'/datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/funcmol/fm_qm9/20251223/samples/20251223_version_2_last/molecule'
codes_idx = 0  # 例如：0 表示 code_0000_tanh.pt
codes_dir = '/datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/funcmol/fm_qm9/20251223/samples/20251223_version_2_last/molecule'

In [36]:
if option == 'denoiser_only':
    # 对于 denoiser_only 模式，使用 FuncMol 的路径
    ckpt_parts = Path(fm_ckpt_path).parts
    funcmol_idx = ckpt_parts.index('funcmol')
    exp_name = f"{ckpt_parts[funcmol_idx + 1]}/{ckpt_parts[funcmol_idx + 2]}"  # fm_qm9/20250912
    ckpt_name = Path(fm_ckpt_path).stem  # funcmol-epoch=319
    model_dir = os.path.join("/datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/funcmol", exp_name)
    output_dir = os.path.join(model_dir, "visualize", f"{Path(fm_ckpt_path).parent.parent.name}_{Path(fm_ckpt_path).parent.name}_{ckpt_name}")

    os.makedirs(output_dir, exist_ok=True)
    print(f"Option: {option}")
    print(f"FuncMol model directory: {model_dir}")
    print(f"FuncMol checkpoint: {ckpt_name}")
    print(f"Neural Field checkpoint: {nf_ckpt_path}")
    print(f"Output directory: {output_dir}")
else:
    # 对于 gt_only 和 gt_pred 模式，使用 Neural Field 的路径
    ckpt_parts = Path(nf_ckpt_path).parts
    neural_field_idx = ckpt_parts.index('neural_field')
    exp_name = f"{ckpt_parts[neural_field_idx + 1]}/{ckpt_parts[neural_field_idx + 2]}"  # nf_qm9/20250911
    ckpt_name = Path(nf_ckpt_path).stem  # model-epoch=39
    model_dir = os.path.join(model_root, exp_name)
    output_dir = os.path.join(model_dir, ckpt_name)
    os.makedirs(output_dir, exist_ok=True)
    batch, gt_coords, gt_types = prepare_data_with_sample_idx(config, device, sample_idx)
    print(f"Data loaded for sample {sample_idx}: {gt_coords.shape}, {gt_types.shape}")
    print(f"Option: {option}")
    print(f"Model directory: {model_dir}")
    print(f"Checkpoint: {ckpt_name}")
    print(f"Output directory: {output_dir}")

[ClusteringProcessor] enable_bond_validation = True
>> val set size: 20042
Data loaded for sample 3362: torch.Size([1, 13, 3]), torch.Size([1, 13])
Option: gt_pred
Model directory: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/neural_field/nf_qm9/20260123
Checkpoint: last
Output directory: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/neural_field/nf_qm9/20260123/last


In [37]:
## Load model, generate or load codes
if option == 'denoiser_only':       
    print(f"Loading Neural Field model from: {nf_ckpt_path}")
    encoder, decoder = load_neural_field(nf_ckpt_path, config)
    # 确保模型在正确的设备上
    encoder = encoder.to(device)
    decoder = decoder.to(device)
    encoder.eval()
    decoder.eval()
    
    # 使用YAML配置文件加载FuncMol配置
    funcmol_config = load_funcmol_config("train_fm_qm9", config)
    
    # 创建FuncMol模型
    funcmol = FuncMol(funcmol_config)
    funcmol = funcmol.to(device)
    
    # 加载checkpoint并获取code_stats
    funcmol, code_stats = load_checkpoint_fm(funcmol, fm_ckpt_path)
    funcmol.eval()
    
    # 设置decoder的code_stats
    decoder.set_code_stats(code_stats)
    
    print(">> FuncMol model loaded successfully!")
    configs_dir = project_root / "funcmol" / "configs"
    with hydra.initialize_config_dir(config_dir=str(configs_dir), version_base=None):
        sample_fm_config = hydra.compose(config_name="sample_fm")
    
    # 转换为字典格式（与 sample_fm.py 第51行完全一致）
    config_dict = OmegaConf.to_container(sample_fm_config, resolve=True)
        
    # 创建 converter（与 sample_fm.py 第151行完全一致）
    converter = create_gnf_converter(config_dict)
    
    # 获取 codes 的维度信息
    grid_size = config_dict.get('dset', {}).get('grid_size', 9)  # 与 sample_fm.py 一致
    code_dim = config_dict.get('encoder', {}).get('code_dim', 128)  # 与 sample_fm.py 一致
    
    if codes_source == 'load':
        mol_save_dir = Path(codes_dir)
        
        code_path = mol_save_dir / f"code_{codes_idx:04d}_tanh.pt"
        print(f"使用手动指定的 codes 目录:")
        print(f"  codes_dir: {codes_dir}")
        print(f"  文件名格式: code_{codes_idx:04d}_tanh.pt")
        print(f"\n最终使用的codes路径: {code_path}")
        
        if not code_path.exists():
            raise FileNotFoundError(
                f"Codes file not found: {code_path}\n"
                f"Please check if the file exists or verify the codes_dir path."
            )
        
        print(f"Loading codes from: {code_path}")
        codes = torch.load(code_path, map_location=device)
        # 确保codes的形状正确 [1, grid_size^3, code_dim]
        if codes.dim() == 2:
            # 如果是 [grid_size^3, code_dim]，添加batch维度
            codes = codes.unsqueeze(0)
        print(f"Loaded codes shape: {codes.shape}")
        
    else:
        # 随机采样 codes
        print("Sampling codes using DDPM...")
        with torch.no_grad():
            codes = funcmol.sample_ddpm(shape=(1, grid_size**3, code_dim), progress=False)
        print(f"Sampled codes shape: {codes.shape}")

    # 使用统一的场计算函数（ddpm模式，使用已加载/采样的codes）
    field_func = create_field_function(
        mode='ddpm',
        decoder=decoder,
        codes=codes
    )
    print("Codes loaded/sampled and field_func set.")
    
    
elif option == 'gt_pred':
    # 使用手动指定的 checkpoint 文件路径
    if not os.path.exists(nf_ckpt_path):
        raise FileNotFoundError(f"Checkpoint file not found: {nf_ckpt_path}")
    
    print(f"Loading model from: {nf_ckpt_path}")
    encoder, decoder = load_neural_field(nf_ckpt_path, config)
    
    # 确保模型在正确的设备上
    encoder = encoder.to(device)
    decoder = decoder.to(device)
    encoder.eval()
    decoder.eval()
    
    # 生成 codes
    print(f"Batch device: {batch.pos.device}")
    print(f"Encoder device: {next(encoder.parameters()).device}")
    print(f"Batch size (number of graphs): {batch.num_graphs}")
    with torch.no_grad():
        codes = encoder(batch)
    print(f"Codes shape: {codes.shape}")
    
    # 使用统一的场计算函数
    field_func = create_field_function(
        mode='predicted',
        decoder=decoder,
        codes=codes
    )
else:  # gt only
    encoder, decoder = None, None
    codes = None

converter = create_gnf_converter(config)

# 打印 converter 的关键参数
print(f"\n=== Converter 参数 ===")
print(f"step_size: {converter.step_size}")
print(f"eps: {converter.eps}")
print(f"min_samples: {converter.min_samples}")
# field 参数
print(f"field_variance_k_neighbors: {converter.field_variance_k_neighbors}")
print(f"field_variance_weight: {converter.field_variance_weight}")

# 创建场函数（在converter定义之后）
if option == 'gt_only':
    field_func = create_field_function(
        mode='gt',
        converter=converter,
        gt_coords=gt_coords,
        gt_types=gt_types
    )
print(f"Model loaded successfully!")

Loading model from: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/neural_field/nf_qm9/20260123/lightning_logs/version_1/checkpoints/last.ckpt
Loading Lightning checkpoint from: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/neural_field/nf_qm9/20260123/lightning_logs/version_1/checkpoints/last.ckpt
>> loaded dec
>> loaded enc
Model loaded successfully!
Batch device: cuda:0
Encoder device: cuda:0
Batch size (number of graphs): 1
Codes shape: torch.Size([1, 125, 384])
[ClusteringProcessor] enable_bond_validation = True

=== Converter 参数 ===
step_size: 0.015
eps: 0.1
min_samples: 3
field_variance_k_neighbors: 10
field_variance_weight: 0.01
Model loaded successfully!


plot field

In [38]:
if option != 'denoiser_only':
    # 可视化一维梯度场对比（所有原子类型）
    atom_types = [0, 1, 2, 3, 4]  # C, H, O, N, F
    # 将1D可视化结果保存到 experiment 根目录下的 recon_animation 文件夹，避免生成 model-epoch=999 目录
    field_1d_output_dir = os.path.join(model_dir, "recon_animation")
    os.makedirs(field_1d_output_dir, exist_ok=True)
    save_path = os.path.join(field_1d_output_dir, f"field_1d_sample_{sample_idx}")

    gradient_results = visualize_1d_gradient_field_comparison(
        gt_coords=gt_coords,
        gt_types=gt_types,
        converter=converter,
        field_func=field_func,
        sample_idx=0,  # 数据中只有1个样本，所以用索引0
        atom_types=atom_types,  # 传入列表，不需要循环
        x_range=None,
        y_coord=0.0,
        z_coord=0.0,
        save_path=save_path,
        display_sample_idx=sample_idx,  # 用于文件名和显示的原始样本索引
    )

    if gradient_results:
        print(f"Gradient field comparison (model: {model_dir}):")
        print(f"  Available atom types: {gradient_results['available_atom_types']}")
        
        # 打印每个原子类型的统计信息
        for atom_name, stats in gradient_results['all_results'].items():
            print(f"  {atom_name}: MSE={stats['mse']:.6f}, MAE={stats['mae']:.6f}")
            print(f"    Saved to: {stats['save_path']}")

elif option == 'denoiser_only':
    # 可视化denoiser生成的codes对应的梯度场在1维上的变化曲线
    print("\n=== 可视化1D梯度场（仅预测） ===")
    
    # 确定文件编号：如果使用load模式，使用codes_idx；否则使用sample_0
    if codes_source == 'load':
        field1d_idx = codes_idx
    else:
        field1d_idx = 0
    
    # 使用统一的场计算函数
    atom_types = [0, 1, 2, 3, 4]  # C, H, O, N, F
    save_path = os.path.join(output_dir, f"field1d_gen_sample_{field1d_idx}")
    
    # 调用修改后的函数，不传入gt_coords和gt_types，只绘制预测的梯度场
    gradient_results = visualize_1d_gradient_field_comparison(
        gt_coords=None,  # 无ground truth
        gt_types=None,   # 无ground truth
        converter=None,  # 无ground truth时converter可以为None
        field_func=field_func,
        sample_idx=0,
        atom_types=atom_types,
        x_range=None,  # 使用默认范围(-5.0, 5.0)
        y_coord=0.0,
        z_coord=0.0,
        save_path=save_path,
        display_sample_idx=field1d_idx,  # 使用field1d_idx作为文件名标识符
    )
    
    if gradient_results:
        print(f"Gradient field visualization (generation mode):")
        print(f"  Available atom types: {gradient_results['available_atom_types']}")
        
        # 打印每个原子类型的统计信息
        for atom_name, stats in gradient_results['all_results'].items():
            print(f"  {atom_name}:")
            print(f"    Magnitude: Mean={stats.get('magnitude_mean', 'N/A'):.6f}, Std={stats.get('magnitude_std', 'N/A'):.6f}")
            print(f"    Saved to: {stats['save_path']}")

自动计算 x 轴范围: (-3.712660026550293, 3.712659788131714)
Decoder输出的场形状: torch.Size([3000, 5, 3]) (n_points=3000, n_atom_types=5, dim=3)
Field 1D comparison (atom_type=C) saved to: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/neural_field/nf_qm9/20260123/recon_animation/field_1d_sample_3362_atom_C.png
Field 1D comparison (atom_type=H) saved to: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/neural_field/nf_qm9/20260123/recon_animation/field_1d_sample_3362_atom_H.png
Field 1D comparison (atom_type=O) saved to: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/neural_field/nf_qm9/20260123/recon_animation/field_1d_sample_3362_atom_O.png
Field 1D comparison (atom_type=N) saved to: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/neural_field/nf_qm9/20260123/recon_animation/field_1d_sample_3362_atom_N.png
Field 1D comparison (atom_type=F) saved to: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/neural_field/nf_qm9/2

In [39]:
# 临时修改可视化函数：采样点大小改为两倍，去掉右上角图注
import matplotlib.pyplot as plt
from funcmol.utils.gnf_visualizer import ATOM_COLORS, _get_max_atom_type, _get_atom_names, _setup_3d_axis
from typing import Optional, Dict
import torch

def visualize_generation_step_modified(
    current_points: torch.Tensor,
    iteration: int,
    save_path: str,
    current_types: torch.Tensor,
    fixed_axis_limits: Optional[Dict] = None):
    """临时修改版本：采样点大小两倍，去掉图注"""
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(111, projection='3d')
    
    points_np = current_points.detach().cpu().numpy()
    if current_types is not None:
        types_np = current_types.detach().cpu().numpy()
        n_atom_types = _get_max_atom_type(current_types)
        atom_names = _get_atom_names(n_atom_types)
        for atom_type in range(n_atom_types):
            mask = (types_np == atom_type)
            if mask.sum() > 0:
                color = ATOM_COLORS.get(atom_type, 'gray')
                atom_name = atom_names[atom_type] if atom_type < len(atom_names) else f"Type{atom_type}"
                ax.scatter(points_np[mask, 0], points_np[mask, 1], points_np[mask, 2], 
                          c=color, marker='.', s=120,  # 从40改为80（两倍）
                          label=atom_name, alpha=0.5)
    else:
        ax.scatter(points_np[:, 0], points_np[:, 1], points_np[:, 2], 
                  c='blue', marker='.', s=80, label='Generated Points', alpha=0.5)  # 从40改为80
    
    # 设置坐标轴
    if fixed_axis_limits is not None:
        ax.set_xlim(fixed_axis_limits['x_min'], fixed_axis_limits['x_max'])
        ax.set_ylim(fixed_axis_limits['y_min'], fixed_axis_limits['y_max'])
        ax.set_zlim(fixed_axis_limits['z_min'], fixed_axis_limits['z_max'])
        ax.set_xlabel('X (Å)')
        ax.set_ylabel('Y (Å)')
        ax.set_zlabel('Z (Å)')
        ax.grid(True, alpha=0.3)
    elif len(points_np) > 0:
        _setup_3d_axis(ax, [points_np], margin=1.0)
    
    ax.view_init(elev=30, azim=60)
    ax.set_box_aspect([1, 1, 1])
    
    ax.set_title(f"Generated Molecule - Iteration {iteration}")
    # 注释掉图注
    # ax.legend(loc='upper right', bbox_to_anchor=(1.0, 1.0))
    
    save_dir = os.path.dirname(save_path)
    if save_dir and not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def visualize_reconstruction_step_modified(
    coords: torch.Tensor,
    current_points: torch.Tensor,
    iteration: int,
    save_path: str,
    coords_types: Optional[torch.Tensor] = None,
    points_types: Optional[torch.Tensor] = None,
    fixed_axis_limits: Optional[Dict] = None):
    """临时修改版本：采样点大小两倍，去掉图注"""
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(111, projection='3d')
    
    coords_np = coords.detach().cpu().numpy()
    if coords_types is not None:
        coords_types_np = coords_types.detach().cpu().numpy()
        n_atom_types = max(_get_max_atom_type(coords_types), _get_max_atom_type(points_types) if points_types is not None else 8)
        atom_names = _get_atom_names(n_atom_types)
        for atom_type in range(n_atom_types):
            mask = (coords_types_np == atom_type)
            if mask.sum() > 0:
                color = ATOM_COLORS.get(atom_type, 'gray')
                atom_name = atom_names[atom_type] if atom_type < len(atom_names) else f"Type{atom_type}"
                ax.scatter(coords_np[mask, 0], coords_np[mask, 1], coords_np[mask, 2], 
                          c=color, marker='o', s=100, 
                          label=f'Original {atom_name}', alpha=0.7)
    else:
        ax.scatter(coords_np[:, 0], coords_np[:, 1], coords_np[:, 2], 
                  c='blue', marker='o', s=100, label='Original', alpha=0.8)
    
    points_np = current_points.detach().cpu().numpy()
    if points_types is not None:
        points_types_np = points_types.detach().cpu().numpy()
        n_atom_types = max(_get_max_atom_type(coords_types) if coords_types is not None else 8, _get_max_atom_type(points_types))
        atom_names = _get_atom_names(n_atom_types)
        for atom_type in range(n_atom_types):
            mask = (points_types_np == atom_type)
            if mask.sum() > 0:
                color = ATOM_COLORS.get(atom_type, 'gray')
                atom_name = atom_names[atom_type] if atom_type < len(atom_names) else f"Type{atom_type}"
                ax.scatter(points_np[mask, 0], points_np[mask, 1], points_np[mask, 2],
                          c=color, marker='.', s=120,  # 从40改为80（两倍）
                          label=f'Current {atom_name}', alpha=0.5)
    else:
        ax.scatter(points_np[:, 0], points_np[:, 1], points_np[:, 2], 
                  c='red', marker='.', s=120, label='Current Points', alpha=0.5)  # 从40改为80
    
    # 设置坐标轴
    if fixed_axis_limits is not None:
        ax.set_xlim(fixed_axis_limits['x_min'], fixed_axis_limits['x_max'])
        ax.set_ylim(fixed_axis_limits['y_min'], fixed_axis_limits['y_max'])
        ax.set_zlim(fixed_axis_limits['z_min'], fixed_axis_limits['z_max'])
        ax.set_xlabel('X (Å)')
        ax.set_ylabel('Y (Å)')
        ax.set_zlabel('Z (Å)')
        ax.grid(True, alpha=0.3)
    else:
        coords_list = [coords_np, points_np]
        _setup_3d_axis(ax, coords_list, margin=1.0)
    
    ax.view_init(elev=30, azim=60)
    ax.set_box_aspect([1, 1, 1])
    
    ax.set_title(f"Iteration {iteration}")
    # 注释掉图注
    # ax.legend(loc='upper right', bbox_to_anchor=(1.0, 1.0))
    
    save_dir = os.path.dirname(save_path)
    if save_dir and not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
        print(f"Created directory: {save_dir}")
    
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

# 覆盖原函数
from funcmol.utils import gnf_visualizer
gnf_visualizer.visualize_generation_step = visualize_generation_step_modified
gnf_visualizer.visualize_reconstruction_step = visualize_reconstruction_step_modified

print("已临时修改可视化函数：采样点大小改为两倍，去掉右上角图注")

已临时修改可视化函数：采样点大小改为两倍，去掉右上角图注


In [40]:
if option == 'denoiser_only':
    # 对于 denoiser_only 模式，使用DDPM采样得到固定codes并可视化
    print("\n=== 执行 DDPM 分子生成 ===")

    grid_size = config.dset.grid_size
    code_dim = config.encoder.code_dim

    # 确定文件编号：如果使用load模式，使用codes_idx；否则使用sample_0
    if codes_source == 'load':
        file_idx = codes_idx
    else:
        file_idx = 0

    # 使用gnf_visualizer中的函数创建可视化回调
    visualization_callback, frame_paths, fixed_axis_limits_dict = create_visualization_callback(
        output_dir=output_dir,
        frame_prefix=f"frame_gen_sample_{file_idx}",
        codes_device=codes.device,
        n_atom_types=5
    )

    # 使用上一个单元已加载的 codes，重建分子，使用可视化
    print("Generating molecular field and reconstructing molecule (loaded codes)...")
    save_interval = 100
    
    # 如果启用了聚类历史记录，则保存聚类历史文件
    gnf2mol_kwargs = {
        "decoder": decoder,
        "codes": codes,
        "save_interval": save_interval,
        "visualization_callback": visualization_callback,
        "enable_timing": True
    }
    
    # 如果converter启用了聚类历史记录，则保存到文件
    if converter.enable_clustering_history:
        clustering_history_dir = os.path.join(output_dir, "clustering_history")
        os.makedirs(clustering_history_dir, exist_ok=True)
        gnf2mol_kwargs["save_clustering_history"] = True
        gnf2mol_kwargs["clustering_history_dir"] = clustering_history_dir
        gnf2mol_kwargs["sample_id"] = file_idx  # 使用file_idx作为样本标识符
        print(f"聚类历史将保存到: {clustering_history_dir}")
        print(f"  - Converter enable_clustering_history: {converter.enable_clustering_history}")
        print(f"  - 样本标识符 (sample_id): {file_idx}")
    else:
        print(f"警告: Converter未启用聚类历史记录 (enable_clustering_history={converter.enable_clustering_history})")
    
    recon_coords, recon_types = converter.gnf2mol(**gnf2mol_kwargs)
    
    # 检查聚类历史文件是否已生成
    if converter.enable_clustering_history and gnf2mol_kwargs.get("save_clustering_history", False):
        clustering_history_dir = gnf2mol_kwargs.get("clustering_history_dir")
        if clustering_history_dir:
            sdf_file = os.path.join(clustering_history_dir, f"sample_{file_idx:04d}_clustering_history.sdf")
            txt_file = os.path.join(clustering_history_dir, f"sample_{file_idx:04d}_clustering_history.txt")
            if os.path.exists(sdf_file):
                print(f"✓ 聚类历史SDF文件已生成: {sdf_file}")
            else:
                print(f"✗ 聚类历史SDF文件未生成: {sdf_file}")
            if os.path.exists(txt_file):
                print(f"✓ 聚类历史TXT文件已生成: {txt_file}")
            else:
                print(f"✗ 聚类历史TXT文件未生成: {txt_file}")
    
    print(f"Generated molecule: {recon_coords[0].shape[0]} atoms")
    print(f"Atom types: {recon_types[0].unique().tolist()}")

    # 创建 GIF 动画（使用gnf_visualizer中的函数）
    print("Creating generation process animation from saved frames...")
    gif_path = os.path.join(output_dir, f"funcmol_gen_sample_{file_idx}.gif")
    create_gif_from_frames(
        frame_paths=frame_paths,
        gif_path=gif_path,
        duration=0.1,
        fps=15,
        loop=1,
        cleanup_frames=True
    )

    # 保存最终生成的分子
    final_path = os.path.join(output_dir, f"funcmol_gen_sample_{file_idx}_final.png")
    # 过滤掉填充的原子（类型为-1的原子）
    valid_mask = recon_types[0] != -1
    if valid_mask.any():
        final_coords_valid = recon_coords[0][valid_mask]
        final_types_valid = recon_types[0][valid_mask]
        visualize_generated_molecule(
            final_coords_valid, final_types_valid, save_path=final_path
        )
        
        # 保存SDF格式的分子文件
        from funcmol.utils.utils_base import xyz_to_sdf
        from funcmol.utils.constants import ELEMENTS_HASH_INV
        
        sdf_path = os.path.join(output_dir, f"funcmol_gen_sample_{file_idx}.sdf")
        try:
            # 构建元素列表
            max_atom_type = int(final_types_valid.max().item()) if len(final_types_valid) > 0 else 7
            element_list = []
            for i in range(max_atom_type + 1):
                element_list.append(ELEMENTS_HASH_INV.get(i, f"X{i}"))
            
            # 转换为numpy数组
            final_coords_np = final_coords_valid.detach().cpu().numpy()
            final_types_np = final_types_valid.detach().cpu().numpy()
            
            # 生成SDF字符串并保存
            sdf_string = xyz_to_sdf(final_coords_np, final_types_np, element_list)
            if sdf_string:
                with open(sdf_path, 'w', encoding='utf-8') as sdf_file:
                    sdf_file.write(sdf_string)
                print(f"SDF文件已保存: {sdf_path}")
            else:
                print(f"警告: 无法生成SDF文件（无有效原子）")
                sdf_path = None
        except Exception as e:
            print(f"警告: 保存SDF文件时出错: {e}")
            sdf_path = None
    else:
        print("Warning: No valid atoms generated")
        sdf_path = None

    print(f"\n=== DDPM Field 生成结果 ===")
    print(f"Generated atoms: {recon_coords[0].shape[0]}")
    print(f"Atom type distribution: {dict(zip(*torch.unique(recon_types[0], return_counts=True)))}")
    print(f"最终分子图: {final_path}")
    print(f"生成过程动画: {gif_path}")
    if sdf_path:
        print(f"SDF文件: {sdf_path}")
    
else:
    # 根据option设置重建列表
    if option == 'gt_only':
        rec_list = ['gt_field']
    else:
        rec_list = ['predicted_field', 'gt_field']

    # 创建可视化器（将重建动画统一保存到 experiment 根目录下的 recon_animation 文件夹）
    recon_output_dir = os.path.join(model_dir, "recon_animation")
    os.makedirs(recon_output_dir, exist_ok=True)
    visualizer = GNFVisualizer(recon_output_dir)

    # 为每种重建类型执行可视化
    for rec_type in rec_list:
        print(f"\n=== 执行 {rec_type} 重建 ===")
        
        # 根据重建类型设置decoder和codes（使用与gnf2mol相同的方法）
        if rec_type == 'gt_field':
            # 对于gt_field模式，创建dummy decoder和codes
            grid_size = config.dset.grid_size
            code_dim = config.encoder.code_dim
            dummy_codes = torch.randn(1, grid_size**3, code_dim, device=gt_coords.device)
            
            # 创建dummy decoder，返回ground truth field
            class DummyDecoder:
                def __init__(self, converter, gt_coords, gt_types):
                    self.converter = converter
                    self.gt_coords = gt_coords
                    self.gt_types = gt_types
                
                def __call__(self, query_points, codes):
                    return self.converter.mol2gnf(
                        self.gt_coords.unsqueeze(0), 
                        self.gt_types.unsqueeze(0), 
                        query_points
                    )
            
            # 过滤掉padding的原子
            gt_mask = (gt_types[0] != PADDING_INDEX)
            gt_valid_coords_for_decoder = gt_coords[0][gt_mask]
            gt_valid_types_for_decoder = gt_types[0][gt_mask]
            
            dummy_decoder = DummyDecoder(converter, gt_valid_coords_for_decoder, gt_valid_types_for_decoder)
            rec_decoder = dummy_decoder
            rec_codes = dummy_codes
        else:  # predicted_field
            # 对于predicted_field模式，直接使用decoder和codes
            rec_decoder = decoder
            rec_codes = codes
        
        # 执行重建可视化（使用gnf2mol方法，与field_recon.py完全一致）
        results = visualizer.create_reconstruction_animation(
            gt_coords=gt_coords,
            gt_types=gt_types,
            converter=converter,
            decoder=rec_decoder,
            codes=rec_codes,
            save_interval=100,
            animation_name=f"recon_sample_{sample_idx}_{rec_type}",
            sample_idx=sample_idx
        )

        print(f"\n=== {rec_type} 重建结果 ===")
        print(f"RMSD: {results['final_rmsd']:.4f}")
        print(f"Reconstruction Loss: {results['final_loss']:.4f}")
        print(f"KL Divergence (orig->recon): {results['final_kl_1to2']:.4f}")
        print(f"KL Divergence (recon->orig): {results['final_kl_2to1']:.4f}")
        print(f"GIF动画: {results['gif_path']}")
        print(f"对比图: {results['comparison_path']}")
        if results.get('sdf_path'):
            print(f"SDF文件: {results['sdf_path']}")


=== 执行 predicted_field 重建 ===

Starting reconstruction for molecule 3362
Ground truth atoms: 13

[Iteration 0] 原子类型: C, min_samples=3
  处理簇数: 5 (新簇: 5, 待重试: 0)
  结果: ✓通过=5, ✗拒绝=0
  当前参考点: 5 个, 类型分布: {'C': 5}

[Iteration 0] 原子类型: H, min_samples=3
  处理簇数: 4 (新簇: 4, 待重试: 0)
  结果: ✓通过=4, ✗拒绝=0
  当前参考点: 9 个, 类型分布: {'C': 5, 'H': 4}

[Iteration 0] 原子类型: O, min_samples=3
  处理簇数: 2 (新簇: 2, 待重试: 0)
  结果: ✓通过=2, ✗拒绝=0
  当前参考点: 11 个, 类型分布: {'C': 5, 'H': 4, 'O': 2}

[Iteration 0] 原子类型: N, min_samples=3
  处理簇数: 1 (新簇: 1, 待重试: 0)
  结果: ✓通过=1, ✗拒绝=0
  当前参考点: 12 个, 类型分布: {'C': 5, 'H': 4, 'O': 2, 'N': 1}

[Iteration 0] 原子类型: F, min_samples=3
  处理簇数: 1 (新簇: 1, 待重试: 0)
  结果: ✓通过=1, ✗拒绝=0
  当前参考点: 13 个, 类型分布: {'C': 5, 'H': 4, 'O': 2, 'N': 1, 'F': 1}
SDF文件已保存: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/exps/neural_field/nf_qm9/20260123/recon_animation/recon/recon_sample_3362_predicted_field.sdf

=== predicted_field 重建结果 ===
RMSD: 0.0309
Reconstruction Loss: 0.0360
KL Divergence (orig->recon