In [1]:
import os
import sys
import torch
import imageio
import time
from pathlib import Path

# 设置 torch.compile 兼容性
try:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
except ImportError:
    # PyTorch 版本 < 2.0 不支持 torch._dynamo
    print("Warning: torch._dynamo not available in this PyTorch version")

## set up environment
notebook_dir = Path(os.getcwd())  # vecmol/notebooks
project_root = notebook_dir.parent.parent  # vecmol 的父目录
sys.path.insert(0, str(project_root))
print(f"Project root: {project_root}")
print(f"Python path: {sys.path[0]}")


from vecmol.dataset.dataset_field import create_gnf_converter, prepare_data_with_sample_idx
from vecmol.utils.utils_nf import load_neural_field, compute_rmsd

from vecmol.utils.constants import PADDING_INDEX
from vecmol.utils.gnf_visualizer import (
    visualize_1d_gradient_field_augmentation_comparison,
    create_visualization_callback,
    visualize_generated_molecule
)
from vecmol.utils.misc import load_nf_config, create_field_function
from vecmol.utils.data_augmentation import apply_rotation_3d, apply_translation_3d

model_root = "/home/huayuchen/Neurl-voxel/exps/neural_field"

Project root: /home/huayuchen/Neurl-voxel
Python path: /home/huayuchen/Neurl-voxel


In [2]:
nf_ckpt_path = '/home/huayuchen/Neurl-voxel/exps/neural_field/nf_qm9/20251024/lightning_logs/version_0/checkpoints/model-epoch=409.ckpt'
sample_idx = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ckpt_parts = Path(nf_ckpt_path).parts
neural_field_idx = ckpt_parts.index('neural_field')
exp_name = f"{ckpt_parts[neural_field_idx + 1]}/{ckpt_parts[neural_field_idx + 2]}"  # nf_qm9/20250911
ckpt_name = Path(nf_ckpt_path).stem  # model-epoch=39
model_dir = os.path.join(model_root, exp_name)
output_dir = os.path.join(model_dir, ckpt_name)
os.makedirs(output_dir, exist_ok=True)

print(f"Model directory: {model_dir}")
print(f"Checkpoint: {ckpt_name}")
print(f"Output directory: {output_dir}")

config = load_nf_config("train_nf_qm9")
batch, gt_coords, gt_types = prepare_data_with_sample_idx(config, device, sample_idx)

print(f"Data loaded for sample {sample_idx}: {gt_coords.shape}, {gt_types.shape}")

encoder, decoder = load_neural_field(nf_ckpt_path, config)
encoder = encoder.cuda()
decoder = decoder.cuda()
encoder.eval()
decoder.eval()

with torch.no_grad():
    codes = encoder(batch)
field_func = create_field_function(
    mode='predicted',
    decoder=decoder,
    codes=codes
)

# 创建 converter（参考 test_qm9.ipynb）
converter = create_gnf_converter(config)

# 打印 converter 的关键参数
print(f"\n=== Converter 参数 ===")
print(f"step_size: {converter.step_size}")
print(f"eps: {converter.eps}")
print(f"min_samples: {converter.min_samples}")

Model directory: /home/huayuchen/Neurl-voxel/exps/neural_field/nf_qm9/20251024
Checkpoint: model-epoch=409
Output directory: /home/huayuchen/Neurl-voxel/exps/neural_field/nf_qm9/20251024/model-epoch=409
Dataset directory: /home/huayuchen/Neurl-voxel/vecmol/dataset/data
Config loaded successfully: train_nf_qm9
n_iter from converter config: 500
>> val set size: 20042
Data loaded for sample 1: torch.Size([1, 9, 3]), torch.Size([1, 9])
Loading Lightning checkpoint from: /home/huayuchen/Neurl-voxel/exps/neural_field/nf_qm9/20251024/lightning_logs/version_0/checkpoints/model-epoch=409.ckpt
>> loaded dec
>> loaded enc
Model loaded successfully!

=== Converter 参数 ===
step_size: 0.05
eps: 0.05
min_samples: 20


In [None]:
print("\n=== 可视化数据增强后的Codes重建 ===")

# 显示codes信息
print(f"原始Codes shape: {codes.shape}")
print(f"Codes数量: {codes.shape[0]} 个batch, 每个batch有 {codes.shape[1]} 个codes (grid_size^3)")
if len(codes.shape) == 3:
    print(f"每个code的维度: {codes.shape[2]}")

# 获取配置参数
if 'config' in locals():
    grid_size = config.dset.grid_size if hasattr(config.dset, 'grid_size') else 9
    anchor_spacing = config.dset.anchor_spacing if hasattr(config.dset, 'anchor_spacing') else 1.5
else:
    # 默认值
    grid_size = 9
    anchor_spacing = 1.5

save_interval = 100

# 使用通用的create_visualization_callback函数创建回调函数
# 原始codes重建（用于对比）
print("1. 重建原始codes...")
visualization_callback_original, frame_paths_original, fixed_axis_limits_original = create_visualization_callback(
    output_dir=output_dir,
    frame_prefix="frame_aug_original",
    codes_device=codes.device,
    n_atom_types=5
)

recon_coords_original, recon_types_original = converter.gnf2mol(
    decoder,
    codes,
    save_interval=save_interval,
    visualization_callback=visualization_callback_original
)
print(f"   原始分子: {recon_coords_original[0].shape[0]} 个原子")

# 2. 旋转增强
print("2. 应用旋转增强并重建...")
codes_rotated = apply_rotation_3d(codes.clone(), grid_size)
visualization_callback_rot, frame_paths_rot, fixed_axis_limits_rot = create_visualization_callback(
    output_dir=output_dir,
    frame_prefix="frame_aug_rotation",
    codes_device=codes.device,
    n_atom_types=5,
    fixed_axis_limits_dict={'limits': fixed_axis_limits_original['limits']}  # 使用相同的坐标轴范围
)

recon_coords_rot, recon_types_rot = converter.gnf2mol(
    decoder,
    codes_rotated,
    save_interval=save_interval,
    visualization_callback=visualization_callback_rot
)
print(f"   旋转后分子: {recon_coords_rot[0].shape[0] if len(recon_coords_rot) > 0 else 0} 个原子")

# 3. 平移增强
print("3. 应用平移增强并重建...")
codes_translated = apply_translation_3d(codes.clone(), grid_size, anchor_spacing)
visualization_callback_trans, frame_paths_trans, fixed_axis_limits_trans = create_visualization_callback(
    output_dir=output_dir,
    frame_prefix="frame_aug_translation",
    codes_device=codes.device,
    n_atom_types=5,
    fixed_axis_limits_dict={'limits': fixed_axis_limits_original['limits']}  # 使用相同的坐标轴范围
)

recon_coords_trans, recon_types_trans = converter.gnf2mol(
    decoder,
    codes_translated,
    save_interval=save_interval,
    visualization_callback=visualization_callback_trans
)
print(f"   平移后分子: {recon_coords_trans[0].shape[0] if len(recon_coords_trans) > 0 else 0} 个原子")

# 4. 旋转+平移增强
print("4. 应用旋转+平移增强并重建...")
codes_both = codes.clone()
codes_both = apply_rotation_3d(codes_both, grid_size)
codes_both = apply_translation_3d(codes_both, grid_size, anchor_spacing)
visualization_callback_both, frame_paths_both, fixed_axis_limits_both = create_visualization_callback(
    output_dir=output_dir,
    frame_prefix="frame_aug_both",
    codes_device=codes.device,
    n_atom_types=5,
    fixed_axis_limits_dict={'limits': fixed_axis_limits_original['limits']}  # 使用相同的坐标轴范围
)

recon_coords_both, recon_types_both = converter.gnf2mol(
    decoder,
    codes_both,
    save_interval=save_interval,
    visualization_callback=visualization_callback_both
)
print(f"   旋转+平移后分子: {recon_coords_both[0].shape[0] if len(recon_coords_both) > 0 else 0} 个原子")


=== 可视化数据增强后的Codes重建 ===
原始Codes shape: torch.Size([1, 729, 128])
Codes数量: 1 个batch, 每个batch有 729 个codes (grid_size^3)
每个code的维度: 128
1. 重建原始codes...


[DBSCAN] Total points: 1000, Clusters found: 2, Noise points: 0
[DBSCAN] Total points: 1000, Clusters found: 6, Noise points: 0
[DBSCAN] Total points: 1000, Clusters found: 1, Noise points: 0
[DBSCAN] Total points: 1000, Clusters found: 0, Noise points: 1000
[DBSCAN] Total points: 1000, Clusters found: 0, Noise points: 1000
   原始分子: 9 个原子
2. 应用旋转增强并重建...
[DBSCAN] Total points: 1000, Clusters found: 3, Noise points: 0
[DBSCAN] Total points: 1000, Clusters found: 8, Noise points: 0
[DBSCAN] Total points: 1000, Clusters found: 1, Noise points: 0
[DBSCAN] Total points: 1000, Clusters found: 3, Noise points: 929
[DBSCAN] Total points: 1000, Clusters found: 1, Noise points: 976
   旋转后分子: 16 个原子
3. 应用平移增强并重建...
[DBSCAN] Total points: 1000, Clusters found: 2, Noise points: 0
[DBSCAN] Total points: 1000, Clusters found: 6, Noise points: 0
[DBSCAN] Total points: 1000, Clusters found: 1, Noise points: 0
[DBSCAN] Total points: 1000, Clusters found: 0, Noise points: 1000
[DBSCAN] Total points: 1000

In [4]:
# 5. 可视化4种codes的1d field对比
print("\n5. 可视化4种codes的1d field对比...")
# 为每种增强类型创建field_func
field_funcs = {
    'original': create_field_function(mode='predicted', decoder=decoder, codes=codes),
    'rotation': create_field_function(mode='predicted', decoder=decoder, codes=codes_rotated),
    'translation': create_field_function(mode='predicted', decoder=decoder, codes=codes_translated),
    'both': create_field_function(mode='predicted', decoder=decoder, codes=codes_both)
}

atom_types = [0, 1, 2, 3, 4]  # C, H, O, N, F
save_path_1d = os.path.join(output_dir, f"field1d_augmentation_sample_{sample_idx}")

gradient_results = visualize_1d_gradient_field_augmentation_comparison(
    field_funcs=field_funcs,
    atom_types=atom_types,
    x_range=None,
    y_coord=0.0,
    z_coord=0.0,
    save_path=save_path_1d,
    sample_idx=sample_idx
)

if gradient_results:
    print(f"\n1D Field Augmentation Comparison (model: {model_dir}):")
    print(f"  Available atom types: {gradient_results['available_atom_types']}")
    
    # 打印每个原子类型的统计信息
    for atom_name, result in gradient_results['all_results'].items():
        print(f"\n  {atom_name}:")
        print(f"    Saved to: {result['save_path']}")
        for aug_type, stats in result['stats'].items():
            print(f"    {aug_type}: Mean={stats['magnitude_mean']:.6f}, Std={stats['magnitude_std']:.6f}, "
                  f"Max={stats['magnitude_max']:.6f}, Min={stats['magnitude_min']:.6f}")

# 保存路径
print(f"Output directory: {output_dir}")
save_paths = {
    'original': os.path.join(output_dir, "augmentation_original.png"),
    'rotation': os.path.join(output_dir, "augmentation_rotation.png"),
    'translation': os.path.join(output_dir, "augmentation_translation.png"),
    'both': os.path.join(output_dir, "augmentation_both.png"),
}

# 可视化原始
if len(recon_coords_original) > 0:
    valid_mask = recon_types_original[0] != -1
    if valid_mask.any():
        visualize_generated_molecule(
            recon_coords_original[0][valid_mask],
            recon_types_original[0][valid_mask],
            save_path=save_paths['original']
        )
        print(f"   原始分子图: {save_paths['original']}")

# 可视化旋转
if len(recon_coords_rot) > 0 and len(recon_coords_rot[0]) > 0:
    valid_mask = recon_types_rot[0] != -1
    if valid_mask.any():
        visualize_generated_molecule(
            recon_coords_rot[0][valid_mask],
            recon_types_rot[0][valid_mask],
            save_path=save_paths['rotation']
        )
        print(f"   旋转后分子图: {save_paths['rotation']}")

# 可视化平移
if len(recon_coords_trans) > 0 and len(recon_coords_trans[0]) > 0:
    valid_mask = recon_types_trans[0] != -1
    if valid_mask.any():
        visualize_generated_molecule(
            recon_coords_trans[0][valid_mask],
            recon_types_trans[0][valid_mask],
            save_path=save_paths['translation']
        )
        print(f"   平移后分子图: {save_paths['translation']}")

# 可视化旋转+平移
if len(recon_coords_both) > 0 and len(recon_coords_both[0]) > 0:
    valid_mask = recon_types_both[0] != -1
    if valid_mask.any():
        visualize_generated_molecule(
            recon_coords_both[0][valid_mask],
            recon_types_both[0][valid_mask],
            save_path=save_paths['both']
        )
        print(f"   旋转+平移后分子图: {save_paths['both']}")

# 计算RMSD（如果原子数量匹配）

print("\n=== RMSD 对比 ===")
if len(recon_coords_original) > 0 and len(recon_coords_rot) > 0:
    if len(recon_coords_original[0]) == len(recon_coords_rot[0]):
        rmsd_rot = compute_rmsd(
            recon_coords_original[0].cpu(),
            recon_coords_rot[0].cpu()
        )
        print(f"   原始 vs 旋转: RMSD = {rmsd_rot:.4f} Å")

if len(recon_coords_original) > 0 and len(recon_coords_trans) > 0:
    if len(recon_coords_original[0]) == len(recon_coords_trans[0]):
        rmsd_trans = compute_rmsd(
            recon_coords_original[0].cpu(),
            recon_coords_trans[0].cpu()
        )
        print(f"   原始 vs 平移: RMSD = {rmsd_trans:.4f} Å")

if len(recon_coords_original) > 0 and len(recon_coords_both) > 0:
    if len(recon_coords_original[0]) == len(recon_coords_both[0]):
        rmsd_both = compute_rmsd(
            recon_coords_original[0].cpu(),
            recon_coords_both[0].cpu()
        )
        print(f"   原始 vs 旋转+平移: RMSD = {rmsd_both:.4f} Å")

# 创建GIF动画
print("\n=== 创建重建过程动画 ===")
gif_paths = {
    'original': os.path.join(output_dir, "augmentation_original.gif"),
    'rotation': os.path.join(output_dir, "augmentation_rotation.gif"),
    'translation': os.path.join(output_dir, "augmentation_translation.gif"),
    'both': os.path.join(output_dir, "augmentation_both.gif"),
}

# 为每个增强类型创建GIF动画
for aug_type, frame_paths, gif_path in [
    ('原始', frame_paths_original, gif_paths['original']),
    ('旋转', frame_paths_rot, gif_paths['rotation']),
    ('平移', frame_paths_trans, gif_paths['translation']),
    ('旋转+平移', frame_paths_both, gif_paths['both'])
]:
    if not frame_paths:
        print(f"   {aug_type}: 没有帧文件，跳过GIF创建")
        continue
    
    print(f"   创建 {aug_type} 重建过程动画...")
    try:
        with imageio.get_writer(gif_path, mode='I', duration=0.1, fps=15, loop=1) as writer:
            for frame_path in frame_paths:
                try:
                    if not os.path.exists(frame_path):
                        continue
                    time.sleep(0.01)  # 短暂等待确保文件写入完成
                    if os.path.getsize(frame_path) == 0:
                        continue
                    frame = imageio.imread(frame_path)
                    writer.append_data(frame)
                except Exception as e:
                    continue
                finally:
                    # 清理临时帧文件
                    try:
                        if os.path.exists(frame_path):
                            os.remove(frame_path)
                    except:
                        pass
        print(f"   {aug_type} GIF动画: {gif_path}")
    except Exception as e:
        print(f"   {aug_type} GIF创建失败: {e}")

print("\n可视化完成！所有图像和动画已保存到输出目录。")


5. 可视化4种codes的1d field对比...
使用默认 x 轴范围: (-11.0, 11.0)
Field 1D augmentation comparison (atom_type=C) saved to: /home/huayuchen/Neurl-voxel/exps/neural_field/nf_qm9/20251024/model-epoch=409/field1d_augmentation_sample_1_atom_C.png
Field 1D augmentation comparison (atom_type=H) saved to: /home/huayuchen/Neurl-voxel/exps/neural_field/nf_qm9/20251024/model-epoch=409/field1d_augmentation_sample_1_atom_H.png
Field 1D augmentation comparison (atom_type=O) saved to: /home/huayuchen/Neurl-voxel/exps/neural_field/nf_qm9/20251024/model-epoch=409/field1d_augmentation_sample_1_atom_O.png
Field 1D augmentation comparison (atom_type=N) saved to: /home/huayuchen/Neurl-voxel/exps/neural_field/nf_qm9/20251024/model-epoch=409/field1d_augmentation_sample_1_atom_N.png
Field 1D augmentation comparison (atom_type=F) saved to: /home/huayuchen/Neurl-voxel/exps/neural_field/nf_qm9/20251024/model-epoch=409/field1d_augmentation_sample_1_atom_F.png

1D Field Augmentation Comparison (model: /home/huayuchen/Neurl-v

  frame = imageio.imread(frame_path)


   原始 GIF动画: /home/huayuchen/Neurl-voxel/exps/neural_field/nf_qm9/20251024/model-epoch=409/augmentation_original.gif
   创建 旋转 重建过程动画...
   旋转 GIF动画: /home/huayuchen/Neurl-voxel/exps/neural_field/nf_qm9/20251024/model-epoch=409/augmentation_rotation.gif
   创建 平移 重建过程动画...
   平移 GIF动画: /home/huayuchen/Neurl-voxel/exps/neural_field/nf_qm9/20251024/model-epoch=409/augmentation_translation.gif
   创建 旋转+平移 重建过程动画...
   旋转+平移 GIF动画: /home/huayuchen/Neurl-voxel/exps/neural_field/nf_qm9/20251024/model-epoch=409/augmentation_both.gif

可视化完成！所有图像和动画已保存到输出目录。
