In [55]:
import time
from collections import defaultdict
import logging

import h5py
import numpy as np
from ase import Atoms
from ase.io import read, write
from dscribe.descriptors import SOAP
from dscribe.kernels import AverageKernel
from joblib import Parallel, delayed


In [56]:
def compute_soap_descriptors(structures, njobs, species, r_cut, n_max, l_max):
    """
    Function: Compute SOAP descriptors for a list of structures
    Input:
        SOAP inputs
    Output:
        List of SOAP descriptors in numpy.ndarray format
    """
    start_time = time.time()
    soap = SOAP(
        species=species,
        r_cut=r_cut,
        n_max=n_max,
        l_max=l_max
    )
    soap_descriptors = soap.create(structures, n_jobs=njobs)
    end_time = time.time()
    logging.info(f"SOAP descriptors computed in {end_time - start_time:.2f} seconds")

    # 由于此处调用本函数计算时，structures 中的结构默认是化学式相同的，因此返回的是 np.ndarray
    # 需要返回 list 类型，其中元素为 np.ndarray
    if type(soap_descriptors) == np.ndarray:
        return [i for i in soap_descriptors]
    
    # 原子数不相同时直接返回 list 即可
    return soap_descriptors

def compute_similarity(cand_soap, ref_soap, kernel_metric="laplacian"):
    """
    Function: Compute the similarity between candidate and reference SOAP descriptors using the specified kernel metric.
    Input:
        cand_soap: Candidate SOAP descriptor
        ref_soap: Reference SOAP descriptor
        kernel_metric: Kernel metric to use for similarity computation (default: laplacian)
    Output:
        Similarity score between candidate and reference SOAP descriptors
    """
    re = AverageKernel(metric=kernel_metric)
    return re.create(cand_soap, ref_soap)

In [57]:
def compare_and_update_structures(ref_structures, cand_structures, njobs=8, species=["H", "C", "O", "N"], r_cut=10.0, n_max=3, l_max=3, threshold=0.99):
    """
    Function:
    Compare candidate structures with reference structures.
    Update the reference database one molecule by one molecule.

    Input:
        ref_structures: list of ase.Atoms objects, reference structures
        cand_structures: list of ase.Atoms objects, candidate structures
        njobs: int, number of jobs to run in parallel
        species: list of str, species to consider  
        r_cut: float, cutoff radius for SOAP calculation
        n_max: int, number of radial basis functions
        l_max: int, maximum degree of spherical harmonics
        average: str, type of averaging for SOAP calculation
        sparse: bool, whether to use sparse representation for SOAP calculation
        threshold: float, similarity threshold for reducing candidate structures
        log_file: str, path to log file

    Output:
        ref_structures: list of ase.Atoms objects, updated reference structures
        soap_ref: list of soap_descriptors, SOAP descriptors for updated reference structures
    """
    round_num = 0

    while True:
        round_num += 1

        if round_num == 1:
            # 初次计算全部的 SOAP 描述符
            # 这里还可以改进，先计算描述符或直接读入描述符
            soap_ref = compute_soap_descriptors(ref_structures, njobs, species, r_cut, n_max, l_max)
            soap_cand = compute_soap_descriptors(cand_structures, njobs, species, r_cut, n_max, l_max)

            # 并行计算 cand_structures 中每个结构与 ref_structures 中所有结构的相似度
            start_time = time.time()
            re_kernel_results = Parallel(n_jobs=njobs)(delayed(compute_similarity)(soap_cand[i:i+1], soap_ref) for i in range(len(soap_cand)))
            re_kernel = np.vstack(re_kernel_results)
            end_time = time.time()
            logging.info(f"Round {round_num}: Similarity computation completed in {end_time - start_time:.2f} seconds")

            # 选取 cand_structures 中每个结构与 ref_structures 中所有结构的最大相似度
            max_similarity_values = np.max(re_kernel, axis=1)
        
        else:
            # 并行计算 cand_structures 中每个结构与 ref_structures 新加入的结构的相似度
            # 变量覆盖释放内存空间
            ### 这里可以写成一个函数，便于复用
            start_time = time.time()
            re_kernel_results = Parallel(n_jobs=njobs)(delayed(compute_similarity)(soap_cand[i:i+1], [soap_ref[-1]]) for i in range(len(soap_cand)))
            re_kernel = np.vstack(re_kernel_results)
            end_time = time.time()
            logging.info(f"Round {round_num}: Similarity computation completed in {end_time - start_time:.2f} seconds")
            ### 这里可以写成一个函数，便于复用

            # 将原先 max_similarity_values 与新加入的点的堆叠
            max_similarity_values = np.max(np.column_stack((max_similarity_values, np.max(re_kernel, axis=1))), axis=1)


        # 删除 cand_structures 中与 ref_structures 中相似度高于 threshold 的所有结构
        # 更新 soap_cand, cand_structures, max_similarity_values
        old_cand_num = len(cand_structures)
        soap_cand = [soap_cand[i] for i in range(len(soap_cand)) if round(max_similarity_values[i],5) < threshold]
        cand_structures = [cand_structures[i] for i in range(len(cand_structures)) if round(max_similarity_values[i],5) < threshold]
        max_similarity_values = np.array([max_similarity_values[i] 
                                            for i in range(len(max_similarity_values)) if round(max_similarity_values[i],5) < threshold])
        new_cand_num = len(cand_structures)
        logging.info(f"Round {round_num}: Cand structures reduced from {old_cand_num} to {new_cand_num}")

        # 如果 cand_structures 中没有元素，则退出循环
        if new_cand_num == 0:
            break

        # 将 cand_structures 与 ref_structures 中最不相似的结构添加到 ref_structures 中
        min_max_similarity = np.min([round(i,5) for i in max_similarity_values])
        min_max_similarity_index = np.argmin(max_similarity_values)
        ref_structures.append(cand_structures[min_max_similarity_index])
        soap_ref.append(soap_cand[min_max_similarity_index])
        logging.info(f"Round {round_num}: Added structure with min max similarity {min_max_similarity}.")
        logging.info(f"Ref structures: {len(ref_structures)}, Cand structures: {len(cand_structures)}")
        logging.info("---------")


    logging.info(f"No structures remaining in candidate list.")
    logging.info(f"Ref structures: {len(ref_structures)}, Cand structures: {len(cand_structures)}")
    logging.info("---------")
        
    return ref_structures, soap_ref

In [58]:
def save_soap_to_hdf5(soap_dict, hdf5_name):
    """
    Save SOAP descriptors to an HDF5 file.

    Example input structure:
    defaultdict(list, {
        'C2H3N3O': [
            array([[...], [...], ...]),  # SOAP descriptors for the first molecule
            array([[...], [...], ...])   # SOAP descriptors for the second molecule
        ],
        'C2H6': [
            array([[...], [...], ...])   # SOAP descriptors for another molecule
        ]
    })

    - Each value is an array representing the SOAP descriptors for a molecule, with shape (N, M),
    where N is the number of descriptors and M is the dimension of each descriptor.
    """
    with h5py.File(hdf5_name, "w") as hdf:
        for formula, soap_list in soap_dict.items():
            stacked_soap = np.array(soap_list)
            hdf.create_dataset(formula, data=stacked_soap)
    logging.info(f"SOAP descriptors saved to '{hdf5_name}'")

def read_soap_from_hdf5(hdf5_name):
    """
    Read SOAP descriptors from an HDF5 file.

    Returns:
        defaultdict(list): A dictionary with molecule formulas as keys
        and lists of corresponding SOAP descriptors as values.
    """
    soap_dict = defaultdict(list)

    with h5py.File(hdf5_name, "r") as hdf:
        for formula in hdf.keys():
            soap_descriptors = hdf[formula][:]
            soap_dict[formula].append(soap_descriptors)

    return soap_dict

def defaultdict_profiler(soap_data):
    """
    Print the available formulas and the shape of their corresponding SOAP descriptors.
    """
    formulas = list(soap_data.keys())
    print("Available formulas:", formulas)

    # 遍历每个分子式并读取 SOAP 描述符
    for formula in formulas:
        soap_descriptors = soap_data[formula][:][0]
        print(f"Formula: {formula}, Shape of SOAP descriptors: {soap_descriptors.shape}")
    print("---------")

In [62]:
# 设置日志记录
logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s - %(levelname)s - %(message)s',
                    handlers=[
                        logging.FileHandler("output.log", mode='w'), # 输出到日志文件
                        logging.StreamHandler() # 同时输出到终端
                    ])
logging.info('Log begin')

start_time = time.time()

# 读取数据
ref_structures = read('./t1x/rxn000x.xyz', index=':')
cand_structures = read('./t1x/rxn000x_all.xyz', index=':')

# 比较并更新结构
updated_ref_structures, updated_soap_list = compare_and_update_structures(ref_structures, cand_structures, njobs=24)

# 保存更新后的结构和 SOAP 到 HDF5
soap_dict = defaultdict(list)
for i in range(len(updated_soap_list)):
    soap_result = updated_soap_list[i]
    formula = updated_ref_structures[i].get_chemical_formula()
    soap_dict[formula].append(soap_result)
save_soap_to_hdf5(soap_dict, "updated_ref_soap_descriptors.h5")

# 保存更新后的参考结构
write("updated_ref_structures.xyz", updated_ref_structures)
logging.info("Updated reference structures saved to 'updated_ref_structures.xyz'")

end_time = time.time()
logging.info(f"Done! Total time elapsed: {end_time - start_time:.2f} seconds")
logging.info('Log end')

2024-10-29 18:46:16,652 - INFO - Log begin
2024-10-29 18:46:19,962 - INFO - SOAP descriptors computed in 2.31 seconds
2024-10-29 18:46:20,791 - INFO - SOAP descriptors computed in 0.83 seconds
2024-10-29 18:52:55,091 - INFO - Round 1: Similarity computation completed in 394.30 seconds
2024-10-29 18:52:55,180 - INFO - Round 1: Cand structures reduced from 8068 to 4327
2024-10-29 18:52:55,193 - INFO - Round 1: Added structure with min max similarity 0.95547.
2024-10-29 18:52:55,193 - INFO - Ref structures: 982, Cand structures: 4327
2024-10-29 18:52:55,194 - INFO - ---------
2024-10-29 18:52:55,796 - INFO - Round 2: Similarity computation completed in 0.60 seconds
2024-10-29 18:52:55,837 - INFO - Round 2: Cand structures reduced from 4327 to 4326
2024-10-29 18:52:55,851 - INFO - Round 2: Added structure with min max similarity 0.95671.
2024-10-29 18:52:55,851 - INFO - Ref structures: 983, Cand structures: 4326
2024-10-29 18:52:55,851 - INFO - ---------
2024-10-29 18:52:56,253 - INFO - Ro

In [63]:
soap_data = read_soap_from_hdf5("updated_ref_soap_descriptors.h5")
defaultdict_profiler(soap_data)

Available formulas: ['C2H3N3O', 'C3H8O2', 'C6H10', 'CHN3O']
Formula: C2H3N3O, Shape of SOAP descriptors: (682, 9, 312)
Formula: C3H8O2, Shape of SOAP descriptors: (1567, 13, 312)
Formula: C6H10, Shape of SOAP descriptors: (206, 16, 312)
Formula: CHN3O, Shape of SOAP descriptors: (139, 6, 312)
---------
