In [469]:
import os
import time
import logging
import argparse
from collections import defaultdict
from itertools import compress

import numpy as np
from ase import Atoms
from ase.io import read, write
import h5py

from dscribe.descriptors import SOAP
from dscribe.kernels import AverageKernel
from joblib import Parallel, delayed

In [470]:
def compute_soap_descriptors(structures, njobs, species, r_cut, n_max, l_max, logger):
    """
    Function: Compute SOAP descriptors for a list of structures
    Input:
        SOAP inputs
        logger: logger object
    Output:
        List of SOAP descriptors in numpy.ndarray format
    """
    start_time = time.time()
    soap = SOAP(
        species=species,
        r_cut=r_cut,
        n_max=n_max,
        l_max=l_max
    )

    try:
        soap_descriptors = soap.create(structures, n_jobs=njobs)
    except IndexError:
        # 说明 structures 为空
        soap_descriptors = []

    end_time = time.time()
    logger.info(f"SOAP descriptors computed in {end_time - start_time:.2f} seconds")

    # 由于此处调用本函数计算时，structures 中的结构默认是化学式相同的，因此返回的是 np.ndarray
    # 需要返回 list 类型，其中元素为 np.ndarray
    if type(soap_descriptors) == np.ndarray:
        return [i for i in soap_descriptors]
    
    # 原子数不相同时直接返回 list 即可
    return soap_descriptors

def compute_similarity(cand_soap, ref_soap, kernel_metric="laplacian"):
    """
    Function: Compute the similarity between candidate and reference SOAP descriptors using the specified kernel metric.
    Input:
        cand_soap: Candidate SOAP descriptor
        ref_soap: Reference SOAP descriptor
        kernel_metric: Kernel metric to use for similarity computation (default: laplacian)
    Output:
        Similarity score between candidate and reference SOAP descriptors
    """
    re = AverageKernel(metric=kernel_metric)
    return re.create(cand_soap, ref_soap)

def compare_and_update_structures(ref_structures, cand_structures, njobs=8, species=["H", "C", "O", "N"], r_cut=10.0, n_max=6, l_max=4, threshold=0.9, logger=None):
    """
    Function:
    Compare candidate structures with reference structures.
    Update the reference database one molecule by one molecule.

    Input:
        ref_structures: list of ase.Atoms objects, reference structures
        cand_structures: list of ase.Atoms objects, candidate structures
        njobs: int, number of jobs to run in parallel
        species: list of str, species to consider  
        r_cut: float, cutoff radius for SOAP calculation
        n_max: int, number of radial basis functions
        l_max: int, maximum degree of spherical harmonics
        threshold: float, similarity threshold for reducing candidate structures
        logger: logging.Logger object, logger for logging

    Output:
        ref_structures: list of ase.Atoms objects, updated reference structures
        soap_ref: list of soap_descriptors, SOAP descriptors for updated reference structures
    """
    round_num = 0
    logger.info(f"njobs: {njobs}, species: {species}, r_cut: {r_cut}, n_max: {n_max}, l_max: {l_max}, threshold: {threshold}")

    while True:
        round_num += 1

        if round_num == 1:
            # 初次计算全部的 SOAP 描述符
            # 这里还可以改进，先计算描述符或直接读入描述符
            soap_ref = compute_soap_descriptors(ref_structures, njobs, species, r_cut, n_max, l_max, logger)
            soap_cand = compute_soap_descriptors(cand_structures, njobs, species, r_cut, n_max, l_max, logger)

            # 如果 soap_ref 为空，则将 soap_cand 的第一个元素添加到 soap_ref 中
            if soap_ref == []:
                ref_structures.append(cand_structures[0])
                soap_ref.append(soap_cand[0])
                logger.info("Ref structure is empty, add the first Cand structure to Ref structure")

            # 并行计算 cand_structures 中每个结构与 ref_structures 中所有结构的相似度
            start_time = time.time()
            re_kernel_results = Parallel(n_jobs=njobs)(delayed(compute_similarity)(soap_cand[i:i+1], soap_ref) for i in range(len(soap_cand)))
            re_kernel = np.vstack(re_kernel_results)
            end_time = time.time()
            logger.info(f"Round {round_num}: Similarity computation completed in {end_time - start_time:.2f} seconds")

            print('Round:', round_num)
            print('re_kernel shape:', re_kernel.shape)
            print('re_kernel:', re_kernel[:10])

            # 选取 cand_structures 中每个结构与 ref_structures 中所有结构的最大相似度
            max_similarity_values = np.max(re_kernel, axis=1)
        
        else:
            # 并行计算 cand_structures 中每个结构与 ref_structures 新加入的结构的相似度
            # 变量覆盖释放内存空间
            ### 这里可以写成一个函数，便于复用
            start_time = time.time()
            re_kernel_results = Parallel(n_jobs=njobs)(delayed(compute_similarity)(soap_cand[i:i+1], [soap_ref[-1]]) for i in range(len(soap_cand)))
            re_kernel = np.vstack(re_kernel_results)
            end_time = time.time()
            logger.info(f"Round {round_num}: Similarity computation completed in {end_time - start_time:.2f} seconds")
            ### 这里可以写成一个函数，便于复用

            print('Round:', round_num)
            print('re_kernel shape:', re_kernel.shape)
            print('re_kernel:', re_kernel[:10])

            # 将原先 max_similarity_values 与新加入的点的堆叠
            max_similarity_values = np.max(np.column_stack((max_similarity_values, np.max(re_kernel, axis=1))), axis=1)


        # 删除 cand_structures 中与 ref_structures 中相似度高于 threshold 的所有结构
        # 更新 soap_cand, cand_structures, max_similarity_values
        old_cand_num = len(cand_structures)

        preserve_condition = max_similarity_values < threshold # 减少不必要的 round(5) 开销
        soap_cand = list(compress(soap_cand, preserve_condition)) # itertools.compress() 更高效
        cand_structures = list(compress(cand_structures, preserve_condition))
        max_similarity_values = max_similarity_values[preserve_condition] # np 布尔索引更高效
        
        new_cand_num = len(cand_structures)
        logger.info(f"Round {round_num}: Cand structures reduced from {old_cand_num} to {new_cand_num}")

        # 如果 cand_structures 中没有元素，则退出循环
        if new_cand_num == 0:
            break

        # 将 cand_structures 与 ref_structures 中最不相似的结构添加到 ref_structures 中
        min_max_similarity = np.min(max_similarity_values).round(5) # 减少不必要的 round(5) 开销
        min_max_similarity_index = np.argmin(max_similarity_values)
        ref_structures.append(cand_structures[min_max_similarity_index])
        soap_ref.append(soap_cand[min_max_similarity_index])
        logger.info(f"Round {round_num}: Added structure with min max similarity {min_max_similarity}.")
        logger.info(f"Ref structures: {len(ref_structures)}, Cand structures: {len(cand_structures)}")
        logger.info("---------")


    logger.info("No structures remaining in candidate list.")
    logger.info(f"Ref structures: {len(ref_structures)}, Cand structures: {len(cand_structures)}")
    logger.info("---------")
        
    return ref_structures, soap_ref

def save_soap_to_hdf5(soap_dict, hdf5_name):
    """
    Save SOAP descriptors to an HDF5 file.

    Example input structure:
    defaultdict(list, {
        'C2H3N3O': [
            array([[...], [...], ...]),  # SOAP descriptors for the first molecule
            array([[...], [...], ...])   # SOAP descriptors for the second molecule
        ],
        'C2H6': [
            array([[...], [...], ...])   # SOAP descriptors for another molecule
        ]
    })

    - Each value is an array representing the SOAP descriptors for a molecule, with shape (N, M),
    where N is the number of descriptors and M is the dimension of each descriptor.
    """
    with h5py.File(hdf5_name, "w") as hdf:
        for formula, soap_list in soap_dict.items():
            stacked_soap = np.array(soap_list)
            hdf.create_dataset(formula, data=stacked_soap)

# 在 .py 中暂时未使用
def read_soap_from_hdf5(hdf5_name):
    """
    Read SOAP descriptors from an HDF5 file.

    Returns:
        defaultdict(list): A dictionary with molecule formulas as keys
        and lists of corresponding SOAP descriptors as values.
    """
    soap_dict = defaultdict(list)

    with h5py.File(hdf5_name, "r") as hdf:
        for formula in hdf.keys():
            soap_descriptors = hdf[formula][:]
            soap_dict[formula].append(soap_descriptors)

    return soap_dict

# 在 .py 中暂时未使用
def defaultdict_profiler(soap_data):
    """
    Print the available formulas and the shape of their corresponding SOAP descriptors.
    """
    formulas = list(soap_data.keys())
    print("Available formulas:", formulas)

    # 遍历每个分子式并读取 SOAP 描述符
    for formula in formulas:
        soap_descriptors = soap_data[formula][:][0]
        print(f"Formula: {formula}, Shape of SOAP descriptors: {soap_descriptors.shape}")
    print("---------")

# 设置总的日志记录
def setup_total_logging():
    '''
    Setup the total logging for the algorithm.
    '''
    # 创建总的 Logger
    total_logger = logging.getLogger("total_logger")
    total_logger.setLevel(logging.INFO)

    # 清除之前的处理器，确保每次都干净
    for handler in total_logger.handlers[:]:
        total_logger.removeHandler(handler)

    # 创建文件处理器
    file_handler = logging.FileHandler("total_output.log", mode='w')
    file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
    total_logger.addHandler(file_handler)

    return total_logger

# 设置各化学式的日志记录
def setup_logging(reaction_formula):
    '''
    Setup the logging for each chemical formula.
    '''
    # 创建新的 Logger
    logger = logging.getLogger(reaction_formula)
    logger.setLevel(logging.INFO)
        
    # 清除之前的处理器，确保每次都干净
    for handler in logger.handlers[:]:
        logger.removeHandler(handler)

    # 创建文件夹
    os.makedirs(reaction_formula, exist_ok=True)

    # 创建文件处理器
    file_handler = logging.FileHandler(os.path.join(reaction_formula, f"{reaction_formula}_output.log"), mode='w')
    file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
    
    # 添加处理器到 Logger
    logger.addHandler(file_handler)
    
    return logger

# 主程序
def main(ref_file, cand_file, njobs, r_cut, n_max, l_max, threshold):
    total_logger = setup_total_logging()
    total_logger.info('Total Log begin')
    start_time = time.time()

    # 读取数据
    if ref_file == '':
        ref_structures = []
    else:
        ref_structures = read(ref_file, index=':')
    cand_structures = read(cand_file, index=':')

    # 根据 chemical_formula 分组
    ref_dict = defaultdict(list)
    cand_dict = defaultdict(list)

    for structure in ref_structures:
        formula = structure.get_chemical_formula()
        ref_dict[formula].append(structure)

    for structure in cand_structures:
        formula = structure.get_chemical_formula()
        cand_dict[formula].append(structure)

    formula_num = len(cand_dict.keys())
    total_logger.info(f"There are {formula_num} formulas to process.")
    total_logger.info("---------")


    for i, formula in enumerate(cand_dict.keys()):
        # 如果 ref_dict 中没有该组，则会返回空列表，程序可以正常运行
        total_logger.info(f"Processing formula {i+1:>}/{formula_num:>}: {formula}")
        total_logger.info(f"Start Ref structures: {len(ref_dict[formula])}, Cand structures: {len(cand_dict[formula])}")

        logger = setup_logging(formula)  
        formula_start_time = time.time()
        logger.info('Log begin')
        logger.info(f"Processing formula: {formula}")

        # threshold 对于 species 很敏感，这可能是自动匹配 species 后可能出现的问题
        # species=list(set(cand_dict[formula][0].get_chemical_symbols()))
        updated_structures, updated_soap_list = compare_and_update_structures(ref_dict[formula], 
                                                                            cand_dict[formula], 
                                                                            njobs=njobs,
                                                                            # species=species,
                                                                            r_cut=r_cut,
                                                                            n_max=n_max,
                                                                            l_max=l_max,
                                                                            threshold=threshold,
                                                                            logger=logger)
        
        # 逐一保存更新后的参考结构
        write(os.path.join(formula, f"updated_ref_structures_{formula}.xyz"), updated_structures)
        logger.info(f"Updated reference structures saved to '{formula}/updated_ref_structures_{formula}.xyz'")
        
        # 保存更新后的结构和 SOAP 到 HDF5
        soap_dict = defaultdict(list)
        for i in range(len(updated_soap_list)):
            soap_result = updated_soap_list[i]
            soap_dict[formula].append(soap_result)
        save_soap_to_hdf5(soap_dict, os.path.join(formula, f"updated_ref_soap_descriptors_{formula}.h5"))
        logger.info(f"SOAP descriptors saved to '{formula}/updated_ref_soap_descriptors_{formula}.h5'")

        formula_end_time = time.time()
        logger.info(f"Done! Total time elapsed: {formula_end_time - formula_start_time:.2f} seconds")
        logger.info('Log end')

        total_logger.info(f"End Ref structures: {len(updated_structures)}")
        total_logger.info(f"Done! Total time elapsed: {formula_end_time - formula_start_time:.2f} seconds")
        total_logger.info("---------")

    end_time = time.time()
    total_logger.info('All reactions processed successfully!')
    total_logger.info(f"Total processing time: {end_time - start_time:.2f} seconds")
    total_logger.info('Total Log end')

In [306]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Select different chemical structures.')
    parser.add_argument('--ref', type=str, default='', help='Reference XYZ file')
    parser.add_argument('--cand', type=str, required=True, help='Candidate XYZ file')
    parser.add_argument('--njobs', type=int, default=8, help='Number of jobs for parallel processing')
    parser.add_argument('--r_cut', type=float, default=10.0, help='Cutoff radius for soap descriptor')
    parser.add_argument('--n_max', type=int, default=6, help='Number of radial basis functions')
    parser.add_argument('--l_max', type=int, default=4, help='Maximum degree of spherical harmonics')
    parser.add_argument('--threshold', type=float, default=0.9, help='Similarity threshold')

    args = parser.parse_args()

    main(args.ref, args.cand, args.njobs, args.r_cut, args.n_max, args.l_max, args.threshold)

usage: ipykernel_launcher.py [-h] [--ref REF] --cand CAND [--njobs NJOBS]
                             [--r_cut R_CUT] [--n_max N_MAX] [--l_max L_MAX]
                             [--threshold THRESHOLD]
ipykernel_launcher.py: error: the following arguments are required: --cand


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [471]:
ref = ''
cand = 'rxn0000.xyz'
njobs = 2
r_cut = 10.0
n_max = 6
l_max = 4
threshold = 0.95
main(ref, cand, njobs, r_cut, n_max, l_max, threshold)

Round: 1
re_kernel shape: (100, 1)
re_kernel: [[1.        ]
 [0.96039081]
 [0.92096803]
 [0.89363827]
 [0.85799494]
 [0.83940276]
 [0.82873544]
 [0.81831525]
 [0.81133494]
 [0.80643498]]
Round: 2
re_kernel shape: (86, 1)
re_kernel: [[0.84577766]
 [0.87448632]
 [0.89647599]
 [0.93384489]
 [0.94597925]
 [0.96019293]
 [0.98173376]
 [1.        ]
 [0.84368853]
 [0.9000759 ]]
Round: 3
re_kernel shape: (73, 1)
re_kernel: [[0.91029363]
 [0.93223138]
 [0.93010149]
 [0.91048445]
 [0.89514285]
 [0.96158465]
 [0.95521957]
 [0.97229951]
 [0.91670071]
 [0.9295723 ]]
Round: 4
re_kernel shape: (63, 1)
re_kernel: [[0.96237323]
 [0.95877086]
 [0.9144786 ]
 [0.88941156]
 [0.87727563]
 [0.91379412]
 [0.95866378]
 [0.96058309]
 [1.        ]
 [0.92197856]]
Round: 5
re_kernel shape: (41, 1)
re_kernel: [[0.94196679]
 [0.97565787]
 [0.9579    ]
 [0.84932731]
 [0.95744178]
 [0.91714586]
 [1.        ]
 [0.97408772]
 [0.9507878 ]
 [0.91297389]]
Round: 6
re_kernel shape: (23, 1)
re_kernel: [[0.98272746]
 [0.886702

# 尝试测试 GPU 代码，本机上没有 GPU，因此理论上结果应该与上述结果相同

In [283]:
import torch

class GPUAverageKernel:
    def __init__(self, metric="laplacian", gamma=None, degree=3, coef0=1, kernel_params=None, normalize_kernel=True):
        """
        Args:
            metric (str): The pairwise metric used for calculating the local similarity, default is "laplacian".
            gamma (float): Gamma parameter for Laplacian kernel. Default is None, use sklearn's default gamma.
            degree (int): Degree of the polynomial kernel. Ignored for Laplacian. Default is 3.
            coef0 (float): Zero coefficient for polynomial and sigmoid kernels. Ignored for Laplacian. Default is 1.
            kernel_params (dict): Additional parameters for kernel function. Default is None.
            normalize_kernel (bool): Whether to normalize the kernel. Default is True.
        """
        self.metric = metric
        self.gamma = gamma
        self.degree = degree
        self.coef0 = coef0
        self.kernel_params = kernel_params
        self.normalize_kernel = normalize_kernel

    def get_pairwise_matrix(self, X, Y=None):
        """
        Computes the pairwise similarity of atomic environments using Laplacian kernel on GPU.
        
        Args:
            X (np.ndarray): Feature vector for atoms in structure A.
            Y (np.ndarray): Feature vector for atoms in structure B.
        
        Returns:
            torch.Tensor: NxM matrix of local similarities between structures A and B.
        """
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        X = torch.tensor(X, dtype=torch.float32, device=device)
        if Y is not None:
            Y = torch.tensor(Y, dtype=torch.float32, device=device)
        else:
            Y = X
        
        print('X.shape:', X.shape)
        print('Y.shape:', Y.shape)
        
        diff = torch.abs(X.unsqueeze(1) - Y.unsqueeze(0))
        print("Difference between structures:", diff.shape)

        dist = torch.sum(diff, dim=2)
        print("Distances:", dist.shape)

        K_ij = torch.exp(-self.gamma * dist)  # Laplacian kernel
        print("Pairwise kernel:", K_ij.shape)
        
        return K_ij

    def get_global_similarity(self, localkernel):
        """
        Computes the average global similarity between two structures.
        
        Args:
            localkernel (np.ndarray): NxM matrix of local similarities between structures A and B.
        
        Returns:
            float: Average similarity between the structures.
        """
        # localkernel = localkernel.clone().detach().to(dtype=torch.float32, device='cuda' if torch.cuda.is_available() else 'cpu')
        # print('localkernel:', localkernel)
        K_ij = torch.mean(localkernel)
        # print('Average similarity:', K_ij)
        return K_ij.item()


    def create(self, x, y=None):
        """
        Creates the kernel matrix based on the given lists of local features x and y.
        
        Args:
            x (iterable): A list of local feature arrays for each structure.
            y (iterable): Optional second list of features. If not specified, it is assumed that y=x.
        
        Returns:
            torch.Tensor: The pairwise global similarity kernel K[i,j] between the given structures.
        """
        # Ensure y is provided or set y = x (symmetric case)
        symmetric = False
        if y is None:
            y = x
            symmetric = True

        n_x = len(x)
        n_y = len(y)
        print('n_x', n_x)
        print('n_y', n_y)

        # Initialize the kernel matrix
        K_ij = torch.zeros((n_x, n_y), dtype=torch.float32, device='cuda' if torch.cuda.is_available() else 'cpu')

        # Compute the kernel matrix (pairwise local similarity followed by global similarity)
        for i in range(n_x):
            for j in range(n_y):
                if symmetric and j < i:
                    continue
                # Get the pairwise similarity matrix for x_i and y_j
                C_ij = self.get_pairwise_matrix(x[i], y[j])
                
                # Calculate the global similarity (mean of pairwise similarities)
                k_ij = self.get_global_similarity(C_ij)
                
                # Store the global similarity in the kernel matrix
                K_ij[i, j] = k_ij # (1, 1)
                
                # If symmetric, copy the value to the (j, i) position as well
                if symmetric and j != i:
                    K_ij[j, i] = k_ij

        # Normalize the kernel matrix if needed
        if self.normalize_kernel:
            if symmetric:
                k_ii = torch.diagonal(K_ij)
                x_k_ii_sqrt = torch.sqrt(k_ii)
                y_k_ii_sqrt = x_k_ii_sqrt
            else:
                # Calculate self-similarity for X
                x_k_ii = torch.empty(n_x, dtype=torch.float32, device='cuda' if torch.cuda.is_available() else 'cpu')
                for i in range(n_x):
                    C_ii = self.get_pairwise_matrix(x[i])
                    x_k_ii[i] = self.get_global_similarity(C_ii)

                x_k_ii_sqrt = torch.sqrt(x_k_ii)
                print('x_k_ii_sqrt:', x_k_ii_sqrt)
                print('x_k_ii_sqrt shape:', x_k_ii_sqrt.shape)

                # Calculate self-similarity for Y
                y_k_ii = torch.empty(n_y, dtype=torch.float32, device='cuda' if torch.cuda.is_available() else 'cpu')
                for i in range(n_y):
                    C_ii = self.get_pairwise_matrix(y[i])
                    y_k_ii[i] = self.get_global_similarity(C_ii)

                y_k_ii_sqrt = torch.sqrt(y_k_ii) 
                print('y_k_ii_sqrt:', y_k_ii_sqrt)
                print('y_k_ii_sqrt shape:', y_k_ii_sqrt.shape)

            print('old K_ij shape:', K_ij.shape)
            print('old K_ij:', K_ij)
            K_ij /= torch.outer(x_k_ii_sqrt, y_k_ii_sqrt)
            print('normalized K_ij shape:', K_ij.shape)
            print('normalized K_ij:', K_ij)
            
        return K_ij

In [595]:
import torch

class AverageKernelMultiDevice:
    def __init__(self, metric, gamma, gpu_id=None):
        """
        Args:
            metric (str): The pairwise metric used for calculating the local similarity, only "laplacian" is supported now.
            gamma (float): Gamma parameter for Laplacian kernel. Use sklearn's default gamma.
            gpu_id (int): The GPU device ID to be used for computation.
        """
        self.metric = metric
        self.gamma = gamma
        self.gpu_id = gpu_id

    def get_pairwise_matrix(self, X, Y=None):
        """
        Computes the pairwise similarity of atomic environments using Laplacian kernel on GPU.
        
        Args:
            X (torch.Tensor): Feature vector for atoms in multiple structures (n_x, n_atoms_x, n_features).
            Y (torch.Tensor): Feature vector for atoms in multiple structures (n_y, n_atoms_y, n_features).
                              If None, the pairwise similarity is computed between the same structures in X.
        
        Returns:
            torch.Tensor: Tensor (n_x, n_y, n_atoms_x, n_atoms_y) representing the pairwise similarities between X and Y.
                          If Y is None, the returned matrix is of shape (n_x, n_atoms_x, n_atoms_x).
        """
        # device = 'cuda' if torch.cuda.is_available() else 'cpu'
        device = torch.device(f'cuda:{self.gpu_id}' if (torch.cuda.is_available() and self.gpu_id is not None) else 'cpu')
        X = X.to(dtype=torch.float32, device=device)  # Shape: (n_x, n_atoms_x, n_features)

        if self.metric == "laplacian":

            # Normalization
            if Y is None:
                Y = X  # Shape: (n_x, n_atoms_x, n_features)
                diff = torch.abs(X.unsqueeze(2) - Y.unsqueeze(1))  # Shape: (n_x, n_atoms_x, n_atoms_x, n_features)
                dist = torch.sum(diff, dim=-1)  # Shape: (n_x, n_atoms_x, n_atoms_x)
                K_ij = torch.exp(-self.gamma * dist) # Shape: (n_x, n_atoms_x, n_atoms_x)

            else:
                Y = Y.to(dtype=torch.float32, device=device) # Shape: (n_y, n_atoms_y, n_features)

                # Broadcast difference calculation: compute |X_i - Y_j| for all i, j pairs
                diff = torch.abs(X.unsqueeze(1).unsqueeze(3) - Y.unsqueeze(0).unsqueeze(2))  # Shape: (n_x, n_y, n_atoms_x, n_atoms_y, n_features)

                # Sum over the atoms dimension (dim=3 for X and dim=4 for Y)
                dist = torch.sum(diff, dim=-1)  # Shape: (n_x, n_y, n_atoms_x, n_atoms_y)

                # Sum over atoms (2nd and 3rd dims) to get the pairwise kernel value for each pair of molecules
                K_ij = torch.exp(-self.gamma * dist)  # Shape: (n_x, n_y, n_atoms_x, n_atoms_y)
        
        print('X shape:', X.shape, 'Y shape:', Y.shape)
        print('diff shape:', diff.shape)
        print('dist shape:', dist.shape)
        print('K_ij shape:', K_ij.shape)

        return K_ij

    def get_global_similarity(self, localkernel):
        """
        Computes the average global similarity between two structures.
        
        Args:
            localkernel (torch.Tensor): Tensor (n_x, n_y, n_atoms_x, n_atoms_y) representing the pairwise similarities between structures in X and Y.
        
        Returns:
            torch.Tensor: Tensor (n_x, n_y) representing the average similarity between the structures.
                          If normalization mode in get_pairwise_matrix(), the shape returned is tensor (n_x).
        """
        device = torch.device(f'cuda:{self.gpu_id}' if (torch.cuda.is_available() and self.gpu_id is not None) else 'cpu')
        localkernel = localkernel.clone().detach().to(dtype=torch.float32, device=device)
        
        # Average similarity across all atoms in both molecules
        K_ij = torch.mean(localkernel, dim=(-2, -1))  # Shape: (n_x, n_y) or (n_x)
        # print('K_ij shape:', K_ij.shape)
        return K_ij

    def create(self, x, y=None):
        """
        Creates the kernel matrix based on the given lists of local features x and y.
    
        Args:
            x (iterable): A list of local feature arrays for each structure. Each element is a tensor of shape (n_atoms, n_features).
            y (iterable): An optional second list of features. 
                          If not specified, y is assumed to be the same as x, and the function computes self-similarity.

        Returns:
            torch.Tensor: A tensor representing the pairwise global similarity kernel K[i,j] between the given structures. 
                          Shape: (n_x, n_y) or (n_x), depending on whether y is provided.
        """

        # If y is None, compute self-similarity using x only
        if y is None:
            x_tensor = torch.stack([torch.tensor(i, dtype=torch.float32) for i in x])
            localkernel = self.get_pairwise_matrix(x_tensor)
            K_ij = self.get_global_similarity(localkernel)
            K_ij = torch.sqrt(K_ij)

            print('y=None')

        # If y is provided, compute pairwise similarity between x and y
        else:

            # symmetric = False
            # if y is None:
            #     y = x
                # symmetric = True
            
            # Convert input features to tensors
            x_tensor = torch.stack([torch.tensor(i, dtype=torch.float32) for i in x])
            y_tensor = torch.stack([torch.tensor(i, dtype=torch.float32) for i in y])

            # Compute pairwise kernel between structures in x and y
            localkernel = self.get_pairwise_matrix(x_tensor, y_tensor)

            # Compute global similarity between structures in x and y
            K_ij = self.get_global_similarity(localkernel)

        # Normalize kernel if required
        # if self.normalize_kernel:
        #     if symmetric:
        #         print(11111)
        #         k_ii = torch.diagonal(K_ij)
        #         x_k_ii_sqrt = torch.sqrt(k_ii).view(-1)
        #         y_k_ii_sqrt = x_k_ii_sqrt
        #     else:
        #         # Calculate self-similarity for X
        #         C_ii = self.get_pairwise_matrix(x_tensor)
        #         x_k_ii = self.get_global_similarity(C_ii)
        #         print('x_k_ii', x_k_ii.shape)
        #         x_k_ii_sqrt = torch.sqrt(x_k_ii)
        #         print('x_k_ii_sqrt', x_k_ii_sqrt.shape)

        #         # Calculate self-similarity for Y
        #         C_ii = self.get_pairwise_matrix(y_tensor)
        #         y_k_ii = self.get_global_similarity(C_ii)
        #         print('y_k_ii', y_k_ii.shape)
        #         y_k_ii_sqrt = torch.sqrt(y_k_ii)
        #         print('y_k_ii_sqrt', y_k_ii_sqrt.shape)

        print('old K_ij shape:', K_ij.shape)
        print('old K_ij:', K_ij)
            # K_ij /= torch.outer(x_k_ii_sqrt, y_k_ii_sqrt)
            # print('normalized K_ij shape:', K_ij.shape)
            # print('normalized K_ij:', K_ij)

        return K_ij

In [596]:
def compute_similarity_pytorch(cand_soap, ref_soap=None, kernel_metric="laplacian", gpu_id=None):
    """
    Computes the pairwise similarity between candidate and reference SOAP descriptors using laplacian kernel metric.
    """
    # 以 sci-kit learn 相同的方法计算 gamma 值
    gamma = 1.0 / cand_soap[0].shape[1]

    re = AverageKernelMultiDevice(metric=kernel_metric, gamma=gamma, gpu_id=gpu_id)
    return re.create(cand_soap, ref_soap)

In [597]:
def compare_and_update_structures(ref_structures, cand_structures, njobs=4, gpu=1, batch_size=50, species=["H", "C", "O", "N"], r_cut=10.0, n_max=6, l_max=4, threshold=0.9, logger=None):
    round_num = 0
    logger.info(f"njobs: {njobs}, gpu: {gpu}, batch_size: {batch_size}")
    logger.info(f"species: {species}, r_cut: {r_cut}, n_max: {n_max}, l_max: {l_max}, threshold: {threshold}")

    while True:
        round_num += 1

        if round_num == 1:

            # 初次计算，先计算全部的 SOAP 描述符
            soap_ref = compute_soap_descriptors(ref_structures, njobs, species, r_cut, n_max, l_max, logger)
            soap_cand = compute_soap_descriptors(cand_structures, njobs, species, r_cut, n_max, l_max, logger)

            # 如果 soap_ref 为空，则将 soap_cand 的第一个元素添加到 soap_ref 中
            if soap_ref == []:
                ref_structures.append(cand_structures[0])
                soap_ref.append(soap_cand[0])
                logger.info("Ref structure is empty, add the first Cand structure to Ref structure")

            start_time = time.time()

            # 支持 GPU 计算
            # 目前仅支持单 GPU 计算
            if gpu:
                # 分 batch_size 批次计算 cand 中每个结构与 ref 中所有结构的相似度
                re_kernel_results = Parallel(n_jobs=gpu)(
                    delayed(compute_similarity_pytorch)(soap_cand[i:i+batch_size], soap_ref, gpu_id=(i//batch_size)%gpu) 
                    for i in range(0, len(soap_cand), batch_size))

                # 分别计算 cand 与 ref 中结构的自我相似度，用于正则化最终的相似度结果到 [0, 1]
                soap_cand_self = Parallel(n_jobs=gpu)(
                    delayed(compute_similarity_pytorch)(soap_cand[i:i+batch_size], gpu_id=(i//batch_size)%gpu) 
                    for i in range(0, len(soap_cand), batch_size))

                soap_ref_self = Parallel(n_jobs=gpu)(
                    delayed(compute_similarity_pytorch)(soap_ref[i:i+batch_size], gpu_id=(i//batch_size)%gpu) 
                    for i in range(0, len(soap_ref), batch_size))

                # 全部数据一起计算会内存不足
                # re_kernel = compute_similarity_pytorch(soap_cand, soap_ref)
                # soap_cand_self = compute_similarity_pytorch(soap_cand)
                # soap_ref_self = compute_similarity_pytorch(soap_ref)
                
                # soap_cand_self = Parallel(n_jobs=njobs)(delayed(compute_similarity_pytorch)(soap_cand[i:i+cand_batch_size]) for i in range(0, len(soap_cand), cand_batch_size))
                # soap_ref_self = Parallel(n_jobs=njobs)(delayed(compute_similarity_pytorch)(soap_ref[i:i+cand_batch_size]) for i in range(0, len(soap_ref), cand_batch_size))
                
                # 此处正则化需要根据 batch_size 修改
                # re_kernel /= torch.outer(soap_cand_self, soap_ref_self)
                # re_kernel = re_kernel.cpu()

                # 依旧在 GPU 上计算
                # nor
                # re_kernel_normalized_results

                # 非并行便于调试
                # re_kernel_results = []
                # for i in range(len(soap_cand)):
                #     # print(i)
                #     result = compute_similarity_pytorch(soap_cand[i:i+1], soap_ref)
                #     re_kernel_results.append(result)
                    
                # re_kernel = np.vstack(re_kernel_results.cpu())

            # 支持多核 CPU 并行运算
            else:
                re_kernel_results = Parallel(n_jobs=njobs)(
                    delayed(compute_similarity_pytorch)(soap_cand[i:i+1], soap_ref) 
                    for i in range(len(soap_cand)))

                soap_cand_self = Parallel(n_jobs=njobs)(
                    delayed(compute_similarity_pytorch)(soap_cand[i:i+1]) 
                    for i in range(len(soap_cand)))

                soap_ref_self = Parallel(n_jobs=njobs)(
                    delayed(compute_similarity_pytorch)(soap_ref[i:i+1]) 
                    for i in range(len(soap_ref)))

            # re_kernel = np.vstack(re_kernel_results.cpu().numpy())
            # soap_cand_self = np.vstack(soap_cand_self.cpu().numpy())
            # soap_ref_self = np.vstack(soap_ref_self.cpu().numpy())
            # re_kernel /= np.outer(soap_cand_self, soap_ref_self)
                                      
            # 合并批次/并行计算的结果
            re_kernel = torch.cat(re_kernel_results, dim=0)
            soap_cand_self = torch.cat(soap_cand_self, dim=0)
            soap_ref_self = torch.cat(soap_ref_self, dim=0)

            # 正则化相似度矩阵
            re_kernel /= torch.outer(soap_cand_self, soap_ref_self)

            # 将相似度矩阵移动到 CPU 以避免 I/O 造成的计算速度瓶颈
            re_kernel = re_kernel.cpu()

            end_time = time.time()
            logger.info(f"Round {round_num}: Similarity computation and self-similarity completed in {end_time - start_time:.2f} seconds")

            print('Round:', round_num)
            print('re_kernel shape:', re_kernel.shape)
            print('re_kernel:', re_kernel[:10])

            # 选取 cand_structures 中每个结构与 ref_structures 中所有结构的最大相似度
            max_similarity_values, _ = torch.max(re_kernel, dim=1)

            # max_similarity_values = np.max(re_kernel, axis=1)
            # print(max_similarity_values.shape)

        else:
            start_time = time.time()

            if gpu:
                re_kernel_results = Parallel(n_jobs=gpu)(
                    delayed(compute_similarity_pytorch)(soap_cand[i:i+batch_size], [soap_ref[-1]], gpu_id=(i//batch_size)%gpu) 
                    for i in range(0, len(soap_cand), batch_size))
            else:
                re_kernel_results = Parallel(n_jobs=njobs)(
                    delayed(compute_similarity_pytorch)(soap_cand[i:i+1], [soap_ref[-1]]) 
                    for i in range(len(soap_cand)))
                                
            re_kernel = torch.cat(re_kernel_results, dim=0)
            # re_kernel = np.vstack(re_kernel_results.cpu().numpy())

            # 内存不足
            # re_kernel = compute_similarity_pytorch(soap_cand, soap_ref)
            # print('sxmi')
            # print(soap_cand_self.shape, soap_ref_self.shape)

            # 此处正则化需要根据 batch_size 修改
            re_kernel /= torch.outer(soap_cand_self, soap_ref_self[-1].unsqueeze(0))
            # re_kernel /= torch.outer(soap_cand_self, soap_ref_self)
            # re_kernel = re_kernel.cpu()
            # re_kernel /= np.outer(soap_cand_self, soap_ref_self)

            # re_kernel = np.vstack(re_kernel_results.cpu())
            end_time = time.time()
            logger.info(f"Round {round_num}: Similarity computation completed in {end_time - start_time:.2f} seconds")

            print('Round:', round_num)
            print('re_kernel shape:', re_kernel.shape)
            print('re_kernel:', re_kernel[:10])

            # max_similarity_values = np.max(np.column_stack((max_similarity_values, np.max(re_kernel, axis=1))), axis=1)
            # print(max_similarity_values.shape)
            # print(torch.max(re_kernel, dim=1)[0].view(-1, 1).shape)
            # print(max_similarity_values.view(-1, 1).shape)
            # print(torch.cat((max_similarity_values.view(-1, 1), torch.max(re_kernel, dim=1)[0].view(-1, 1)), dim=1).shape)
            
            # 将原先 max_similarity_values 与新加入的 ref 计算的 max_similarity_value 合并
            max_similarity_values, _ = torch.max(
                                            torch.cat((max_similarity_values.view(-1, 1), 
                                                       torch.max(re_kernel, dim=1)[0].view(-1, 1)), 
                                                       dim=1), 
                                            dim=1)
            
        # 删除 cand_structures 中与 ref_structures 中相似度高于 threshold 的所有结构
        # 更新 soap_cand, cand_structures, max_similarity_values, soap_cand_self
        old_cand_num = len(cand_structures)

        preserve_condition = max_similarity_values < threshold
        soap_cand = list(compress(soap_cand, preserve_condition)) # itertools.compress() 更高效
        cand_structures = list(compress(cand_structures, preserve_condition))
        max_similarity_values = max_similarity_values[preserve_condition] # 布尔索引更高效
        soap_cand_self = soap_cand_self[preserve_condition]

        new_cand_num = len(cand_structures)
        logger.info(f"Round {round_num}: Cand structures reduced from {old_cand_num} to {new_cand_num}")

        # 如果 cand_structures 中没有元素，说明筛选完毕，退出循环
        if new_cand_num == 0:
            break
        
        # 将 cand_structures 与 ref_structures 中最不相似的结构添加到 ref_structures 中
        # 包含结构，SOAP 描述符，以及用于正则化的自我相似度

        # min_max_similarity = np.min(max_similarity_values).round(5)
        min_max_similarity = torch.min(max_similarity_values).item()

        # min_max_similarity_index = np.argmin(max_similarity_values)
        min_max_similarity_index = torch.argmin(max_similarity_values).item()
        
        ref_structures.append(cand_structures[min_max_similarity_index])
        soap_ref.append(soap_cand[min_max_similarity_index])

        # print(soap_ref_self)
        # print(soap_cand_self[min_max_similarity_index].unsqueeze(0))
        soap_ref_self = torch.cat((soap_ref_self, soap_cand_self[min_max_similarity_index].unsqueeze(0)))
        # soap_ref_self = np.vstak((soap_ref_self, soap_cand_self[min_max_similarity_index]))
        

        logger.info(f"Round {round_num}: Added structure with min max similarity {min_max_similarity:.5f}.")
        logger.info(f"Ref structures: {len(ref_structures)}, Cand structures: {len(cand_structures)}")
        logger.info("---------")

    logger.info("No structures remaining in candidate list.")
    logger.info(f"Ref structures: {len(ref_structures)}, Cand structures: {len(cand_structures)}")
    logger.info("---------")

    return ref_structures, soap_ref

In [574]:
# 主程序
def main(ref_file, cand_file, njobs, gpu, batch_size, r_cut, n_max, l_max, threshold):
    total_logger = setup_total_logging()
    total_logger.info('Total Log begin')
    start_time = time.time()

    # 读取数据
    if ref_file == '':
        ref_structures = []
    else:
        ref_structures = read(ref_file, index=':')
    cand_structures = read(cand_file, index=':')

    # 根据 chemical_formula 分组
    ref_dict = defaultdict(list)
    cand_dict = defaultdict(list)

    for structure in ref_structures:
        formula = structure.get_chemical_formula()
        ref_dict[formula].append(structure)

    for structure in cand_structures:
        formula = structure.get_chemical_formula()
        cand_dict[formula].append(structure)

    formula_num = len(cand_dict.keys())
    total_logger.info(f"There are {formula_num} formulas to process.")

    # 确认 species 中所含有的元素类型
    species = set()
    for key in cand_dict:
        species.update(cand_dict[key][0].get_chemical_symbols())
    for key in ref_dict:
        try:
            species.update(cand_dict[key][0].get_chemical_symbols())
        except:
            continue
    species = list(species)
    total_logger.info(f"Species: {species}")
    total_logger.info("---------")

    for i, formula in enumerate(cand_dict.keys()):
        # 如果 ref_dict 中没有该组，则会返回空列表，程序可以正常运行
        total_logger.info(f"Processing formula {i+1:>}/{formula_num:>}: {formula}")
        total_logger.info(f"Start Ref structures: {len(ref_dict[formula])}, Cand structures: {len(cand_dict[formula])}")

        logger = setup_logging(formula)  
        formula_start_time = time.time()
        logger.info('Log begin')
        logger.info(f"Processing formula: {formula}")

        # 由于后续相似度是经过正则化到 [0, 1] 的，因此不必对每一个反应自动匹配 species，直接取所有元素的并集即可
        # species=list(set(cand_dict[formula][0].get_chemical_symbols()))
        updated_structures, updated_soap_list = compare_and_update_structures(ref_dict[formula], 
                                                                              cand_dict[formula], 
                                                                              njobs=njobs,
                                                                              gpu=gpu,
                                                                              batch_size=batch_size,                                                                                  
                                                                              species=species,
                                                                              r_cut=r_cut,
                                                                              n_max=n_max,
                                                                              l_max=l_max,
                                                                              threshold=threshold,
                                                                              logger=logger)
        
        # 逐一保存更新后的参考结构
        write(os.path.join(formula, f"updated_ref_structures_{formula}.xyz"), updated_structures)
        logger.info(f"Updated reference structures saved to '{formula}/updated_ref_structures_{formula}.xyz'")
        
        # 保存更新后的结构和 SOAP 到 HDF5
        soap_dict = defaultdict(list)
        for i in range(len(updated_soap_list)):
            soap_result = updated_soap_list[i]
            soap_dict[formula].append(soap_result)
        save_soap_to_hdf5(soap_dict, os.path.join(formula, f"updated_ref_soap_descriptors_{formula}.h5"))
        logger.info(f"SOAP descriptors saved to '{formula}/updated_ref_soap_descriptors_{formula}.h5'")

        formula_end_time = time.time()
        logger.info(f"Done! Total time elapsed: {formula_end_time - formula_start_time:.2f} seconds")
        logger.info('Log end')

        total_logger.info(f"End Ref structures: {len(updated_structures)}")
        total_logger.info(f"Done! Total time elapsed: {formula_end_time - formula_start_time:.2f} seconds")
        total_logger.info("---------")

    end_time = time.time()
    total_logger.info('All reactions processed successfully!')
    total_logger.info(f"Total processing time: {end_time - start_time:.2f} seconds")
    total_logger.info('Total Log end')

In [600]:
ref = ''
cand = 'rxn0000.xyz'
njobs = 8
gpu = 0
batch_size = 50
r_cut = 10.0
n_max = 6
l_max = 4
threshold = 0.95
main(ref, cand, njobs, gpu, batch_size, r_cut, n_max, l_max, threshold)

Round: 1
re_kernel shape: torch.Size([100, 1])
re_kernel: tensor([[1.0000],
        [0.9604],
        [0.9210],
        [0.8936],
        [0.8580],
        [0.8394],
        [0.8287],
        [0.8183],
        [0.8113],
        [0.8064]])
Round: 2
re_kernel shape: torch.Size([86, 1])
re_kernel: tensor([[0.8458],
        [0.8745],
        [0.8965],
        [0.9338],
        [0.9460],
        [0.9602],
        [0.9817],
        [1.0000],
        [0.8437],
        [0.9001]])
Round: 3
re_kernel shape: torch.Size([73, 1])
re_kernel: tensor([[0.9103],
        [0.9322],
        [0.9301],
        [0.9105],
        [0.8951],
        [0.9616],
        [0.9552],
        [0.9723],
        [0.9167],
        [0.9296]])
Round: 4
re_kernel shape: torch.Size([63, 1])
re_kernel: tensor([[0.9624],
        [0.9588],
        [0.9145],
        [0.8894],
        [0.8773],
        [0.9138],
        [0.9587],
        [0.9606],
        [1.0000],
        [0.9220]])
Round: 5
re_kernel shape: torch.Size([41, 1])
r

In [125]:
x = torch.rand([2, 9, 1500])
y = torch.rand([100, 9, 1500])

In [126]:
x = x.unsqueeze(0).unsqueeze(2) 
y = y.unsqueeze(1).unsqueeze(3) 

In [127]:
x.shape, y.shape

(torch.Size([1, 2, 1, 9, 1500]), torch.Size([100, 1, 9, 1, 1500]))

In [130]:
(x-y).shape

torch.Size([100, 2, 9, 9, 1500])

In [407]:
a = torch.rand([1, ])
b = torch.rand([1, ])
print(a, b)
torch.cat([a, b], dim=0)

tensor([0.5924]) tensor([0.3228])


tensor([0.5924, 0.3228])

In [486]:
555%1

0

In [491]:
if 0:
    print(1)