In [57]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [58]:
import os
import sys
import torch
import hydra
from pathlib import Path
from lightning import Fabric

# 设置 torch.compile 兼容性
try:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
except ImportError:
    # PyTorch 版本 < 2.0 不支持 torch._dynamo
    print("Warning: torch._dynamo not available in this PyTorch version")

## set up environment
project_root = Path(os.getcwd()).parent
sys.path.insert(0, str(project_root))

from funcmol.utils.constants import PADDING_INDEX
from funcmol.utils.gnf_visualizer import (
    load_config, load_model, create_converter, 
    prepare_data, prepare_data_with_sample_idx, visualize_1d_gradient_field_comparison,
    GNFVisualizer
)

# 模型根目录
model_root = "/datapool/data3/storage/pengxingang/pxg/hyc/funcmol-main-neuralfield/exps/neural_field"

In [59]:
# TODO：手动指定是 gt_only、gt_pred 还是 denoiser_only 模式
option = 'denoiser_only'  # 'gt_only', 'gt_pred', 'denoiser_only'

# TODO：手动指定 checkpoint 文件路径，会根据ckpt_path自动提取exp_name
nf_ckpt_path = '/datapool/data3/storage/pengxingang/pxg/hyc/funcmol-main-neuralfield/exps/neural_field/nf_qm9/20250911/lightning_logs/version_1/checkpoints/model-epoch=39.ckpt'
fm_ckpt_path = '/datapool/data3/storage/pengxingang/pxg/hyc/funcmol-main-neuralfield/exps/funcmol/fm_qm9/20250912/lightning_logs/version_0/checkpoints/last.ckpt'

# TODO：手动指定 sample_idx（仅用于 gt_only 和 gt_pred 模式）
sample_idx = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if option == 'denoiser_only':
    # 对于 denoiser_only 模式，使用 FuncMol 的路径
    ckpt_parts = Path(fm_ckpt_path).parts
    funcmol_idx = ckpt_parts.index('funcmol')
    exp_name = f"{ckpt_parts[funcmol_idx + 1]}/{ckpt_parts[funcmol_idx + 2]}"  # fm_qm9/20250912
    ckpt_name = Path(fm_ckpt_path).stem  # funcmol-epoch=319
    model_dir = os.path.join("/datapool/data3/storage/pengxingang/pxg/hyc/funcmol-main-neuralfield/exps/funcmol", exp_name)
    output_dir = os.path.join(model_dir, ckpt_name)
    os.makedirs(output_dir, exist_ok=True)
    print(f"Option: {option}")
    print(f"FuncMol model directory: {model_dir}")
    print(f"FuncMol checkpoint: {ckpt_name}")
    print(f"Neural Field checkpoint: {nf_ckpt_path}")
    print(f"Output directory: {output_dir}")
else:
    # 对于 gt_only 和 gt_pred 模式，使用 Neural Field 的路径
    ckpt_parts = Path(nf_ckpt_path).parts
    neural_field_idx = ckpt_parts.index('neural_field')
    exp_name = f"{ckpt_parts[neural_field_idx + 1]}/{ckpt_parts[neural_field_idx + 2]}"  # nf_qm9/20250911
    ckpt_name = Path(nf_ckpt_path).stem  # model-epoch=39
    model_dir = os.path.join(model_root, exp_name)
    output_dir = os.path.join(model_dir, ckpt_name)
    os.makedirs(output_dir, exist_ok=True)
    print(f"Option: {option}")
    print(f"Model directory: {model_dir}")
    print(f"Checkpoint: {ckpt_name}")
    print(f"Output directory: {output_dir}")

Option: denoiser_only
FuncMol model directory: /datapool/data3/storage/pengxingang/pxg/hyc/funcmol-main-neuralfield/exps/funcmol/fm_qm9/20250912
FuncMol checkpoint: last
Neural Field checkpoint: /datapool/data3/storage/pengxingang/pxg/hyc/funcmol-main-neuralfield/exps/neural_field/nf_qm9/20250911/lightning_logs/version_1/checkpoints/model-epoch=39.ckpt
Output directory: /datapool/data3/storage/pengxingang/pxg/hyc/funcmol-main-neuralfield/exps/funcmol/fm_qm9/20250912/last


In [60]:
## Load data
fabric = Fabric(
    accelerator="auto",
    devices=1,
    precision="32-true",
    strategy="auto"
)
fabric.launch()

# 使用 load_config 函数从 configs 目录加载配置
config = load_config("train_nf_qm9")

if option == 'denoiser_only':
    # 对于 denoiser_only 模式，不需要加载数据集
    batch, gt_coords, gt_types = None, None, None
    print("Denoiser-only mode: No dataset loading required")
else:
    # 准备包含特定样本的数据
    batch, gt_coords, gt_types = prepare_data_with_sample_idx(fabric, config, device, sample_idx)
    print(f"Data loaded for sample {sample_idx}: {gt_coords.shape}, {gt_types.shape}")

Dataset directory: /datapool/data2/home/pxg/data/hyc/funcmol-main-neuralfield/funcmol/dataset/data
Config loaded successfully: train_nf_qm9
n_iter from converter config: 1000
Denoiser-only mode: No dataset loading required


In [61]:
print(f"\nProcessing model from: {model_dir}")

## Load model
if option == 'denoiser_only':
    # 加载 Neural Field 模型和 FuncMol 模型
    from funcmol.models.funcmol import FuncMol
    from funcmol.utils.utils_fm import load_checkpoint_fm
    
    # 加载 Neural Field 模型
    print(f"Loading Neural Field model from: {nf_ckpt_path}")
    encoder, decoder = load_model(fabric, config, model_path=nf_ckpt_path)
    encoder = fabric.setup_module(encoder)
    decoder = fabric.setup_module(decoder)
    encoder.eval()
    decoder.eval()
    
    # 加载 FuncMol 模型
    print(f"Loading FuncMol model from: {fm_ckpt_path}")
    funcmol_config = {
        "smooth_sigma": 0.5,
        "denoiser": {
            "use_gnn": True,
            "n_hidden_units": 256,
            "num_blocks": 4,
            "dropout": 0.1,
            "k_neighbors": 8,
            "cutoff": 5.0,
            "radius": 3.0,
            "use_radius_graph": True
        },
        "encoder": config.encoder,
        "decoder": config.decoder,
        "dset": config.dset
    }
    funcmol = FuncMol(funcmol_config, fabric)
    funcmol = funcmol.cuda()
    funcmol, _ = load_checkpoint_fm(funcmol, fm_ckpt_path, fabric=fabric)
    funcmol.eval()
    
    # 定义 denoiser 场函数
    def denoiser_field_func(points):
        # 生成随机噪声代码
        grid_size = config.dset.grid_size
        code_dim = config.encoder.code_dim
        batch_size = 1
        
        # 创建随机噪声代码
        noise_codes = torch.randn(batch_size, grid_size**3, code_dim, device=points.device)
        
        # 通过 denoiser 生成分子代码
        with torch.no_grad():
            denoised_codes = funcmol(noise_codes)
        
        # 使用 decoder 生成场
        if points.dim() == 2:  # [n_points, 3]
            points = points.unsqueeze(0)  # [1, n_points, 3]
        elif points.dim() == 3:  # [batch, n_points, 3]
            pass
        else:
            raise ValueError(f"Unexpected points shape: {points.shape}")
        
        result = decoder(points, denoised_codes[0:1])
        if result.dim() == 4:  # [batch, n_points, n_atom_types, 3]
            return result[0]  # 取第一个batch
        else:
            return result
    
    field_func = denoiser_field_func
    codes = None  # denoiser 模式不需要预计算的 codes
    
elif option == 'gt_pred':
    # 使用手动指定的 checkpoint 文件路径
    if not os.path.exists(nf_ckpt_path):
        raise FileNotFoundError(f"Checkpoint file not found: {nf_ckpt_path}")
    
    print(f"Loading model from: {nf_ckpt_path}")
    encoder, decoder = load_model(fabric, config, model_path=nf_ckpt_path)
    
    # 确保模型在正确的设备上
    encoder = fabric.setup_module(encoder)
    decoder = fabric.setup_module(decoder)
    encoder.eval()
    decoder.eval()
    
    # 生成 codes
    print(f"Batch device: {batch.pos.device}")
    print(f"Encoder device: {next(encoder.parameters()).device}")
    with torch.no_grad():
        codes = encoder(batch)
    # 定义预测场函数
    def predicted_field_func(points):
        # 确保 points 是正确的形状
        if points.dim() == 2:  # [n_points, 3]
            points = points.unsqueeze(0)  # [1, n_points, 3]
        elif points.dim() == 3:  # [batch, n_points, 3]
            pass
        else:
            raise ValueError(f"Unexpected points shape: {points.shape}")
        
        result = decoder(points, codes[0:1])  # 现在codes只有1个样本，所以用索引0
        # 确保返回 [n_points, n_atom_types, 3] 形状
        if result.dim() == 4:  # [batch, n_points, n_atom_types, 3]
            return result[0]  # 取第一个batch
        else:
            return result
    field_func = predicted_field_func
else:  # gt only
    encoder, decoder = None, None
    # 定义真实场函数
    def gt_field_func(points):
        gt_mask = (gt_types[0] != PADDING_INDEX)  # 现在只有1个样本，所以用索引0
        gt_valid_coords = gt_coords[0][gt_mask]
        gt_valid_types = gt_types[0][gt_mask]
        
        # 确保 points 是正确的形状
        if points.dim() == 2:  # [n_points, 3]
            points = points.unsqueeze(0)  # [1, n_points, 3]
        elif points.dim() == 3:  # [batch, n_points, 3]
            pass
        else:
            raise ValueError(f"Unexpected points shape: {points.shape}")
        
        result = converter.mol2gnf(
            gt_valid_coords.unsqueeze(0),
            gt_valid_types.unsqueeze(0),
            points
        )
        # 确保返回 [n_points, n_atom_types, 3] 形状
        if result.dim() == 4:  # [batch, n_points, n_atom_types, 3]
            return result[0]  # 取第一个batch
        else:
            return result
    field_func = gt_field_func
    codes = None

converter = create_converter(config, device)
print(f"Model loaded successfully!")


Processing model from: /datapool/data3/storage/pengxingang/pxg/hyc/funcmol-main-neuralfield/exps/funcmol/fm_qm9/20250912
Loading Neural Field model from: /datapool/data3/storage/pengxingang/pxg/hyc/funcmol-main-neuralfield/exps/neural_field/nf_qm9/20250911/lightning_logs/version_1/checkpoints/model-epoch=39.ckpt
Loading Lightning checkpoint from: /datapool/data3/storage/pengxingang/pxg/hyc/funcmol-main-neuralfield/exps/neural_field/nf_qm9/20250911/lightning_logs/version_1/checkpoints/model-epoch=39.ckpt
Model loaded successfully!
Loading FuncMol model from: /datapool/data3/storage/pengxingang/pxg/hyc/funcmol-main-neuralfield/exps/funcmol/fm_qm9/20250912/lightning_logs/version_0/checkpoints/last.ckpt
>> loaded denoiser
>> loaded model trained for 1284 epochs
GNF Converter created with n_iter: 1000, gradient_field_method: tanh, n_atom_types: 5
Model loaded successfully!


In [62]:
if option != 'denoiser_only':
    # 可视化一维梯度场对比（所有原子类型）
    atom_types = [0, 1, 2, 3, 4]  # C, H, O, N, F
    save_path = os.path.join(output_dir, f"field1d_sample_{sample_idx}")

    gradient_results = visualize_1d_gradient_field_comparison(
        gt_coords=gt_coords,
        gt_types=gt_types,
        converter=converter,
        field_func=field_func,
        sample_idx=0,  # 数据中只有1个样本，所以用索引0
        atom_types=atom_types,  # 传入列表，不需要循环
        x_range=None,
        y_coord=0.0,
        z_coord=0.0,
        save_path=save_path,  # save_path已经包含了正确的sample_idx (14441)
        display_sample_idx=sample_idx,  # 用于文件名和显示的原始样本索引
    )

    if gradient_results:
        print(f"Gradient field comparison (model: {model_dir}):")
        print(f"  Available atom types: {gradient_results['available_atom_types']}")
        
        # 打印每个原子类型的统计信息
        for atom_name, stats in gradient_results['all_results'].items():
            print(f"  {atom_name}: MSE={stats['mse']:.6f}, MAE={stats['mae']:.6f}")
            print(f"    Saved to: {stats['save_path']}")
else:
    print("Denoiser-only mode: Skipping gradient field comparison")

Denoiser-only mode: Skipping gradient field comparison


In [63]:
if option == 'denoiser_only':
    # 对于 denoiser_only 模式，生成分子并可视化
    print("\n=== 执行 denoiser_field 分子生成 ===")
    
    # 导入必要的函数
    from funcmol.utils.gnf_visualizer import visualize_single_molecule
    
    # 生成随机噪声代码
    grid_size = config.dset.grid_size
    code_dim = config.encoder.code_dim
    batch_size = 1
    
    # 创建随机噪声代码
    noise_codes = torch.randn(batch_size, grid_size**3, code_dim, device=device)
    
    # 通过 denoiser 生成分子代码
    with torch.no_grad():
        denoised_codes = funcmol(noise_codes)
    
    # 使用 decoder 生成场并重建分子
    print("Generating molecular field and reconstructing molecule...")
    recon_coords, recon_types = converter.gnf2mol(
        decoder,
        denoised_codes,
        fabric=fabric
    )
    
    print(f"Generated molecule: {recon_coords[0].shape[0]} atoms")
    print(f"Atom types: {recon_types[0].unique().tolist()}")
    
    # 生成单张分子图片
    print("Creating single molecule visualization...")
    single_mol_path = os.path.join(output_dir, "generated_molecule_single.png")
    visualize_single_molecule(
        coords=recon_coords[0],
        types=recon_types[0],
        save_path=single_mol_path,
        title="Generated Molecule (Denoiser Field)",
        figsize=(10, 8)
    )
    
    # 创建生成过程动画（使用现有的 GNFVisualizer 函数）
    print("Creating generation process animation...")
    visualizer = GNFVisualizer(output_dir)
    
    # 定义生成过程的场函数（每次调用都生成新的随机噪声）
    def generation_field_func(points):
        # 生成随机噪声代码
        grid_size = config.dset.grid_size
        code_dim = config.encoder.code_dim
        batch_size = 1
        
        # 创建随机噪声代码
        noise_codes = torch.randn(batch_size, grid_size**3, code_dim, device=points.device)
        
        # 通过 denoiser 生成分子代码
        with torch.no_grad():
            denoised_codes = funcmol(noise_codes)
        
        # 使用 decoder 生成场
        if points.dim() == 2:  # [n_points, 3]
            points = points.unsqueeze(0)  # [1, n_points, 3]
        elif points.dim() == 3:  # [batch, n_points, 3]
            pass
        else:
            raise ValueError(f"Unexpected points shape: {points.shape}")
        
        result = decoder(points, denoised_codes[0:1])
        if result.dim() == 4:  # [batch, n_points, n_atom_types, 3]
            return result[0]  # 取第一个batch
        else:
            return result
    
    # 使用现有的 create_generation_animation 函数，设置 use_recon_dir=False
    results = visualizer.create_generation_animation(
        converter=converter,
        field_func=generation_field_func,
        sample_idx=0,
        save_interval=100,
        create_1d_plots=False,  # 不创建1D场可视化
        use_recon_dir=False,    # 直接在主目录下生成文件
        fixed_axis=True         # 使用固定坐标轴，避免GIF中视角震荡
    )
    
    print(f"\n=== Denoiser Field 生成结果 ===")
    print(f"Generated atoms: {recon_coords[0].shape[0]}")
    print(f"Atom type distribution: {dict(zip(*torch.unique(recon_types[0], return_counts=True)))}")
    print(f"单张分子图: {single_mol_path}")
    print(f"生成过程动画: {results['gif_path']}")
    
else:
    # 根据option设置重建列表
    if option == 'gt_only':
        rec_list = ['gt_field']
    else:
        rec_list = ['predicted_field', 'gt_field']

    # 创建可视化器
    visualizer = GNFVisualizer(output_dir)

    # 为每种重建类型执行可视化
    for rec_type in rec_list:
        print(f"\n=== 执行 {rec_type} 重建 ===")
        
        # 根据重建类型设置场函数
        if rec_type == 'gt_field':
            # 定义真实场函数
            def gt_field_func(points):
                gt_mask = (gt_types[0] != PADDING_INDEX)  # 现在只有1个样本，所以用索引0
                gt_valid_coords = gt_coords[0][gt_mask]
                gt_valid_types = gt_types[0][gt_mask]
                return converter.mol2gnf(
                    gt_valid_coords.unsqueeze(0),
                    gt_valid_types.unsqueeze(0),
                    points
                )
            field_func = gt_field_func
        else:  # predicted_field
            # 定义预测场函数
            def predicted_field_func(points):
                if points.dim() == 2:
                    points = points.unsqueeze(0)
                elif points.dim() == 3:
                    pass
                else:
                    raise ValueError(f"Unexpected points shape: {points.shape}")
                result = decoder(points, codes[0:1])  # 现在codes只有1个样本，所以用索引0
                return result[0] if result.dim() == 4 else result
            field_func = predicted_field_func
        
        # 执行重建可视化
        results = visualizer.create_reconstruction_animation(
            gt_coords=gt_coords,
            gt_types=gt_types,
            converter=converter,
            field_func=field_func,
            save_interval=100,
            animation_name=f"recon_sample_{sample_idx}_{rec_type}",
            sample_idx=0  # 现在只有1个样本，所以用索引0
        )

        print(f"\n=== {rec_type} 重建结果 ===")
        print(f"RMSD: {results['final_rmsd']:.4f}")
        print(f"Reconstruction Loss: {results['final_loss']:.4f}")
        print(f"KL Divergence (orig->recon): {results['final_kl_1to2']:.4f}")
        print(f"KL Divergence (recon->orig): {results['final_kl_2to1']:.4f}")
        print(f"GIF动画: {results['gif_path']}")
        print(f"对比图: {results['comparison_path']}")


=== 执行 denoiser_field 分子生成 ===
Generating molecular field and reconstructing molecule...
>>     Memory status at iteration 0: Allocated=0.02GB, Reserved=0.98GB
>>     Memory status at iteration 50: Allocated=0.02GB, Reserved=0.98GB
>>     Memory status at iteration 100: Allocated=0.02GB, Reserved=0.98GB
>>     Memory status at iteration 150: Allocated=0.02GB, Reserved=0.98GB
>>     Memory status at iteration 200: Allocated=0.02GB, Reserved=0.98GB
>>     Memory status at iteration 250: Allocated=0.02GB, Reserved=0.98GB
>>     Memory status at iteration 300: Allocated=0.02GB, Reserved=0.98GB
>>     Memory status at iteration 350: Allocated=0.02GB, Reserved=0.98GB
>>     Memory status at iteration 400: Allocated=0.02GB, Reserved=0.98GB
>>     Memory status at iteration 450: Allocated=0.02GB, Reserved=0.98GB
>>     Memory status at iteration 500: Allocated=0.02GB, Reserved=0.98GB
>>     Memory status at iteration 550: Allocated=0.02GB, Reserved=0.98GB
>>     Memory status at iteration 600