In [None]:
import numpy as np
from ase import Atoms
from ase.io import read
from scipy.spatial import KDTree, cKDTree
import matplotlib.pyplot as plt
from collections import defaultdict

class SurfaceSiteFinder:
    def __init__(self, atoms: Atoms, surface_direction: int = 2):
        """
        初始化表面位点查找器
        
        参数:
        only_mols: ASE Atoms 对象，表示平板结构
        surface_direction: 表面法线方向 (0=x, 1=y, 2=z)
        """
        self.atoms = atoms
        self.surface_direction = surface_direction
        self.grid_points = None
        self.wrapped_points = None
        self.site_atoms = defaultdict(list)
        self.site_positions = {}
        self.site_special_vectors = {}
        self.site_vectors = {}
        self.cell = self.atoms.get_cell()
        self.pbc = self.atoms.get_pbc()
        self._generate_replicas()
    def _generate_replicas(self):
        """生成必要的镜像原子以处理周期性边界条件"""
        # 确定每个方向需要复制的数量
        # 对于最近邻搜索，通常只需要相邻的镜像
        replicas = []
        for i, pbc in enumerate(self.pbc):
            if pbc:
                replicas.append([-1, 0, 1])
            else:
                replicas.append([0])
        
        # 生成所有可能的复制组合
        replica_offsets = np.array(np.meshgrid(*replicas)).T.reshape(-1, 3)
        
        # 存储所有位置（原始+镜像）
        self.all_positions = []
        self.original_indices = []  # 记录每个位置对应的原始原子索引
        
        # 原始原子位置
        original_positions = self.atoms.get_positions()
        for i, pos in enumerate(original_positions):
            self.all_positions.append(pos)
            self.original_indices.append(i)
        # 镜像原子位置
        for offset in replica_offsets:
            # 跳过零偏移（原始位置）
            if np.all(offset == 0):
                continue
            # 应用周期性偏移
            offset_positions = original_positions + offset @ self.cell
            
            for i, pos in enumerate(offset_positions):
                self.all_positions.append(pos)
                self.original_indices.append(i)
        
        self.all_positions = np.array(self.all_positions)
        self.original_indices = np.array(self.original_indices)
    def create_grid(self, grid_spacing: float = 0.1, height_above_surface: float = 5.0):
        """
        在表面上创建密集网格
        
        参数:
        grid_spacing: 网格点间距 (Å)
        height_above_surface: 网格在表面上方的初始高度 (Å)
        """
        # 获取表面原子的坐标
        positions = self.atoms.get_positions()
        
        # 确定表面方向
        if self.surface_direction == 0:  # x方向为表面法线
            surface_coords = positions[:, 1:]
            max_height = np.max(positions[:, 0])
        elif self.surface_direction == 1:  # y方向为表面法线
            surface_coords = positions[:, [0, 2]]
            max_height = np.max(positions[:, 1])
        else:  # z方向为表面法线 (默认)
            surface_coords = positions[:, :2]
            max_height = np.max(positions[:, 2])
        
        # 确定网格的边界
        x_min, y_min = np.min(surface_coords, axis=0)
        x_max, y_max = np.max(surface_coords, axis=0)
        
        # 扩展边界以确保覆盖整个表面
        x_min, x_max = x_min - 2.0, x_max + 2.0
        y_min, y_max = y_min - 2.0, y_max + 2.0
        
        # 创建网格点
        x_grid = np.arange(x_min, x_max, grid_spacing)
        y_grid = np.arange(y_min, y_max, grid_spacing)
        xx, yy = np.meshgrid(x_grid, y_grid)
        
        # 设置网格高度
        if self.surface_direction == 0:
            zz = np.full_like(xx, max_height + height_above_surface)
            self.grid_points = np.vstack([zz.ravel(), xx.ravel(), yy.ravel()]).T
        elif self.surface_direction == 1:
            zz = np.full_like(xx, max_height + height_above_surface)
            self.grid_points = np.vstack([xx.ravel(), zz.ravel(), yy.ravel()]).T
        else:
            zz = np.full_like(xx, max_height + height_above_surface)
            self.grid_points = np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T
            
        return self.grid_points    
    def wrap_grid_to_surface(self, contact_distance: float = 2.0, step_size: float = 0.1,height_above_surface=3.0):
        """
        将网格点向表面平移，直到接近原子
        
        参数:
        contact_distance: 接触距离阈值 (Å)
        step_size: 平移步长 (Å)
        """
        if self.grid_points is None:
            raise ValueError("请先创建网格点")
            
        # 创建原子位置的KD树用于快速最近邻搜索
        atom_positions = self.all_positions
        tree = cKDTree(atom_positions)
        
        # 初始化包裹后的点
        wrapped_points = self.grid_points.copy()
        
        # 确定平移方向
        if self.surface_direction == 0:
            direction = np.array([-1, 0, 0])
        elif self.surface_direction == 1:
            direction = np.array([0, -1, 0])
        else:
            direction = np.array([0, 0, -1])
            
        # 逐步平移网格点
        max_steps = int(height_above_surface / step_size) + 10
        for step in range(max_steps):
            # 计算每个点到最近原子的距离
            distances, indices = tree.query(wrapped_points)
            
            # 找到尚未接触原子的点
            not_contacted = distances > contact_distance
            
            if not np.any(not_contacted):
                break
                
            # 将这些点向表面方向移动
            wrapped_points[not_contacted] += direction * step_size
        
        self.wrapped_points = wrapped_points
        return wrapped_points  
    def find_sites(self, contact_distance: float = 2.0, multi_site_threshold: float = 2):
        """
        识别表面位点
        
        参数:
        contact_distance: 接触距离阈值 (Å)
        multi_site_threshold: 多重位点识别阈值 (Å)
        """
        if self.wrapped_points is None:
            raise ValueError("请先执行网格包裹")
            
        # 创建原子位置的KD树
        atom_positions = self.all_positions
        tree = cKDTree(atom_positions)
        
        # 对于每个包裹后的网格点，找到接触的原子
        for i, point in enumerate(self.wrapped_points):
            # 找到距离此点在一定范围内的所有原子
            indices = tree.query_ball_point(point, contact_distance)
            
            if indices:
                # 将原子索引转换为可哈希的元组
                atom_tuple = tuple(sorted(indices))
                self.site_atoms[atom_tuple].append(point)
        
        # 识别位点类型并计算位点位置
        for atom_indices, points in self.site_atoms.items():
            if len(atom_indices) == 1:
                # 顶位 - 使用原子位置
                atom_idx = atom_indices[0]
                self.site_positions[atom_indices] = atom_positions[atom_idx]
                self.site_special_vectors[atom_indices]=None
            elif len(atom_indices) == 2:
                #桥位
                site_atoms = atom_positions[list(atom_indices)]
                self.site_positions[atom_indices] = np.mean(site_atoms, axis=0)
                self.site_special_vectors[atom_indices]=(site_atoms[-1]-site_atoms[0])/np.linalg.norm(site_atoms[-1]-site_atoms[0])
            else:
                # 桥位或多重位点 - 使用原子位置的平均值
                site_atoms = atom_positions[list(atom_indices)]
                self.site_positions[atom_indices] = np.mean(site_atoms, axis=0)
                self.site_special_vectors[atom_indices]=None
        
        return self.site_atoms, self.site_positions,self.site_special_vectors
    def classify_sites(self, multi_site_threshold: float = 2):
        """
        分类位点类型
        
        参数:
        multi_site_threshold: 多重位点识别阈值 (Å)
        """
        site_types = {}
        
        for atom_indices in self.site_atoms.keys():
            if len(atom_indices) == 1:
                site_types[atom_indices] = "top"
            elif len(atom_indices) == 2:
                site_types[atom_indices] = "bridge"
            else:
                # 检查是否构成多重位点
                atom_positions = self.all_positions[list(atom_indices)]
                centroid = np.mean(atom_positions, axis=0)
                
                # 计算原子到质心的最大距离
                max_distance = np.max(np.linalg.norm(atom_positions - centroid, axis=1))
                
                if max_distance < multi_site_threshold:
                    site_types[atom_indices] = f"{len(atom_indices)}th_multifold"
                else:
                    site_types[atom_indices] = "complex"
        self.site_types = site_types
        return site_types  
    def visualize(self, show_grid: bool = False, show_wrapped: bool = True):
        """
        可视化结果
        
        参数:
        show_grid: 是否显示初始网格
        show_wrapped: 是否显示包裹后的网格
        """
        fig = plt.figure(figsize=(10, 8))
        ax = fig.add_subplot(111, projection='3d')
        
        # 绘制原子
        positions = self.atoms.get_positions()
        ax.scatter(positions[:, 0], positions[:, 1], positions[:, 2], 
                  c='blue', s=150, label='atom')
        
        # 绘制初始网格点（如果要求）
        if show_grid and self.grid_points is not None:
            ax.scatter(self.grid_points[:, 0], self.grid_points[:, 1], self.grid_points[:, 2],
                      c='gray', s=5, alpha=0.3, label='initial grid')
        
        # 绘制包裹后的网格点（如果要求）
        if show_wrapped and self.wrapped_points is not None:
            ax.scatter(self.wrapped_points[:, 0], self.wrapped_points[:, 1], self.wrapped_points[:, 2],
                      c='lightgreen', s=5, alpha=0.5, label='warpped grid')
        
        # 绘制位点位置
        colors = ['red', 'orange', 'purple', 'cyan', 'magenta', 'yellow']
        color_idx = 0
        
        for atom_indices, position in self.site_positions.items():
            site_type = self.classify_sites().get(atom_indices, "未知")
            
            if site_type == "top":
                color = 'red'
                marker = 'o'
                size = 50
                alpha = 1
            elif site_type == "bridge":
                color = 'orange'
                marker = 's'
                size = 40
                alpha = 1
            elif "multifold" in site_type:
                '''color = colors[color_idx % len(colors)]
                color_idx += 1'''
                color = 'cyan'
                marker = 'D'
                size = 30
                alpha = 1
            else:
                color = 'gray'
                marker = 'x'
                size = 20
                alpha = 1
            
            ax.scatter(position[0], position[1], position[2], 
                      c=color, marker=marker, s=size, label=site_type,alpha = alpha)
            
            # 绘制向量（从位点到网格点平均位置）
            '''if atom_indices in self.site_vectors:
                vector = self.site_vectors[atom_indices]
                ax.quiver(position[0], position[1], position[2],
                         vector[0]-position[0], vector[1]-position[1], vector[2]-position[2],
                         color=color, arrow_length_ratio=0.1)'''
        
        # 设置图表属性
        ax.set_xlabel('X (Å)')
        ax.set_ylabel('Y (Å)')
        ax.set_zlabel('Z (Å)')
        ax.set_title('sites of surface')
        
        # 避免重复的图例标签
        handles, labels = ax.get_legend_handles_labels()
        by_label = dict(zip(labels, handles))
        ax.legend(by_label.values(), by_label.keys())
        
        plt.show()

# 使用示例
if __name__ == "__main__":
    from ase.build import hcp0001, add_adsorbate
    # 创建表面
    slab = hcp0001('Ru', size=(4, 4, 4), vacuum=10.0)
    
    # 初始化查找器
    finder = SurfaceSiteFinder(slab)
    # 创建网格
    grid_points = finder.create_grid(grid_spacing=0.1, height_above_surface=3.0)
    
    # 包裹网格到表面
    wrapped_points = finder.wrap_grid_to_surface(contact_distance=2, step_size=0.1,height_above_surface=3.0)
    
    # 查找位点
    sites, positions, special_vector = finder.find_sites(contact_distance=2)
    
    # 分类位点
    site_types = finder.classify_sites(multi_site_threshold=2)
    
    # 打印结果
    print("识别到的位点:")
    for atom_indices, site_type in site_types.items():
            print(f"{site_type}: Atoms {atom_indices}, Position {positions[atom_indices]}, vector {special_vector[atom_indices]}")

    # 可视化
    finder.visualize(show_grid=False, show_wrapped=False)
    
    # 示例2: 从文件读取结构
    # 假设有一个POSCAR文件
    # atoms = read('POSCAR')
    # finder = SurfaceSiteFinder(atoms)
    # ... 其余步骤相同

In [None]:
from scipy.spatial import KDTree
import numpy as np

def find_nearest_point_kdtree(points, target_point):
    """
    使用KD树找到最近的点
    """
    # 构建KD树
    tree = KDTree(points)
    
    # 查询最近的点
    distance, index = tree.query(target_point.reshape(1, -1))
    
    nearest_point = points[index[0]]
    
    return nearest_point, distance, index

# 示例使用
points =[]
key = []
for i in positions:
    points.append(positions[i])
    key.append(i)

points = np.array(points)  
target = np.array([2.7181167669610438, 1.5688513598637766, 17.9147179067864251])

nearest, distance, idx = find_nearest_point_kdtree(points, target)
print(f"最近的点: {nearest}")
print(f"距离: {distance}")
print(f"索引: {key[idx[0]]}")

In [None]:
from scipy.spatial import KDTree
import numpy as np
from rdkit import Chem
from rdkit.Chem import rdmolops
from build_ISFS_model.CheckNN import *
from ase.io import write
import copy
import os
import re
from rdkit import Chem
'''查询吸附位点'''
class SurfaceSiteFinder:
    def __init__(self, atoms: Atoms, surface_direction: int = 2):
        """
        初始化表面位点查找器
        
        参数:
        atoms: ASE Atoms 对象，表示平板结构
        surface_direction: 表面法线方向 (0=x, 1=y, 2=z)
        """
        self.atoms = atoms
        self.surface_direction = surface_direction
        self.grid_points = None
        self.wrapped_points = None
        self.site_atoms = defaultdict(list)
        self.site_positions = {}
        self.site_special_vectors = {}
        self.site_vectors = {}
        self.cell = self.atoms.get_cell()
        self.pbc = self.atoms.get_pbc()
        self._generate_replicas()
    def _generate_replicas(self):
        """生成必要的镜像原子以处理周期性边界条件"""
        # 确定每个方向需要复制的数量
        # 对于最近邻搜索，通常只需要相邻的镜像
        replicas = []
        for i, pbc in enumerate(self.pbc):
            if pbc:
                replicas.append([-1, 0, 1])
            else:
                replicas.append([0])
        
        # 生成所有可能的复制组合
        replica_offsets = np.array(np.meshgrid(*replicas)).T.reshape(-1, 3)
        
        # 存储所有位置（原始+镜像）
        self.all_positions = []
        self.original_indices = []  # 记录每个位置对应的原始原子索引
        
        # 原始原子位置
        original_positions = self.atoms.get_positions()
        for i, pos in enumerate(original_positions):
            self.all_positions.append(pos)
            self.original_indices.append(i)
        # 镜像原子位置
        for offset in replica_offsets:
            # 跳过零偏移（原始位置）
            if np.all(offset == 0):
                continue
            # 应用周期性偏移
            offset_positions = original_positions + offset @ self.cell
            
            for i, pos in enumerate(offset_positions):
                self.all_positions.append(pos)
                self.original_indices.append(i)
        
        self.all_positions = np.array(self.all_positions)
        self.original_indices = np.array(self.original_indices)
    def create_grid(self, grid_spacing: float = 0.1, height_above_surface: float = 5.0):
        """
        在表面上创建密集网格
        
        参数:
        grid_spacing: 网格点间距 (Å)
        height_above_surface: 网格在表面上方的初始高度 (Å)
        """
        # 获取表面原子的坐标
        positions = self.atoms.get_positions()
        
        # 确定表面方向
        if self.surface_direction == 0:  # x方向为表面法线
            surface_coords = positions[:, 1:]
            max_height = np.max(positions[:, 0])
        elif self.surface_direction == 1:  # y方向为表面法线
            surface_coords = positions[:, [0, 2]]
            max_height = np.max(positions[:, 1])
        else:  # z方向为表面法线 (默认)
            surface_coords = positions[:, :2]
            max_height = np.max(positions[:, 2])
        
        # 确定网格的边界
        x_min, y_min = np.min(surface_coords, axis=0)
        x_max, y_max = np.max(surface_coords, axis=0)
        
        # 扩展边界以确保覆盖整个表面
        x_min, x_max = x_min - 2.0, x_max + 2.0
        y_min, y_max = y_min - 2.0, y_max + 2.0
        
        # 创建网格点
        x_grid = np.arange(x_min, x_max, grid_spacing)
        y_grid = np.arange(y_min, y_max, grid_spacing)
        xx, yy = np.meshgrid(x_grid, y_grid)
        
        # 设置网格高度
        if self.surface_direction == 0:
            zz = np.full_like(xx, max_height + height_above_surface)
            self.grid_points = np.vstack([zz.ravel(), xx.ravel(), yy.ravel()]).T
        elif self.surface_direction == 1:
            zz = np.full_like(xx, max_height + height_above_surface)
            self.grid_points = np.vstack([xx.ravel(), zz.ravel(), yy.ravel()]).T
        else:
            zz = np.full_like(xx, max_height + height_above_surface)
            self.grid_points = np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T
            
        return self.grid_points
    
    def wrap_grid_to_surface(self, contact_distance: float = 2.0, step_size: float = 0.1,height_above_surface=3.0):
        """
        将网格点向表面平移，直到接近原子
        
        参数:
        contact_distance: 接触距离阈值 (Å)
        step_size: 平移步长 (Å)
        """
        if self.grid_points is None:
            raise ValueError("请先创建网格点")
            
        # 创建原子位置的KD树用于快速最近邻搜索
        atom_positions = self.all_positions
        tree = cKDTree(atom_positions)
        
        # 初始化包裹后的点
        wrapped_points = self.grid_points.copy()
        
        # 确定平移方向
        if self.surface_direction == 0:
            direction = np.array([-1, 0, 0])
        elif self.surface_direction == 1:
            direction = np.array([0, -1, 0])
        else:
            direction = np.array([0, 0, -1])
            
        # 逐步平移网格点
        max_steps = int(height_above_surface / step_size) + 10
        for step in range(max_steps):
            # 计算每个点到最近原子的距离
            distances, indices = tree.query(wrapped_points)
            
            # 找到尚未接触原子的点
            not_contacted = distances > contact_distance
            
            if not np.any(not_contacted):
                break
                
            # 将这些点向表面方向移动
            wrapped_points[not_contacted] += direction * step_size
        
        self.wrapped_points = wrapped_points
        return wrapped_points
    
    def find_sites(self, contact_distance: float = 2.0, multi_site_threshold: float = 2):
        """
        识别表面位点
        
        参数:
        contact_distance: 接触距离阈值 (Å)
        multi_site_threshold: 多重位点识别阈值 (Å)
        """
        if self.wrapped_points is None:
            raise ValueError("请先执行网格包裹")
            
        # 创建原子位置的KD树
        atom_positions = self.all_positions
        tree = cKDTree(atom_positions)
        
        # 对于每个包裹后的网格点，找到接触的原子
        for i, point in enumerate(self.wrapped_points):
            # 找到距离此点在一定范围内的所有原子
            indices = tree.query_ball_point(point, contact_distance)
            
            if indices:
                # 将原子索引转换为可哈希的元组
                atom_tuple = tuple(sorted(indices))
                self.site_atoms[atom_tuple].append(point)
        
        # 识别位点类型并计算位点位置
        for atom_indices, points in self.site_atoms.items():
            if len(atom_indices) == 1:
                # 顶位 - 使用原子位置
                atom_idx = atom_indices[0]
                self.site_positions[atom_indices] = atom_positions[atom_idx]
                self.site_special_vectors[atom_indices]=None
            elif len(atom_indices) == 2:
                #桥位
                site_atoms = atom_positions[list(atom_indices)]
                self.site_positions[atom_indices] = np.mean(site_atoms, axis=0)
                self.site_special_vectors[atom_indices]=(site_atoms[-1]-site_atoms[0])/np.linalg.norm(site_atoms[-1]-site_atoms[0])
            else:
                # 桥位或多重位点 - 使用原子位置的平均值
                site_atoms = atom_positions[list(atom_indices)]
                self.site_positions[atom_indices] = np.mean(site_atoms, axis=0)
                self.site_special_vectors[atom_indices]=None
        
        return self.site_atoms, self.site_positions,self.site_special_vectors
    
    def classify_sites(self, multi_site_threshold: float = 2):
        """
        分类位点类型
        
        参数:
        multi_site_threshold: 多重位点识别阈值 (Å)
        """
        site_types = {}
        
        for atom_indices in self.site_atoms.keys():
            if len(atom_indices) == 1:
                site_types[atom_indices] = "top"
            elif len(atom_indices) == 2:
                site_types[atom_indices] = "bridge"
            else:
                # 检查是否构成多重位点
                atom_positions = self.all_positions[list(atom_indices)]
                centroid = np.mean(atom_positions, axis=0)
                
                # 计算原子到质心的最大距离
                max_distance = np.max(np.linalg.norm(atom_positions - centroid, axis=1))
                
                if max_distance < multi_site_threshold:
                    site_types[atom_indices] = f"{len(atom_indices)}th_multifold"
                else:
                    site_types[atom_indices] = "complex"
        
        return site_types
    
    def visualize(self, show_grid: bool = False, show_wrapped: bool = True):
        """
        可视化结果
        
        参数:
        show_grid: 是否显示初始网格
        show_wrapped: 是否显示包裹后的网格
        """
        fig = plt.figure(figsize=(10, 8))
        ax = fig.add_subplot(111, projection='3d')
        
        # 绘制原子
        positions = self.atoms.get_positions()
        ax.scatter(positions[:, 0], positions[:, 1], positions[:, 2], 
                  c='blue', s=150, label='atom')
        
        # 绘制初始网格点（如果要求）
        if show_grid and self.grid_points is not None:
            ax.scatter(self.grid_points[:, 0], self.grid_points[:, 1], self.grid_points[:, 2],
                      c='gray', s=5, alpha=0.3, label='initial grid')
        
        # 绘制包裹后的网格点（如果要求）
        if show_wrapped and self.wrapped_points is not None:
            ax.scatter(self.wrapped_points[:, 0], self.wrapped_points[:, 1], self.wrapped_points[:, 2],
                      c='lightgreen', s=5, alpha=0.5, label='warpped grid')
        
        # 绘制位点位置
        colors = ['red', 'orange', 'purple', 'cyan', 'magenta', 'yellow']
        color_idx = 0
        
        for atom_indices, position in self.site_positions.items():
            site_type = self.classify_sites().get(atom_indices, "未知")
            
            if site_type == "top":
                color = 'red'
                marker = 'o'
                size = 50
                alpha = 1
            elif site_type == "bridge":
                color = 'orange'
                marker = 's'
                size = 40
                alpha = 1
            elif "multifold" in site_type:
                '''color = colors[color_idx % len(colors)]
                color_idx += 1'''
                color = 'cyan'
                marker = 'D'
                size = 30
                alpha = 1
            else:
                color = 'gray'
                marker = 'x'
                size = 20
                alpha = 1
            
            ax.scatter(position[0], position[1], position[2], 
                      c=color, marker=marker, s=size, label=site_type,alpha = alpha)
            
            # 绘制向量（从位点到网格点平均位置）
            '''if atom_indices in self.site_vectors:
                vector = self.site_vectors[atom_indices]
                ax.quiver(position[0], position[1], position[2],
                         vector[0]-position[0], vector[1]-position[1], vector[2]-position[2],
                         color=color, arrow_length_ratio=0.1)'''
        
        # 设置图表属性
        ax.set_xlabel('X (Å)')
        ax.set_ylabel('Y (Å)')
        ax.set_zlabel('Z (Å)')
        ax.set_title('sites of surface')
        
        # 避免重复的图例标签
        handles, labels = ax.get_legend_handles_labels()
        by_label = dict(zip(labels, handles))
        ax.legend(by_label.values(), by_label.keys())
        
        plt.show()
def find_nearest_point_kdtree(points, target_point):
    """
    使用KD树找到最近的点
    """
    # 构建KD树
    tree = KDTree(points)
    
    # 查询最近的点
    distance, index = tree.query(target_point.reshape(1, -1))
    
    nearest_point = points[index[0]]
    
    return nearest_point, distance[0], index[0]
class DistanceQuery:
    def __init__(self, points):
        """
        初始化KDTree索引
        """
        self.points = np.array(points)
        self.tree = cKDTree(self.points)
    
    def find_points_at_distance(self, query_point, target_distance, tolerance=0.1):
        """
        使用KDTree高效查找特定距离的点
        
        策略：先找到半径范围内的点，再精确筛选
        """
        # 搜索半径范围 [target_distance - tolerance, target_distance + tolerance]
        min_dist = max(0, target_distance - tolerance)
        max_dist = target_distance + tolerance
        
        # 使用query_ball_point找到范围内的点
        indices = self.tree.query_ball_point(query_point, max_dist)
        
        # 精确筛选
        result_indices = []
        for idx in indices:
            dist = np.linalg.norm(self.points[idx] - query_point)
            if abs(dist - target_distance) < tolerance:
                result_indices.append(idx)
        
        return result_indices
    
    def find_all_distances(self, query_point):
        """
        返回所有点到查询点的距离，便于后续分析
        """
        return np.linalg.norm(self.points - query_point, axis=1)
def find_site(atoms,adsatom:list,finder): 
    out = []
    sites, positions,special_vector = finder.find_sites(contact_distance=2)
    site_types = finder.classify_sites(multi_site_threshold=2)
    site_positions_list =[]
    atom_indices_list = []
    for i in positions:
        site_positions_list.append(positions[i])
        atom_indices_list.append(i)
    for adsA in adsatom:
        tp = atoms.positions[adsA.id]
        nearest, distance, idx = find_nearest_point_kdtree(np.array(site_positions_list),tp)
        vector = tp - nearest
        atom_indices = atom_indices_list[idx[0]]
        site_type = site_types[atom_indices]
        out.append([nearest, distance,adsA.id,atom_indices,site_type, vector])
    return out
def svd_rotation_matrix(a, b):
    """
    使用SVD分解计算旋转矩阵
    
    参数:
    a, b: 三维单位向量 (numpy数组)
    
    返回:
    R: 3x3旋转矩阵
    """
    # 确保输入是单位向量
    a = a / np.linalg.norm(a)
    b = b / np.linalg.norm(b)
    
    # 计算协方差矩阵
    H = np.outer(a, b)
    
    # SVD分解
    U, S, Vt = np.linalg.svd(H)
    
    # 计算旋转矩阵
    R = np.dot(Vt.T, U.T)
    
    # 处理反射情况
    if np.linalg.det(R) < 0:
        Vt[2, :] *= -1
        R = np.dot(Vt.T, U.T)
    
    return R
def find_min_sum_distance_point_vectorized(points, point_a, point_b):
    """
    使用向量化计算找到到两点距离和最小的点（更高效）
    
    参数:
    points: 三维点集，形状为(n, 3)的numpy数组
    point_a, point_b: 两个已知的三维点
    
    返回:
    min_point: 到两点距离和最小的点
    min_distance: 最小的距离和
    min_index: 最小点在点集中的索引
    """
    # 向量化计算所有点到point_a和point_b的距离
    dist_a = np.linalg.norm(points - point_a, axis=1)
    dist_b = np.linalg.norm(points - point_b, axis=1)
    total_dists = dist_a + dist_b
    
    # 找到最小值的索引
    min_index = np.argmin(total_dists)
    min_point = points[min_index]
    min_distance = total_dists[min_index]
    
    return min_point, min_distance, min_index
def select_site_with_max_dist(result_idxlist,base_mol,points,centeridx):
    sitepositions=[]
    idxlist = []
    for idx in result_idxlist:
        sitepositions.append(points[idx])
        idxlist.append(idx)
    sitepositions=np.array(sitepositions)
    total_dists = 0
    for a in base_mol:
        if a.index != centeridx:
            point_a = a.position
            dist_a = np.linalg.norm(sitepositions - point_a, axis=1)
            total_dists+=dist_a
    max_index = np.argmax(total_dists)
    max_point = sitepositions[max_index]
    max_distance = total_dists[max_index]
    return idxlist(max_index),max_point
'''查询成键反应原子对'''
def str2list(reaction:str):
    r1 = reaction.split(">")
    r2 = []
    for i in r1:
        i1 = i.split()
        r2.append(i1)
    return r2
def checkbond(reaction:list,a1,a2,a3):
    mol1 = a1.bms.mol
    mol2 = a2.bms.mol
    mol3 = a3.bms.mol
    reactiontype = reaction[1][0]
    clean_list_reaction =[]
    for id in range(len(reaction[1])):
        if bool(re.search(r'\d', reaction[1][id])) == False:
            clean_list_reaction.append(reaction[1][id])
    addatom = clean_list_reaction[1]#reaction[1][-3]
    bondedatom = clean_list_reaction[-1]#reaction[1][-1]
    def addATOM():
        if reactiontype == 'Add':
            if len(addatom) > 1 and 'C' in addatom:
                return 'C'
            elif addatom =='OH':
                return 'O'
            else:
                return addatom
        else:
            if addatom == 'O/OH':
                return 'O'
            else:
                return addatom
    aset = {addATOM,bondedatom}
    adsAtomsIna1 = a1.ads
    indices_to_mol = [atom.index for atom in a1.poscar if atom.symbol != 'Ru']
    adsAtomsIna2 = a2.ads
    for adsA1 in adsAtomsIna1:
        for adsA2 in adsAtomsIna2:
            qset = {adsA1.elesymbol,adsA2.elesymbol}
            if qset == aset:
                mol_broken = rdmolops.CombineMols(mol1,mol2)
                rwmol = Chem.RWMol(mol_broken)
                ida1 = adsA1.id-64
                ida2 = adsA2.id-64+mol1.GetNumAtoms()
                rwmol.AddBond(ida1, ida2, Chem.BondType.SINGLE)
                if Chem.MolToSmiles(rwmol) == Chem.MolToSmiles(mol3):
                    return adsA1.id,adsA2.id,'ad2ad'
                else:pass
    for A1idx in indices_to_mol:
        A1 = a1.atoms[A1idx]
        for adsA2 in adsAtomsIna2:
            qset = {A1.elesymbol,adsA2.elesymbol}
            if qset == aset:
                mol_broken = rdmolops.CombineMols(mol1,mol2)
                rwmol = Chem.RWMol(mol_broken)
                ida1 = A1.id-64
                ida2 = adsA2.id-64+mol1.GetNumAtoms()
                rwmol.AddBond(ida1, ida2, Chem.BondType.SINGLE)
                if Chem.MolToSmiles(rwmol) == Chem.MolToSmiles(mol3):
                    return A1.id,adsA2.id,'Nad2ad'
                else:pass
    return False, False,False
class NN_system():
    def __init__(self):
        self.cb = None
        self.bms = None
        self.ads = None
        self.only_mol = None
        self.ads_data = None
        self.atoms = None
    def RunCheckNN_FindSite(self,file,finder):
        cb = checkBonds()
        if type(file) == str:
            cb.input(file)
        else:cb.poscar=file
        cb.AddAtoms()
        cb.CheckAllBonds()
        bms=BuildMol2Smiles(cb)
        bms.build()
        self.cb = cb 
        self.bms = bms
        self.ads = cb.adsorption
        atoms = cb.poscar
        self.atoms = atoms
        indices_to_mol = [atom.index for atom in atoms if atom.symbol != 'Ru']
        self.only_mol = atoms[indices_to_mol]
        self.ads_data = find_site(cb.poscar,cb.adsorption,finder)
        return self
        
"""
1.复数吸附位点吸附的中间体在催化剂表面难以发生迁移
2.复数吸附的中间体基元反应前后整体质心移动幅度小
#3.化学键的形成与断裂发生在吸附原子之间,至少有一个吸附原子
"""

class model1():
    def __init__(self,atoms1,atoms2,atoms3):
        self.atoms = (atoms1,atoms2,atoms3)
        '''
        a1(big)+a2(small)=a3(bigger)
        '''
    def site_finder(self,slab):
        finder = SurfaceSiteFinder(slab)
        # 创建网格
        grid_points = finder.create_grid(grid_spacing=0.1, height_above_surface=3.0)
        # 包裹网格到表面
        wrapped_points = finder.wrap_grid_to_surface(contact_distance=2, step_size=0.1,height_above_surface=3.0)
        # 查找位点
        sites, positions,special_vector = finder.find_sites(contact_distance=2)
        # 分类位点
        site_types = finder.classify_sites(multi_site_threshold=2)
        self.site = finder
        self.site_types = site_types
        self.site_positions = positions
        self.special_vectors = special_vector
        return self.site
    def run(self,reaction:str):
        self.r = str2list(reaction)
        (atoms1,atoms2,atoms3)=self.atoms
        a1,a2,a3= NN_system(),NN_system(),NN_system()
        a1.RunCheckNN_FindSite(atoms1,self.site)
        a2.RunCheckNN_FindSite(atoms2,self.site)
        a3.RunCheckNN_FindSite(atoms3,self.site)
        top = {}
        bridge = {}
        hcc = {}
        for atom_indices, site_type in self.site_types.items():
            if site_type == 'top':
                top[atom_indices]=self.site_positions[atom_indices]
            elif site_type == 'bridge':
                bridge[atom_indices]=self.site_positions[atom_indices]
            else:
                hcc[atom_indices]=self.site_positions[atom_indices]
        def warp(self,a1,a2,a3):
            o1,o2,_ = checkbond(self.r,a1,a2,a3)#id in atoms
            if o1 == False or o2 == False:
                return 'the Reaction wrong'
            base_mol = a1.atoms
            total_atoms = len(base_mol)
            mola2 = copy.deepcopy(a2.only_mol)#ase atoms
            a2_ads_data0 = a2.ads_data[0]#[[nearest, distance,adsA.id,atom_indices,site_type, vector]]
            a2zv=a2_ads_data0[-1]#vector
            a2st=a2_ads_data0[-2]#site_type
            a2ai=a2_ads_data0[-3]#atom_indices
            a2sp=positions[a2ai]
            a2spv= self.special_vectors[a2ai]
            topsitepl = list(top.values())
            topsitekl = list(top.keys())
            bridegsitepl = list(bridge.values())
            bridgesitekl = list(bridge.keys())
            hccsitepl = list(hcc.values())
            hccsitekl = list(hcc.keys())
            if a2st == 'top':
                bap_a1 = base_mol[o1].position
                DQ = DistanceQuery(topsitepl)
                result_idxlist = DQ.find_points_at_distance(query_point=bap_a1,target_distance=3,tolerance=1)
                _,sp4a2 = select_site_with_max_dist(result_idxlist,base_mol,topsitepl,o1)
                v_trans = sp4a2-a2sp
                mola2.positions += v_trans
                a1a2sys = base_mol+mola2
            elif a2st == 'bridge':
                bap_a1 = base_mol[o1].position
                DQ = DistanceQuery(bridegsitepl)
                result_idxlist = DQ.find_points_at_distance(query_point=bap_a1,target_distance=3,tolerance=1)
                _,sp4a2 = select_site_with_max_dist(result_idxlist,base_mol,bridegsitepl,o1)
                sai = bridgesitekl[result_idxlist[0]]
                spv = self.special_vectors[sai]
                v_trans = sp4a2-a2sp
                R = svd_rotation_matrix(a2spv,spv)
                for a in mola2:
                    a.position = np.dot(R,a.position)
                mola2.positions += v_trans
                a1a2sys = base_mol+mola2
            else:
                bap_a1 = base_mol[o1].position
                DQ = DistanceQuery(hccsitepl)
                result_idxlist = DQ.find_points_at_distance(query_point=bap_a1,target_distance=3,tolerance=1)
                _,sp4a2 = select_site_with_max_dist(result_idxlist,base_mol,hccsitepl,o1)
                v_trans = sp4a2-a2sp
                mola2.positions += v_trans
                a1a2sys = base_mol+mola2
                indices_to_mol = [atom.index for atom in a1a2sys if atom.symbol != 'Ru']
                a1a2sys_only_mol = a1a2sys[indices_to_mol]
            


            a3_ads_data = a3.ads_data
            if len(a3_ads_data) == 1:
                [nearest, distance,adsAid,atom_indices,site_type, vector] = a3_ads_data[0]
                BeginAtom_p = a1a2sys[o1].position
                EndAtom_p = a1a2sys[o2-64+total_atoms].position
                if site_type != 'bridge':
                    if site_type == 'top':
                        sitepl = topsitepl
                    else:
                        sitepl = hccsitepl
                    min_point, _, _ = find_min_sum_distance_point_vectorized(sitepl,BeginAtom_p,EndAtom_p)
                    v_trans = min_point - nearest
                    R = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
                else:
                    sitepl = bridegsitepl
                    a3spv = self.special_vectors[atom_indices]
                    min_point, _, min_index = find_min_sum_distance_point_vectorized(sitepl,BeginAtom_p,EndAtom_p)
                    sai = bridgesitekl[min_index]
                    spv = self.special_vectors[sai]
                    v_trans = min_point - nearest
                    R = svd_rotation_matrix(a3spv,spv)
                a3sys = a3.atoms
                indices_to_mol = [atom.index for atom in a3sys if atom.symbol != 'Ru']
                for id in indices_to_mol:
                    a3sys[id].position += v_trans
                    a3sys[id].position = np.dot(R,a3sys[id].position)
            if len(a3_ads_data) > 1:
                site1 = a3_ads_data[0]
                site2 = a3_ads_data[1]

                
                
                
                



                


                    

                


            return a1a2sys

        if len(a2.ads_data) == 1 :
            print(a1.bms.smiles,len(a1.ads_data),a2.bms.smiles,len(a2.ads_data))
            IS = warp(a1,a2,a3)
        else:
            print(a1.bms.smiles,len(a1.ads_data),a2.bms.smiles,len(a2.ads_data))
            IS = warp(a2,a1,a3)
        





                




                    


In [None]:
from scipy.spatial import KDTree

class DistanceQuery:
    def __init__(self, points):
        """
        初始化KDTree索引
        """
        self.points = np.array(points)
        self.tree = cKDTree(self.points)
    
    def find_points_at_distance(self, query_point, target_distance, tolerance=1e-9):
        """
        使用KDTree高效查找特定距离的点
        
        策略：先找到半径范围内的点，再精确筛选
        """
        # 搜索半径范围 [target_distance - tolerance, target_distance + tolerance]
        min_dist = max(0, target_distance - tolerance)
        max_dist = target_distance + tolerance
        
        # 使用query_ball_point找到范围内的点
        indices = self.tree.query_ball_point(query_point, max_dist)
        
        # 精确筛选
        result_indices = []
        for idx in indices:
            dist = np.linalg.norm(self.points[idx] - query_point)
            if abs(dist - target_distance) < tolerance:
                result_indices.append(idx)
        
        return result_indices
    
    def find_all_distances(self, query_point):
        """
        返回所有点到查询点的距离，便于后续分析
        """
        return np.linalg.norm(self.points - query_point, axis=1)

# 使用示例
query_system = DistanceQuery(points)
indices = query_system.find_points_at_distance(query_point, target_distance)

In [None]:
from rdkit import Chem
from rdkit.Chem import rdmolops

# 创建两个简单的分子
mol1 = Chem.AddHs(Chem.MolFromSmiles('CCO'))  # 乙醇，3个原子
mol2 = Chem.MolFromSmiles('O')    # 水，1个原子

# 标记并显示原始分子的原子索引
def mark_and_show(mol, description):
    for atom in mol.GetAtoms():
        atom.SetAtomMapNum(atom.GetIdx())
    print(f"{description}: {Chem.MolToSmiles(mol)}")
    return mol

mol1_marked = mark_and_show(mol1, "原始mol1")
mol2_marked = mark_and_show(mol2, "原始mol2")

# 合并分子
combined_mol = rdmolops.CombineMols(mol1, mol2)

# 标记并显示合并后的分子
combined_marked = mark_and_show(combined_mol, "合并后的分子")

# 验证：遍历合并后分子的原子，查看其索引和符号
print("\n合并分子原子详情:")
print(type(mol1.GetNumAtoms()))
for atom in combined_mol.GetAtoms():
    print(f"原子索引: {atom.GetIdx()}, 原子符号: {atom.GetSymbol()}")

In [None]:
import numpy as np

def svd_rotation_matrix(a, b):
    """
    使用SVD分解计算旋转矩阵
    
    参数:
    a, b: 三维单位向量 (numpy数组)
    
    返回:
    R: 3x3旋转矩阵
    """
    # 确保输入是单位向量
    a = a / np.linalg.norm(a)
    b = b / np.linalg.norm(b)
    
    # 计算协方差矩阵
    H = np.outer(a, b)
    
    # SVD分解
    U, S, Vt = np.linalg.svd(H)
    
    # 计算旋转矩阵
    R = np.dot(Vt.T, U.T)
    
    # 处理反射情况
    if np.linalg.det(R) < 0:
        Vt[2, :] *= -1
        R = np.dot(Vt.T, U.T)
    
    return R

# 示例
a = np.array([0, 1, 0])
b = np.array([1, 0, 0])
c = np.array([1,1,0])
R = svd_rotation_matrix(a, b)
print("旋转矩阵 R:")
print(R)

# 验证
result = np.dot(R, c)
print(f"R * c = {result}")
print(f"b = {b}")
print(f"误差: {np.linalg.norm(result - b)}")

In [None]:
from ase.io import read as aseread

def std_angle(x):
    return x if x < 180 else x - 360

def genCartesianString(xyz_file):
    atom = aseread(xyz_file)
    indices_to_mol = [a.index for a in atom if a.symbol != 'Ru']
    atom = atom[indices_to_mol]
    atom = atom[1,0,3,2,4,5]
    string = ''
    symbols = list(atom.symbols)
    for idx in range(len(atom)):
        if idx == 0:
            string = '%3s\n'%symbols[idx]
        elif idx == 1:
            string += '%3s %3d %20.8f\n'%(symbols[idx],1,atom.get_distance(0,idx))
        elif idx == 2:
            string += '%3s %3d %20.8f %3d %20.8f\n'%(symbols[idx],2,atom.get_distance(1,idx),1,
                                                 std_angle(atom.get_angle(0,1,idx)))
        else:
            string += '%3s %3d %20.8f %3d %20.8f %3d %20.8f\n'%(symbols[idx],idx,atom.get_distance(idx,idx-1),idx-1,
                                                          std_angle(atom.get_angle(idx,idx-1,idx-2)),
                                                          idx-2,std_angle(atom.get_dihedral(idx-3,idx-2,idx-1,idx)))
    return string

#示例
with open ('IS_opt.vasp','r') as fr:
    lines = fr.readlines()
#print(''.join(lines))
print(genCartesianString('IS_opt.vasp'))
print(genCartesianString('FS_opt.vasp'))

In [None]:
from rdkit import Chem

# 创建包含两个独立分子的Mol对象
combined_smiles = "c1ccccc1.c1CCCCC1"  # 苯和环己烷，用点号表示独立组分
mol = Chem.MolFromSmiles(combined_smiles)

In [None]:
# 准备一个列表来接收原子映射信息
frags_mol_atom_mapping = []

# 调用 GetMolFrags 获取片段分子和原子映射
fragments_mols = Chem.GetMolFrags(
    mol, 
    asMols=True, 
    sanitizeFrags=True, 
    frags=None,
    fragsMolAtomMapping=frags_mol_atom_mapping  # 此参数用于接收原子映射信息[citation:1]
)

print(f"共获得 {len(fragments_mols)} 个片段")
print(f"原子映射信息: {frags_mol_atom_mapping}")

In [None]:
import build_ISFS_model.pre4SearchTS as pre4TS
#NEB搜索过渡态#并列执行
Pre4TS = pre4TS.PREforSearchTS('model/')
Pre4TS.readDataPath()


In [None]:
Pre4TS.buildmodel('model/RN/reactionslist.txt')

[H]C([H])C([H])([H])O_[H]OC([H])([H])C([H])([H])O

In [12]:
import numpy as np
from ase.optimize import BFGS
from ase import Atoms
from ase.io import read,write
from nequip.ase import NequIPCalculator
class DistanceAwareOptimizer(BFGS):
    """
    自定义优化器：监控特定原子对距离并在距离过大时调整力
    """
    def __init__(self, atoms, atom_indices, max_distance, force_scale=0.1, 
                 trajectory=None, logfile=None, master=None):
        """
        参数:
        atoms: Atoms对象
        atom_indices: 要监控的原子对索引，例如 (0, 1)
        max_distance: 允许的最大距离（埃）
        force_scale: 力调整的缩放因子
        """
        super().__init__(atoms, trajectory, logfile, master)
        self.atom_indices = atom_indices
        self.max_distance = max_distance
        self.force_scale = force_scale
        
    def check_and_adjust_forces(self, forces):
        """检查原子距离并在需要时调整力"""
        i, j = self.atom_indices
        positions = self.atoms.positions
        
        # 计算当前原子距离
        distance = np.linalg.norm(positions[i] - positions[j])
        print(f"当前原子 {i} 和 {j} 之间的距离: {distance:.4f} Å")
        
        # 如果距离超过阈值，调整力
        if distance > self.max_distance:
            print(f"距离超过阈值 {self.max_distance} Å，调整受力...")
            
            # 计算原子间方向向量
            direction = positions[j] - positions[i]
            direction_unit = direction / np.linalg.norm(direction)
            
            # 计算当前原子间作用力的大小
            current_force_magnitude = np.linalg.norm(forces[i] - forces[j])
            
            # 添加额外的吸引力（指向对方）
            # 力的大小与超出阈值的程度和当前力的大小相关
            extra_force_magnitude = (self.force_scale * current_force_magnitude) * (distance - self.max_distance)
            
            # 应用调整后的力
            forces[i] += extra_force_magnitude * direction_unit
            forces[j] -= extra_force_magnitude * direction_unit
            
            print(f"已调整受力，额外力大小: {extra_force_magnitude:.6f} eV/Å")
        
        return forces
    
    def step(self, forces=None):
        """重写step方法，在每一步优化中检查并调整力"""
        if forces is None:
            forces = self.atoms.get_forces()
        
        # 检查距离并调整力
        adjusted_forces = self.check_and_adjust_forces(forces)
        
        # 使用调整后的力进行优化步骤
        super().step(adjusted_forces)

# ========== 使用示例 ==========
def example_usage():
    """使用示例"""
    model_path = 'prototypeModel.pth'
    calc = NequIPCalculator.from_deployed_model(model_path, device='cpu')
    
    # 创建一个示例分子（一氧化碳）
    atoms = read('000.traj',index=0)
    write('IS.vasp', atoms, format='vasp', vasp5=True)
    atoms.calc = calc  # 使用NequIP计算器
    # 设置要监控的原子对（C和O原子）
    carbon_idx, oxygen_idx = 66, 67  # 在CO分子中
    
    print("初始结构:")
    print(f"C 原子位置: {atoms.positions[carbon_idx]}")
    print(f"O 原子位置: {atoms.positions[oxygen_idx]}")
    
    initial_distance = np.linalg.norm(atoms.positions[carbon_idx] - atoms.positions[oxygen_idx])
    print(f"初始C-O距离: {initial_distance:.4f} Å")
    
    # 创建自定义优化器
    # 设置最大允许距离为1.5 Å（CO键长约为1.13 Å）
    opt = DistanceAwareOptimizer(
        atoms=atoms,
        atom_indices=(carbon_idx, oxygen_idx),
        max_distance=1.5,
        force_scale=2,
        logfile='-',
    )
    
    # 运行优化
    print("\n开始结构优化...")
    opt.run(fmax=0.05, steps=1000)
    
    print("\n优化后结构:")
    print(f"C 原子位置: {atoms.positions[carbon_idx]}")
    print(f"O 原子位置: {atoms.positions[oxygen_idx]}")
    final_distance = np.linalg.norm(atoms.positions[carbon_idx] - atoms.positions[oxygen_idx])
    print(f"最终C-O距离: {final_distance:.4f} Å")
    write('FS.vasp', atoms, format='vasp', vasp5=True)
# ========== 高级版本：监控多个原子对 ==========
class MultiDistanceAwareOptimizer(BFGS):
    """监控多个原子对距离的优化器"""
    
    def __init__(self, atoms, distance_constraints, force_scale=0.1, 
                 trajectory=None, logfile=None, master=None):
        """
        参数:
        distance_constraints: 列表，每个元素为 (atom_i, atom_j, max_distance)
        """
        super().__init__(atoms, trajectory, logfile, master)
        self.distance_constraints = distance_constraints
        self.force_scale = force_scale
        
    def check_and_adjust_forces(self, forces):
        """检查多个原子对距离并调整力"""
        positions = self.atoms.positions
        
        for i, j, max_dist in self.distance_constraints:
            distance = np.linalg.norm(positions[i] - positions[j])
            
            if distance > max_dist:
                direction = positions[j] - positions[i]
                direction_unit = direction / np.linalg.norm(direction)
                
                current_force = np.linalg.norm(forces[i] - forces[j])
                extra_force = self.force_scale * current_force * (distance - max_dist)
                
                forces[i] += extra_force * direction_unit
                forces[j] -= extra_force * direction_unit
                
                print(f"调整原子对 ({i},{j}) 受力，距离: {distance:.4f} Å")
        
        return forces
    
    def step(self, forces=None):
        if forces is None:
            forces = self.atoms.get_forces()
        
        adjusted_forces = self.check_and_adjust_forces(forces)
        super().step(adjusted_forces)

if __name__ == "__main__":
    example_usage()

初始结构:
C 原子位置: [ 5.41323252  7.49622097 19.29009626]
O 原子位置: [ 5.4024266   7.71309066 17.90771699]
初始C-O距离: 1.3993 Å

开始结构优化...
                        Step     Time          Energy          fmax
DistanceAwareOptimizer:    0 15:48:41     -611.280701        1.530114
当前原子 66 和 67 之间的距离: 1.3993 Å
DistanceAwareOptimizer:    1 15:48:42     -611.377258        1.506197
当前原子 66 和 67 之间的距离: 1.4350 Å
DistanceAwareOptimizer:    2 15:48:42     -611.559753        1.710781
当前原子 66 和 67 之间的距离: 1.5048 Å
距离超过阈值 1.5 Å，调整受力...
已调整受力，额外力大小: 0.010841 eV/Å
DistanceAwareOptimizer:    3 15:48:43     -611.781250        2.065255
当前原子 66 和 67 之间的距离: 1.5790 Å
距离超过阈值 1.5 Å，调整受力...
已调整受力，额外力大小: 0.345704 eV/Å
DistanceAwareOptimizer:    4 15:48:43     -612.191467        2.651088
当前原子 66 和 67 之间的距离: 1.6724 Å
距离超过阈值 1.5 Å，调整受力...
已调整受力，额外力大小: 0.778545 eV/Å
DistanceAwareOptimizer:    5 15:48:43     -612.865417        3.422809
当前原子 66 和 67 之间的距离: 1.7713 Å
距离超过阈值 1.5 Å，调整受力...
已调整受力，额外力大小: 1.573673 eV/Å
DistanceAwareOptimi

In [None]:
import numpy as np
from ase import Atoms
from ase.optimize import BFGS, FIRE
from ase.io import read, write
from ase.calculators.emt import EMT
from typing import List, Tuple, Dict
import matplotlib.pyplot as plt

class DistanceConstrainedOptimizer:
    """
    带有距离约束的优化器
    监控特定原子对距离，超过阈值时施加约束力
    """
    
    def __init__(self, 
                 atoms: Atoms, 
                 atom_pairs: List[Tuple[int, int]],
                 distance_thresholds: List[float],
                 constraint_strength: float = 0.5,
                 log_file: str = "distance_log.txt"):
        """
        初始化距离约束优化器
        
        参数:
        atoms: ASE原子对象
        atom_pairs: 要监控的原子对列表，如[(0,1), (1,2)]
        distance_thresholds: 每个原子对的阈值列表（单位：Å）
        constraint_strength: 约束力强度系数
        log_file: 距离日志文件
        """
        self.atoms = atoms
        self.atom_pairs = atom_pairs
        self.distance_thresholds = distance_thresholds
        self.constraint_strength = constraint_strength
        self.log_file = log_file
        
        # 验证输入
        assert len(atom_pairs) == len(distance_thresholds), \
            "原子对数量和阈值数量必须相同"
        
        # 初始化记录
        self.distance_history = []
        self.forces_history = []
        
    def calculate_distances(self) -> List[float]:
        """计算所有监控原子对的距离"""
        distances = []
        for i, (a, b) in enumerate(self.atom_pairs):
            pos_a = self.atoms.get_positions()[a]
            pos_b = self.atoms.get_positions()[b]
            dist = np.linalg.norm(pos_a - pos_b)
            distances.append(dist)
        return distances
    
    def apply_distance_constraints(self, forces: np.ndarray) -> np.ndarray:
        """
        应用距离约束到受力上
        如果原子对距离超过阈值，施加一个拉力使其接近
        """
        modified_forces = forces.copy()
        
        for i, (a, b) in enumerate(self.atom_pairs):
            pos_a = self.atoms.get_positions()[a]
            pos_b = self.atoms.get_positions()[b]
            dist_vec = pos_b - pos_a
            distance = np.linalg.norm(dist_vec)
            threshold = self.distance_thresholds[i]
            
            # 如果距离超过阈值，施加约束力
            if distance > threshold:
                # 计算目标距离与实际距离的差异
                delta = distance - threshold
                
                # 计算方向单位向量
                if distance > 0:
                    direction = dist_vec / distance
                else:
                    direction = np.array([1, 0, 0])  # 默认方向
                
                # 计算约束力大小（与距离差异成正比）
                constraint_force = self.constraint_strength * delta
                
                # 施加力：原子a被拉向原子b，原子b被拉向原子a（大小相等方向相反）
                modified_forces[a] += constraint_force * direction
                modified_forces[b] -= constraint_force * direction
                
                # 记录施加的约束力
                print(f"原子对 ({a},{b}): 距离={distance:.3f} Å > 阈值={threshold} Å")
                print(f"  施加约束力: {constraint_force:.3f} eV/Å")
        
        return modified_forces
    
    def log_distances(self, step: int):
        """记录当前距离到日志文件"""
        distances = self.calculate_distances()
        self.distance_history.append(distances.copy())
        
        with open(self.log_file, 'a') as f:
            f.write(f"优化步数: {step}\n")
            for i, (a, b) in enumerate(self.atom_pairs):
                f.write(f"  原子对({a},{b}): {distances[i]:.3f} Å ")
                if distances[i] > self.distance_thresholds[i]:
                    f.write(f"(超过阈值 {self.distance_thresholds[i]} Å)\n")
                else:
                    f.write("\n")
            f.write("\n")
        
        return distances
    
    def optimize_with_constraints(self, 
                                 optimizer_type: str = 'BFGS',
                                 fmax: float = 0.05,
                                 steps: int = 200,
                                 trajectory: str = 'optimization.traj'):
        """
        执行带距离约束的优化
        
        参数:
        optimizer_type: 优化器类型 ('BFGS' 或 'FIRE')
        fmax: 最大受力阈值
        steps: 最大优化步数
        trajectory: 轨迹文件名
        """
        print("="*60)
        print("开始带距离约束的结构优化")
        print("="*60)
        
        # 设置计算器（这里使用EMT作为示例）
        self.atoms.set_calculator(EMT())
        
        # 初始化优化器
        if optimizer_type.upper() == 'BFGS':
            opt = BFGS(self.atoms, trajectory=trajectory)
        elif optimizer_type.upper() == 'FIRE':
            opt = FIRE(self.atoms, trajectory=trajectory)
        else:
            raise ValueError(f"不支持的优化器类型: {optimizer_type}")
        
        # 清空日志文件
        with open(self.log_file, 'w') as f:
            f.write("距离约束优化日志\n")
            f.write("="*40 + "\n")
        
        # 自定义优化循环，允许我们修改受力
        step = 0
        converged = False
        
        while step < steps and not converged:
            # 计算能量和受力
            energy = self.atoms.get_potential_energy()
            forces = self.atoms.get_forces()
            
            # 记录当前受力
            self.forces_history.append(forces.copy())
            
            # 记录距离
            distances = self.log_distances(step)
            
            # 应用距离约束到受力上
            modified_forces = self.apply_distance_constraints(forces)
            
            # 更新原子受力
            self.atoms.set_array('forces', modified_forces)
            
            # 检查收敛性（基于修改后的受力）
            max_force = np.max(np.linalg.norm(modified_forces, axis=1))
            converged = max_force < fmax
            
            # 使用修改后的受力更新原子位置
            # 这里我们手动实现一步梯度下降作为示例
            # 实际应用中，您可能需要根据优化器类型调整
            if optimizer_type.upper() == 'BFGS':
                # BFGS需要更复杂的实现，这里简化处理
                # 实际使用时建议使用ASE内置的优化器
                pass
            
            # 或者，我们可以创建一个自定义优化器类
            # 这里为了简化，我们只记录和修改受力，让ASE优化器继续工作
            
            print(f"步数 {step}: 能量 = {energy:.4f} eV, 最大受力 = {max_force:.4f} eV/Å")
            
            # 手动执行一步优化（简化的实现）
            # 注意：这是一个简化示例，实际优化器更复杂
            if not converged:
                # 梯度下降的一步（学习率需要调整）
                positions = self.atoms.get_positions()
                # 简化的梯度下降，实际应使用ASE优化器的step方法
                # self.atoms.set_positions(positions - 0.01 * modified_forces)
                
                # 使用ASE优化器的step方法
                opt.step()
                step += 1
            else:
                print("优化已收敛！")
                break
        
        print(f"\n优化完成，总步数: {step}")
        if converged:
            print(f"在受力阈值 fmax={fmax} eV/Å 下收敛")
        else:
            print(f"达到最大步数 {steps}")
        
        return converged
    
    def plot_distance_history(self, save_fig: bool = True):
        """绘制距离历史"""
        if not self.distance_history:
            print("没有距离历史数据")
            return
        
        distance_array = np.array(self.distance_history)
        steps = np.arange(len(distance_array))
        
        plt.figure(figsize=(10, 6))
        
        for i, (a, b) in enumerate(self.atom_pairs):
            plt.plot(steps, distance_array[:, i], 
                    marker='o', markersize=4, linewidth=2,
                    label=f'原子对 ({a},{b})')
            # 绘制阈值线
            plt.axhline(y=self.distance_thresholds[i], 
                       color=plt.gca().lines[-1].get_color(),
                       linestyle='--', alpha=0.5,
                       label=f'阈值 ({a},{b}) = {self.distance_thresholds[i]} Å')
        
        plt.xlabel('优化步数', fontsize=12)
        plt.ylabel('距离 (Å)', fontsize=12)
        plt.title('优化过程中原子对距离变化', fontsize=14)
        plt.legend(fontsize=10)
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        
        if save_fig:
            plt.savefig('distance_evolution.png', dpi=300)
        plt.show()
    
    def plot_forces_history(self, atom_index: int = 0, save_fig: bool = True):
        """绘制特定原子的受力历史"""
        if not self.forces_history:
            print("没有受力历史数据")
            return
        
        forces_array = np.array(self.forces_history)
        steps = np.arange(len(forces_array))
        
        plt.figure(figsize=(12, 8))
        
        # 绘制三个方向的分力
        directions = ['X', 'Y', 'Z']
        for i in range(3):
            plt.subplot(3, 1, i+1)
            plt.plot(steps, forces_array[:, atom_index, i], 
                    'b-', linewidth=2)
            plt.xlabel('优化步数', fontsize=10)
            plt.ylabel(f'F_{directions[i]} (eV/Å)', fontsize=10)
            plt.title(f'原子 {atom_index} 的 {directions[i]} 方向受力变化', fontsize=11)
            plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        if save_fig:
            plt.savefig(f'forces_atom_{atom_index}.png', dpi=300)
        plt.show()


# 使用示例
def example_usage():
    """
    使用示例：创建一个简单的分子并优化
    """
    
    # 创建示例结构（一个拉伸的分子）
    positions = np.array([
        [0.0, 0.0, 0.0],   # 原子0
        [0.0, 0.0, 3.0],   # 原子1（距离原子0为3.0 Å）
        [0.0, 2.0, 0.0],   # 原子2
        [0.0, 2.0, 3.0],   # 原子3
    ])
    
    atoms = Atoms('H4', positions=positions)
    
    # 定义要监控的原子对和阈值
    atom_pairs = [(0, 1), (2, 3), (0, 2)]  # 监控这三对原子
    distance_thresholds = [2.0, 2.0, 1.5]   # 对应的阈值（Å）
    
    # 创建优化器
    optimizer = DistanceConstrainedOptimizer(
        atoms=atoms,
        atom_pairs=atom_pairs,
        distance_thresholds=distance_thresholds,
        constraint_strength=0.3,  # 约束力强度
        log_file="optimization_log.txt"
    )
    
    # 优化前的距离
    print("优化前的原子对距离:")
    initial_distances = optimizer.calculate_distances()
    for i, (a, b) in enumerate(atom_pairs):
        print(f"  原子对({a},{b}): {initial_distances[i]:.3f} Å")
    
    # 执行优化
    converged = optimizer.optimize_with_constraints(
        optimizer_type='BFGS',
        fmax=0.05,      # 收敛标准：最大受力小于0.05 eV/Å
        steps=50,       # 最大优化步数
        trajectory='constrained_opt.traj'
    )
    
    # 优化后的距离
    print("\n优化后的原子对距离:")
    final_distances = optimizer.calculate_distances()
    for i, (a, b) in enumerate(atom_pairs):
        status = "超过阈值" if final_distances[i] > distance_thresholds[i] else "在阈值内"
        print(f"  原子对({a},{b}): {final_distances[i]:.3f} Å ({status})")
    
    # 绘制距离变化
    optimizer.plot_distance_history()
    
    # 绘制原子0的受力变化
    optimizer.plot_forces_history(atom_index=0)
    
    # 保存最终结构
    write('optimized_structure.xyz', atoms)
    print("\n优化后的结构已保存到 'optimized_structure.xyz'")
    
    return optimizer


# 高级版本：自定义优化器类，完全集成到ASE框架中
class CustomBFGS(BFGS):
    """自定义BFGS优化器，集成距离约束"""
    
    def __init__(self, atoms, distance_constraints=None, **kwargs):
        """
        参数:
        atoms: ASE原子对象
        distance_constraints: 距离约束字典
            {'atom_pairs': [(0,1), (1,2)], 
             'thresholds': [2.0, 2.0],
             'strength': 0.5}
        """
        super().__init__(atoms, **kwargs)
        
        if distance_constraints is None:
            distance_constraints = {}
        
        self.atom_pairs = distance_constraints.get('atom_pairs', [])
        self.thresholds = distance_constraints.get('thresholds', [])
        self.strength = distance_constraints.get('strength', 0.5)
        self.distance_history = []
    
    def step(self, f=None):
        """重写step方法，在每次迭代前修改受力"""
        
        # 计算当前原子对距离
        distances = []
        for a, b in self.atom_pairs:
            pos_a = self.atoms.get_positions()[a]
            pos_b = self.atoms.get_positions()[b]
            dist = np.linalg.norm(pos_a - pos_b)
            distances.append(dist)
        
        self.distance_history.append(distances)
        
        # 获取原始受力
        if f is None:
            f = self.atoms.get_forces()
        
        # 修改受力
        f_modified = f.copy()
        for i, (a, b) in enumerate(self.atom_pairs):
            if i < len(self.thresholds) and distances[i] > self.thresholds[i]:
                pos_a = self.atoms.get_positions()[a]
                pos_b = self.atoms.get_positions()[b]
                dist_vec = pos_b - pos_a
                distance = distances[i]
                
                if distance > 0:
                    direction = dist_vec / distance
                    delta = distance - self.thresholds[i]
                    constraint_force = self.strength * delta
                    
                    f_modified[a] += constraint_force * direction
                    f_modified[b] -= constraint_force * direction
        
        # 使用修改后的受力调用父类的step方法
        return super().step(f_modified)


# 使用自定义优化器的示例
def advanced_example():
    """使用自定义优化器的高级示例"""
    
    # 创建结构
    from ase.build import molecule
    
    # 创建一个水分子
    atoms = molecule('H2O')
    atoms.center(vacuum=5.0)
    
    # 设置计算器
    atoms.set_calculator(EMT())
    
    # 定义距离约束（监控O-H键）
    distance_constraints = {
        'atom_pairs': [(0, 1), (0, 2)],  # O-H键
        'thresholds': [1.2, 1.2],        # 阈值 1.2 Å
        'strength': 0.4                  # 约束力强度
    }
    
    # 使用自定义优化器
    print("使用自定义BFGS优化器进行优化...")
    opt = CustomBFGS(atoms, 
                     distance_constraints=distance_constraints,
                     trajectory='custom_opt.traj')
    
    # 运行优化
    opt.run(fmax=0.05, steps=100)
    
    print("\n优化完成!")
    print(f"最终能量: {atoms.get_potential_energy():.4f} eV")
    
    # 绘制距离历史
    plt.figure(figsize=(10, 6))
    distance_array = np.array(opt.distance_history)
    steps = np.arange(len(distance_array))
    
    for i, pair in enumerate(distance_constraints['atom_pairs']):
        plt.plot(steps, distance_array[:, i], 
                label=f'O-H{i+1} 距离', linewidth=2)
        plt.axhline(y=distance_constraints['thresholds'][i], 
                   linestyle='--', alpha=0.5,
                   label=f'阈值 {distance_constraints["thresholds"][i]} Å')
    
    plt.xlabel('优化步数')
    plt.ylabel('距离 (Å)')
    plt.title('O-H键距离变化')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig('oh_distance_evolution.png', dpi=300)
    plt.show()


if __name__ == "__main__":
    print("示例1：基本距离约束优化")
    print("-" * 40)
    optimizer = example_usage()
    
    print("\n\n示例2：使用自定义优化器")
    print("-" * 40)
    advanced_example()

In [None]:
import numpy as np
from scipy.spatial.distance import cdist, pdist, squareform
from scipy.spatial.transform import Rotation as R
from scipy.optimize import linear_sum_assignment, minimize
import ase
from ase import Atoms
from typing import Tuple, List, Dict, Optional, Any
import itertools
from collections import defaultdict, Counter
import warnings

warnings.filterwarnings('ignore')


class EnhancedPointSetMatcher:
    """
    增强的三维点集匹配器，考虑整体结构和元素间相对距离
    """
    
    def __init__(self, reference_points: np.ndarray, 
                 reference_symbols: Optional[List[str]] = None):
        """
        初始化参考点集
        
        参数:
        reference_points: (N, 3) 形状的numpy数组，N个三维点
        reference_symbols: 可选的元素符号列表
        """
        self.reference_points = np.asarray(reference_points)
        self.n_points = len(reference_points)
        self.reference_symbols = reference_symbols
        
        if reference_symbols is not None:
            if len(reference_symbols) != self.n_points:
                raise ValueError("元素符号数量必须与点数量相同")
            self._setup_element_groups()
            # 计算参考结构的距离矩阵
            self.reference_distance_matrix = self._compute_distance_matrix(self.reference_points)
            # 计算参考结构的相对距离特征
            self.reference_distance_features = self._compute_distance_features()
    
    def _setup_element_groups(self):
        """按元素分组"""
        self.element_groups = {}
        self.element_indices = {}
        for symbol in set(self.reference_symbols):
            indices = [i for i, s in enumerate(self.reference_symbols) if s == symbol]
            self.element_groups[symbol] = self.reference_points[indices]
            self.element_indices[symbol] = indices
    
    def _compute_distance_matrix(self, points: np.ndarray) -> np.ndarray:
        """计算点间距离矩阵"""
        return squareform(pdist(points))
    
    def _compute_distance_features(self) -> Dict:
        """计算距离特征，考虑不同元素间的相对距离"""
        features = {
            'all_distances': pdist(self.reference_points),
            'intra_element_distances': {},
            'inter_element_distances': {}
        }
        
        if self.reference_symbols is None:
            return features
        
        # 计算同种元素内部的距离
        for symbol in self.element_groups.keys():
            indices = self.element_indices[symbol]
            if len(indices) > 1:
                points = self.reference_points[indices]
                distances = pdist(points)
                features['intra_element_distances'][symbol] = distances
        
        # 计算不同元素之间的距离
        symbols = list(self.element_groups.keys())
        for i in range(len(symbols)):
            for j in range(i, len(symbols)):
                symbol1 = symbols[i]
                symbol2 = symbols[j]
                
                if symbol1 == symbol2:
                    continue
                
                indices1 = self.element_indices[symbol1]
                indices2 = self.element_indices[symbol2]
                
                # 计算所有symbol1和symbol2原子间的距离
                distances = []
                for idx1 in indices1:
                    for idx2 in indices2:
                        dist = np.linalg.norm(self.reference_points[idx1] - self.reference_points[idx2])
                        distances.append(dist)
                
                key = f"{symbol1}-{symbol2}"
                features['inter_element_distances'][key] = np.array(distances)
        
        return features
    
    def rotate_points_around_axis(self, points: np.ndarray, 
                                 axis_point: np.ndarray,
                                 axis_direction: np.ndarray,
                                 angle_rad: float) -> np.ndarray:
        """
        将点集绕给定轴旋转
        
        参数:
        points: (N, 3) 形状的点集
        axis_point: 旋转轴上的一个点 (3,)
        axis_direction: 旋转轴方向向量 (3,)
        angle_rad: 旋转角度（弧度）
        
        返回:
        旋转后的点集
        """
        # 确保轴方向是单位向量
        axis_direction = axis_direction / np.linalg.norm(axis_direction)
        
        # 使用scipy的Rotation类（更稳定）
        rotation = R.from_rotvec(axis_direction * angle_rad)
        
        # 平移点集使旋转轴通过原点
        translated_points = points - axis_point
        
        # 应用旋转
        rotated_translated = rotation.apply(translated_points)
        
        # 平移回原位置
        rotated_points = rotated_translated + axis_point
        
        return rotated_points
    
    def compute_distance_similarity(self, points1: np.ndarray, 
                                   points2: np.ndarray,
                                   symbols1: Optional[List[str]] = None,
                                   symbols2: Optional[List[str]] = None,
                                   weight_intra: float = 1.0,
                                   weight_inter: float = 1.5) -> Dict[str, Any]:
        """
        计算考虑元素间相对距离的相似性
        
        参数:
        points1, points2: 要比较的两个点集
        symbols1, symbols2: 对应的元素符号
        weight_intra: 同种元素间距离的权重
        weight_inter: 不同元素间距离的权重
        
        返回:
        包含各种相似性指标的字典
        """
        results = {}
        
        # 1. 整体距离相似性
        overall_distances = np.linalg.norm(points1 - points2, axis=1)
        results['overall_rmse'] = np.sqrt(np.mean(overall_distances**2))
        results['overall_max_dist'] = np.max(overall_distances)
        results['overall_similarity'] = np.exp(-results['overall_rmse'])
        
        # 2. 如果提供了元素符号，计算元素间距离相似性
        if symbols1 is not None and symbols2 is not None:
            if len(symbols1) != len(points1) or len(symbols2) != len(points2):
                raise ValueError("元素符号数量必须与点数量相同")
            
            # 验证元素类型匹配
            if Counter(symbols1) != Counter(symbols2):
                warnings.warn("元素类型分布不匹配，元素间距离相似性可能不准确")
            
            # 计算距离矩阵
            dist_matrix1 = self._compute_distance_matrix(points1)
            dist_matrix2 = self._compute_distance_matrix(points2)
            
            # 距离矩阵相似性
            dist_diff = np.abs(dist_matrix1 - dist_matrix2)
            results['distance_matrix_rmse'] = np.sqrt(np.mean(dist_diff**2))
            results['distance_matrix_similarity'] = np.exp(-results['distance_matrix_rmse'])
            
            # 计算同种元素间距离相似性
            intra_element_similarities = {}
            symbols_set = set(symbols1)
            
            for symbol in symbols_set:
                indices1 = [i for i, s in enumerate(symbols1) if s == symbol]
                indices2 = [i for i, s in enumerate(symbols2) if s == symbol]
                
                if len(indices1) <= 1:
                    continue
                
                # 计算同种元素间的距离
                intra_dist1 = dist_matrix1[np.ix_(indices1, indices1)]
                intra_dist2 = dist_matrix2[np.ix_(indices2, indices2)]
                
                # 只取上三角（不包括对角线）
                mask = np.triu(np.ones_like(intra_dist1, dtype=bool), k=1)
                dists1 = intra_dist1[mask]
                dists2 = intra_dist2[mask]
                
                if len(dists1) > 0:
                    intra_rmse = np.sqrt(np.mean((dists1 - dists2)**2))
                    intra_element_similarities[symbol] = {
                        'rmse': intra_rmse,
                        'similarity': np.exp(-intra_rmse * weight_intra),
                        'num_pairs': len(dists1)
                    }
            
            results['intra_element_similarities'] = intra_element_similarities
            
            # 计算不同元素间距离相似性
            inter_element_similarities = {}
            
            for (symbol1, symbol2) in itertools.combinations(symbols_set, 2):
                indices1_a = [i for i, s in enumerate(symbols1) if s == symbol1]
                indices1_b = [i for i, s in enumerate(symbols1) if s == symbol2]
                indices2_a = [i for i, s in enumerate(symbols2) if s == symbol1]
                indices2_b = [i for i, s in enumerate(symbols2) if s == symbol2]
                
                # 计算不同元素间的距离
                inter_dist1 = dist_matrix1[np.ix_(indices1_a, indices1_b)]
                inter_dist2 = dist_matrix2[np.ix_(indices2_a, indices2_b)]
                
                if inter_dist1.size > 0:
                    inter_rmse = np.sqrt(np.mean((inter_dist1.flatten() - inter_dist2.flatten())**2))
                    key = f"{symbol1}-{symbol2}"
                    inter_element_similarities[key] = {
                        'rmse': inter_rmse,
                        'similarity': np.exp(-inter_rmse * weight_inter),
                        'num_pairs': inter_dist1.size
                    }
            
            results['inter_element_similarities'] = inter_element_similarities
            
            # 计算加权平均相似性
            total_weight = 0
            weighted_sum = 0
            
            for symbol, data in intra_element_similarities.items():
                weight = data['num_pairs'] * weight_intra
                weighted_sum += data['similarity'] * weight
                total_weight += weight
            
            for key, data in inter_element_similarities.items():
                weight = data['num_pairs'] * weight_inter
                weighted_sum += data['similarity'] * weight
                total_weight += weight
            
            if total_weight > 0:
                results['weighted_similarity'] = weighted_sum / total_weight
            else:
                results['weighted_similarity'] = results['overall_similarity']
        
        return results
    
    def find_best_match_with_elements(self, input_points: np.ndarray,
                                     input_symbols: List[str]) -> Tuple[np.ndarray, np.ndarray, Dict]:
        """
        考虑元素类型的最佳匹配（点对点对应）
        
        参数:
        input_points: 输入点集
        input_symbols: 输入点集的元素符号
        
        返回:
        matched_points: 重新排序后的输入点集
        matched_symbols: 重新排序后的元素符号
        match_info: 匹配信息
        """
        if self.reference_symbols is None:
            raise ValueError("参考点集必须有元素符号")
        
        if len(input_symbols) != len(input_points):
            raise ValueError("输入元素符号数量必须与点数量相同")
        
        # 检查元素类型分布是否匹配
        ref_counter = Counter(self.reference_symbols)
        input_counter = Counter(input_symbols)
        
        if ref_counter != input_counter:
            raise ValueError(f"元素类型分布不匹配: 参考={dict(ref_counter)}, 输入={dict(input_counter)}")
        
        # 为每种元素类型分别进行匹配
        all_matched_indices = np.zeros(len(input_points), dtype=int)
        match_details = {}
        
        for symbol in ref_counter.keys():
            # 获取参考和输入中该元素的索引
            ref_indices = [i for i, s in enumerate(self.reference_symbols) if s == symbol]
            input_indices = [i for i, s in enumerate(input_symbols) if s == symbol]
            
            # 提取对应点
            ref_points_symbol = self.reference_points[ref_indices]
            input_points_symbol = input_points[input_indices]
            
            # 计算距离矩阵
            dist_matrix = cdist(ref_points_symbol, input_points_symbol)
            
            # 使用匈牙利算法找到最小成本匹配
            row_ind, col_ind = linear_sum_assignment(dist_matrix)
            
            # 存储匹配结果
            match_details[symbol] = {
                'ref_indices': ref_indices,
                'input_indices': [input_indices[i] for i in col_ind],
                'distances': dist_matrix[row_ind, col_ind],
                'avg_distance': dist_matrix[row_ind, col_ind].mean()
            }
            
            # 将匹配结果映射到全局索引
            for ref_idx, input_idx in zip(ref_indices, [input_indices[i] for i in col_ind]):
                all_matched_indices[ref_idx] = input_idx
        
        # 根据匹配结果重新排序输入点集和元素符号
        matched_points = input_points[all_matched_indices]
        matched_symbols = [input_symbols[i] for i in all_matched_indices]
        
        # 计算匹配质量
        total_distance = sum(detail['distances'].sum() for detail in match_details.values())
        avg_distance = total_distance / len(input_points)
        match_info = {
            'matched_indices': all_matched_indices,
            'match_details': match_details,
            'total_distance': total_distance,
            'avg_distance': avg_distance,
            'similarity': np.exp(-avg_distance)
        }
        
        return matched_points, matched_symbols, match_info
    
    def align_whole_structure(self, input_points: np.ndarray,
                             input_symbols: List[str],
                             use_elements: bool = True) -> Tuple[np.ndarray, np.ndarray, Dict]:
        """
        对齐整个结构，考虑元素类型
        
        参数:
        input_points: 输入点集
        input_symbols: 输入元素符号
        use_elements: 是否使用元素类型信息进行匹配
        
        返回:
        aligned_points: 对齐后的点集
        aligned_symbols: 对齐后的元素符号（重新排序后）
        alignment_info: 对齐信息
        """
        if use_elements and self.reference_symbols is not None:
            # 使用元素类型进行匹配
            matched_points, matched_symbols, match_info = \
                self.find_best_match_with_elements(input_points, input_symbols)
        else:
            # 不考虑元素类型，直接使用所有点
            distance_matrix = cdist(self.reference_points, input_points)
            row_ind, col_ind = linear_sum_assignment(distance_matrix)
            matched_points = input_points[col_ind]
            matched_symbols = input_symbols if input_symbols is not None else ['X'] * len(input_points)
            match_info = {
                'matched_indices': col_ind,
                'avg_distance': distance_matrix[row_ind, col_ind].mean()
            }
        
        # 使用Kabsch算法进行最优旋转对齐
        aligned_points, rotation_matrix = self._kabsch_alignment(
            matched_points, self.reference_points
        )
        
        alignment_info = {
            **match_info,
            'rotation_matrix': rotation_matrix,
            'aligned_points': aligned_points
        }
        
        return aligned_points, matched_symbols, alignment_info
    
    def _kabsch_alignment(self, points1: np.ndarray, points2: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """
        Kabsch算法：最小二乘对齐
        """
        # 计算质心
        centroid1 = np.mean(points1, axis=0)
        centroid2 = np.mean(points2, axis=0)
        
        # 中心化
        centered1 = points1 - centroid1
        centered2 = points2 - centroid2
        
        # 计算协方差矩阵
        H = centered1.T @ centered2
        
        # 奇异值分解
        U, S, Vt = np.linalg.svd(H)
        
        # 计算旋转矩阵
        d = np.sign(np.linalg.det(Vt.T @ U.T))
        rotation_matrix = Vt.T @ np.diag([1, 1, d]) @ U.T
        
        # 应用旋转和平移
        aligned_points = (rotation_matrix @ (points1 - centroid1).T).T + centroid2
        
        return aligned_points, rotation_matrix
    
    def optimize_rotation_angle(self, input_points: np.ndarray,
                               input_symbols: List[str],
                               axis_point: np.ndarray,
                               axis_direction: np.ndarray,
                               use_elements: bool = True) -> Dict:
        """
        优化旋转角度以最大化相似性
        
        参数:
        input_points: 输入点集
        input_symbols: 输入元素符号
        axis_point: 旋转轴上的点
        axis_direction: 旋转轴方向
        use_elements: 是否使用元素类型信息
        
        返回:
        优化结果
        """
        def objective(angle_rad):
            # 旋转点集
            rotated_points = self.rotate_points_around_axis(
                input_points, axis_point, axis_direction, angle_rad
            )
            
            # 对齐并计算相似性
            aligned_points, _, alignment_info = self.align_whole_structure(
                rotated_points, input_symbols, use_elements
            )
            
            # 计算考虑元素间距离的相似性
            similarity_result = self.compute_distance_similarity(
                self.reference_points, aligned_points,
                self.reference_symbols, input_symbols
            )
            
            # 返回负相似性（因为我们要最大化相似性）
            return -similarity_result['weighted_similarity']
        
        # 使用优化算法寻找最佳角度
        result = minimize(
            objective,
            x0=0.0,
            bounds=[(-np.pi, np.pi)],
            method='L-BFGS-B'
        )
        
        best_angle = result.x[0]
        best_similarity = -result.fun
        
        # 计算最佳旋转下的完整结果
        rotated_points = self.rotate_points_around_axis(
            input_points, axis_point, axis_direction, best_angle
        )
        
        aligned_points, matched_symbols, alignment_info = self.align_whole_structure(
            rotated_points, input_symbols, use_elements
        )
        
        similarity_result = self.compute_distance_similarity(
            self.reference_points, aligned_points,
            self.reference_symbols, matched_symbols
        )
        
        return {
            'best_angle_rad': best_angle,
            'best_angle_deg': np.degrees(best_angle),
            'best_similarity': best_similarity,
            'aligned_points': aligned_points,
            'matched_symbols': matched_symbols,
            'similarity_result': similarity_result,
            'alignment_info': alignment_info,
            'optimization_success': result.success
        }


class AdvancedASEMatcher:
    """
    高级ASE结构匹配器，专门处理晶体结构
    """
    
    def __init__(self, reference_structure: Atoms):
        """
        初始化参考结构
        
        参数:
        reference_structure: ASE Atoms对象
        """
        self.reference_structure = reference_structure
        self.reference_positions = reference_structure.get_positions()
        self.reference_symbols = reference_structure.get_chemical_symbols()
        
        # 创建点集匹配器
        self.matcher = EnhancedPointSetMatcher(
            self.reference_positions, self.reference_symbols
        )
        
        # 提取结构特征
        self.structure_features = self._extract_structure_features()
    
    def _extract_structure_features(self) -> Dict:
        """提取结构特征"""
        features = {
            'elements': Counter(self.reference_symbols),
            'center_of_mass': self.reference_structure.get_center_of_mass(),
            'cell': self.reference_structure.get_cell() if hasattr(self.reference_structure, 'get_cell') else None,
            'volume': self.reference_structure.get_volume() if hasattr(self.reference_structure, 'get_volume') else None,
            'formula': self.reference_structure.get_chemical_formula()
        }
        
        # 计算配位环境（简化版）
        features['nearest_neighbor_distances'] = self._compute_nearest_neighbor_distances()
        
        return features
    
    def _compute_nearest_neighbor_distances(self) -> Dict:
        """计算最近邻距离"""
        from scipy.spatial import KDTree
        
        tree = KDTree(self.reference_positions)
        distances, _ = tree.query(self.reference_positions, k=2)  # k=2因为包含自身
        
        # 排除自身距离，取最近邻距离
        nearest_distances = distances[:, 1]
        
        # 按元素分组
        element_distances = defaultdict(list)
        for symbol, dist in zip(self.reference_symbols, nearest_distances):
            element_distances[symbol].append(dist)
        
        return {k: np.mean(v) for k, v in element_distances.items()}
    
    def match_with_rotation(self, input_structure: Atoms,
                          axis_point: Optional[np.ndarray] = None,
                          axis_direction: Optional[np.ndarray] = None,
                          angle_rad: Optional[float] = None,
                          optimize_angle: bool = True) -> Dict:
        """
        匹配输入结构，可选旋转
        
        参数:
        input_structure: 输入的ASE结构
        axis_point: 旋转轴上的点（如果为None则使用质心）
        axis_direction: 旋转轴方向（如果为None则随机方向）
        angle_rad: 旋转角度（如果为None且optimize_angle=True则自动优化）
        optimize_angle: 是否优化旋转角度
        
        返回:
        匹配结果
        """
        input_positions = input_structure.get_positions()
        input_symbols = input_structure.get_chemical_symbols()
        
        # 设置默认旋转参数
        if axis_point is None:
            axis_point = self.structure_features['center_of_mass']
        
        if axis_direction is None:
            axis_direction = np.array([1, 0, 0])  # 默认绕X轴旋转
        
        if angle_rad is None and not optimize_angle:
            angle_rad = 0.0
        
        # 检查元素类型
        self._validate_elements(input_symbols)
        
        if optimize_angle:
            # 优化旋转角度
            result = self.matcher.optimize_rotation_angle(
                input_positions, input_symbols, axis_point, axis_direction, use_elements=True
            )
        else:
            # 使用指定角度
            rotated_points = self.matcher.rotate_points_around_axis(
                input_positions, axis_point, axis_direction, angle_rad
            )
            
            aligned_points, matched_symbols, alignment_info = self.matcher.align_whole_structure(
                rotated_points, input_symbols, use_elements=True
            )
            
            similarity_result = self.matcher.compute_distance_similarity(
                self.reference_positions, aligned_points,
                self.reference_symbols, matched_symbols
            )
            
            result = {
                'best_angle_rad': angle_rad,
                'best_angle_deg': np.degrees(angle_rad),
                'best_similarity': similarity_result['weighted_similarity'],
                'aligned_points': aligned_points,
                'matched_symbols': matched_symbols,
                'similarity_result': similarity_result,
                'alignment_info': alignment_info,
                'optimization_success': True
            }
        
        # 添加结构信息
        result['reference_formula'] = self.structure_features['formula']
        result['input_formula'] = input_structure.get_chemical_formula()
        result['axis_point'] = axis_point
        result['axis_direction'] = axis_direction
        
        # 计算结构指标
        result['structure_metrics'] = self._compute_structure_metrics(
            result['aligned_points'], result['matched_symbols']
        )
        
        return result
    
    def _validate_elements(self, input_symbols: List[str]):
        """验证元素类型"""
        ref_counter = Counter(self.reference_symbols)
        input_counter = Counter(input_symbols)
        
        if ref_counter != input_counter:
            warnings.warn(f"元素类型分布不匹配: 参考={dict(ref_counter)}, 输入={dict(input_counter)}")
    
    def _compute_structure_metrics(self, aligned_points: np.ndarray, 
                                 matched_symbols: List[str]) -> Dict:
        """计算结构指标"""
        metrics = {}
        
        # 1. 键长分布
        metrics['bond_length_distribution'] = self._analyze_bond_lengths(
            aligned_points, matched_symbols
        )
        
        # 2. 键角分布（简化版）
        metrics['bond_angle_analysis'] = self._analyze_bond_angles(
            aligned_points, matched_symbols
        )
        
        # 3. 体积匹配度
        if self.structure_features['volume'] is not None:
            # 计算对齐后结构的体积
            from scipy.spatial import ConvexHull
            try:
                hull = ConvexHull(aligned_points)
                input_volume = hull.volume
                ref_volume = self.structure_features['volume']
                metrics['volume_ratio'] = input_volume / ref_volume
                metrics['volume_similarity'] = np.exp(-abs(metrics['volume_ratio'] - 1))
            except:
                metrics['volume_ratio'] = None
                metrics['volume_similarity'] = None
        
        return metrics
    
    def _analyze_bond_lengths(self, points: np.ndarray, symbols: List[str]) -> Dict:
        """分析键长分布"""
        # 使用最近邻距离作为键长
        from scipy.spatial import KDTree
        
        tree = KDTree(points)
        distances, indices = tree.query(points, k=2)  # 最近邻和次近邻
        
        results = {
            'all_bond_lengths': distances[:, 1],  # 排除自身
            'element_specific': {}
        }
        
        # 按元素类型分析
        unique_symbols = set(symbols)
        for symbol in unique_symbols:
            symbol_indices = [i for i, s in enumerate(symbols) if s == symbol]
            symbol_distances = distances[symbol_indices, 1]
            results['element_specific'][symbol] = {
                'mean': np.mean(symbol_distances),
                'std': np.std(symbol_distances),
                'min': np.min(symbol_distances),
                'max': np.max(symbol_distances)
            }
        
        return results
    
    def _analyze_bond_angles(self, points: np.ndarray, symbols: List[str]) -> Dict:
        """分析键角分布（简化版）"""
        from scipy.spatial import KDTree
        
        tree = KDTree(points)
        # 找到每个点的两个最近邻
        distances, indices = tree.query(points, k=4)  # 自身 + 3个最近邻
        
        angles = []
        for i in range(len(points)):
            # 使用三个最近邻点计算角度
            neighbors = indices[i, 1:4]  # 排除自身
            
            if len(neighbors) >= 2:
                # 取前两个最近邻
                vec1 = points[neighbors[0]] - points[i]
                vec2 = points[neighbors[1]] - points[i]
                
                # 计算夹角
                cos_angle = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
                cos_angle = np.clip(cos_angle, -1.0, 1.0)
                angle = np.degrees(np.arccos(cos_angle))
                angles.append(angle)
        
        return {
            'mean_angle': np.mean(angles) if angles else None,
            'std_angle': np.std(angles) if angles else None,
            'angles': angles
        }
    
    def compare_multiple_rotations(self, input_structure: Atoms,
                                 n_angles: int = 36,
                                 axis_point: Optional[np.ndarray] = None,
                                 axis_direction: Optional[np.ndarray] = None) -> Dict:
        """
        比较多个旋转角度
        
        参数:
        input_structure: 输入结构
        n_angles: 角度采样数量
        axis_point: 旋转轴上的点
        axis_direction: 旋转轴方向
        
        返回:
        比较结果
        """
        if axis_point is None:
            axis_point = self.structure_features['center_of_mass']
        
        if axis_direction is None:
            axis_direction = np.array([1, 0, 0])
        
        input_positions = input_structure.get_positions()
        input_symbols = input_structure.get_chemical_symbols()
        
        angles_rad = np.linspace(0, 2*np.pi, n_angles, endpoint=False)
        results = []
        
        for angle_rad in angles_rad:
            rotated_points = self.matcher.rotate_points_around_axis(
                input_positions, axis_point, axis_direction, angle_rad
            )
            
            aligned_points, matched_symbols, _ = self.matcher.align_whole_structure(
                rotated_points, input_symbols, use_elements=True
            )
            
            similarity_result = self.matcher.compute_distance_similarity(
                self.reference_positions, aligned_points,
                self.reference_symbols, matched_symbols
            )
            
            results.append({
                'angle_rad': angle_rad,
                'angle_deg': np.degrees(angle_rad),
                'similarity': similarity_result['weighted_similarity'],
                'overall_similarity': similarity_result['overall_similarity'],
                'rmse': similarity_result['overall_rmse']
            })
        
        # 找到最佳角度
        best_result = max(results, key=lambda x: x['similarity'])
        
        return {
            'all_results': results,
            'best_result': best_result,
            'best_angle_deg': best_result['angle_deg'],
            'best_similarity': best_result['similarity']
        }


def create_test_crystal_structures() -> Tuple[Atoms, Atoms]:
    """
    创建测试用的晶体结构（NaCl结构）
    """
    # NaCl晶体结构：面心立方
    lattice_constant = 5.64  # Å
    
    # 参考结构：完美的NaCl
    positions_ref = np.array([
        [0.0, 0.0, 0.0],    # Na
        [0.5, 0.5, 0.0],    # Cl
        [0.5, 0.0, 0.5],    # Cl
        [0.0, 0.5, 0.5],    # Na
        [0.0, 0.0, 0.5],    # Cl
        [0.5, 0.5, 0.5],    # Na
        [0.5, 0.0, 0.0],    # Na
        [0.0, 0.5, 0.0],    # Cl
    ]) * lattice_constant
    
    symbols_ref = ['Na', 'Cl', 'Cl', 'Na', 'Cl', 'Na', 'Na', 'Cl']
    
    # 输入结构：对参考结构进行旋转和轻微扰动
    angle = np.pi/3  # 60度
    axis_point = np.array([lattice_constant/2, lattice_constant/2, lattice_constant/2])
    axis_direction = np.array([1, 1, 1])  # 沿体对角线
    
    # 旋转矩阵
    rot = R.from_rotvec(axis_direction / np.linalg.norm(axis_direction) * angle)
    positions_input = rot.apply(positions_ref - axis_point) + axis_point
    
    # 添加随机扰动（模拟热振动）
    np.random.seed(42)
    positions_input += np.random.normal(0, 0.05, positions_input.shape)  # 0.05 Å的扰动
    
    # 打乱原子顺序
    indices = np.arange(len(positions_input))
    np.random.shuffle(indices)
    positions_input = positions_input[indices]
    symbols_input = [symbols_ref[i] for i in indices]
    
    # 创建ASE结构
    cell = np.eye(3) * lattice_constant
    
    ref_structure = Atoms(symbols=symbols_ref, positions=positions_ref, cell=cell, pbc=True)
    input_structure = Atoms(symbols=symbols_input, positions=positions_input, cell=cell, pbc=True)
    
    return ref_structure, input_structure


def main():
    """主函数示例"""
    print("="*60)
    print("晶体结构匹配示例")
    print("="*60)
    
    # 创建测试结构
    ref_structure, input_structure = create_test_crystal_structures()
    
    print("\n参考结构:")
    print(f"化学式: {ref_structure.get_chemical_formula()}")
    print(f"原子数: {len(ref_structure)}")
    print(f"元素分布: {dict(Counter(ref_structure.get_chemical_symbols()))}")
    
    print("\n输入结构:")
    print(f"化学式: {input_structure.get_chemical_formula()}")
    print(f"原子数: {len(input_structure)}")
    print(f"元素分布: {dict(Counter(input_structure.get_chemical_symbols()))}")
    print("（已旋转60度并添加随机扰动）")
    
    # 创建匹配器
    matcher = AdvancedASEMatcher(ref_structure)
    
    # 定义旋转轴
    lattice_constant = 5.64
    axis_point = np.array([lattice_constant/2, lattice_constant/2, lattice_constant/2])
    axis_direction = np.array([1, 1, 1])
    
    print(f"\n旋转参数:")
    print(f"旋转轴点: {axis_point}")
    print(f"旋转轴方向: {axis_direction}")
    
    # 方法1：自动优化旋转角度
    print("\n" + "="*60)
    print("方法1：自动优化旋转角度")
    print("="*60)
    
    result = matcher.match_with_rotation(
        input_structure,
        axis_point=axis_point,
        axis_direction=axis_direction,
        optimize_angle=True
    )
    print('!'*50)
    print(result)
    print('!'*50)
    print(f"最佳旋转角度: {result['best_angle_deg']:.2f} 度")
    print(f"加权相似性: {result['best_similarity']:.6f}")
    print(f"整体RMSE: {result['similarity_result']['overall_rmse']:.6f} Å")
    
    # 元素间距离相似性
    print("\n元素间距离相似性:")
    for key, data in result['similarity_result']['inter_element_similarities'].items():
        print(f"  {key}: {data['similarity']:.6f} (RMSE={data['rmse']:.6f} Å)")
    
    # 方法2：比较多个角度
    print("\n" + "="*60)
    print("方法2：比较多个旋转角度")
    print("="*60)
    
    multi_result = matcher.compare_multiple_rotations(
        input_structure,
        n_angles=12,
        axis_point=axis_point,
        axis_direction=axis_direction
    )
    
    print(f"最佳角度: {multi_result['best_angle_deg']:.2f} 度")
    print(f"最佳相似性: {multi_result['best_similarity']:.6f}")
    
    # 打印前3个最佳结果
    print("\n前3个最佳旋转角度:")
    sorted_results = sorted(multi_result['all_results'], key=lambda x: x['similarity'], reverse=True)
    for i, res in enumerate(sorted_results[:3]):
        print(f"  {i+1}. {res['angle_deg']:6.1f}° - 相似性: {res['similarity']:.6f}")
    
    # 结构指标
    print("\n" + "="*60)
    print("结构指标分析")
    print("="*60)
    
    metrics = result['structure_metrics']
    
    print("\n键长分析:")
    for element, data in metrics['bond_length_distribution']['element_specific'].items():
        print(f"  {element}: 平均键长={data['mean']:.3f} Å, 标准差={data['std']:.3f} Å")
    
    if metrics['volume_ratio'] is not None:
        print(f"\n体积匹配: 输入/参考 = {metrics['volume_ratio']:.4f}")
        print(f"体积相似性: {metrics['volume_similarity']:.6f}")
    
    if metrics['bond_angle_analysis']['mean_angle'] is not None:
        print(f"\n键角分析: 平均角度={metrics['bond_angle_analysis']['mean_angle']:.1f}°, "
              f"标准差={metrics['bond_angle_analysis']['std_angle']:.1f}°")
    
    # 详细相似性报告
    print("\n" + "="*60)
    print("详细相似性报告")
    print("="*60)
    
    sim_result = result['similarity_result']
    print(f"整体相似性: {sim_result['overall_similarity']:.6f}")
    print(f"距离矩阵相似性: {sim_result.get('distance_matrix_similarity', 'N/A'):.6f}")
    print(f"加权相似性: {sim_result['weighted_similarity']:.6f}")
    
    # 匹配质量评估
    print(f"\n匹配质量评估:")
    print(f"  平均匹配距离: {result['alignment_info']['avg_distance']:.6f} Å")
    print(f"  最大原子位移: {sim_result['overall_max_dist']:.6f} Å")
    
    # 判断是否匹配成功
    threshold = 0.8
    if sim_result['weighted_similarity'] > threshold:
        print(f"\n✅ 匹配成功！相似性高于阈值 ({threshold})")
    else:
        print(f"\n⚠️  匹配质量一般，相似性低于阈值 ({threshold})")


if __name__ == "__main__":
    main()