In [1]:
from ase import Atoms
from ase.calculators.calculator import Calculator, all_changes
from ase.neighborlist import neighbor_list
from mace.calculators import MACECalculator
from xtb.ase.calculator import XTB
import torch
import numpy as np
from ase.io import read
from mp_api.client import MPRester

  _Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))


cuequivariance or cuequivariance_torch is not available. Cuequivariance acceleration will be disabled.


In [19]:
API = "6UhsVQseZ2tkPsqxcMyWnvh82v1fmthq"
with MPRester(API) as mpr:
    docs = mpr.materials.search(formula='Li6PS5Cl')
    for doc in docs:
        structure = doc.structure
        cif_path = "Li6PS5Cl.cif"
        structure.to(filename=str(cif_path), fmt='cif')

Retrieving MaterialsDoc documents:   0%|          | 0/1 [00:00<?, ?it/s]

In [2]:
def switching_function(r, R_inner=3.0, R_outer=5.0):
    """
    平滑切换函数：r < R_inner → 1.0 (xTB), r > R_outer → 0.0 (MLP)
    中间区域平滑插值
    """
    if r <= R_inner:
        return 1.0
    elif r >= R_outer:
        return 0.0
    else:
        x = (r - R_inner) / (R_outer - R_inner)
        # 使用cosine平滑 or cubic spline
        return 0.5 * (1 + np.cos(np.pi * x))
def detect_reactive_atoms(atoms, cutoff_dict=None):
    if cutoff_dict is None:
        cutoff_dict = {
            ('Li', 'S'): 2.0,   # 正常 ~2.5Å，>3.0 可能断键
            ('P', 'S'): 2.5,    # 正常 ~2.1Å
            ('S', 'Cl'): 3.2,   # 若出现异常接近，可能反应
        }
    
    reactive_mask = np.zeros(len(atoms), dtype=bool)
    i_list, j_list, d_list = neighbor_list('ijd', atoms, cutoff=4.0)
    
    for i, j, d in zip(i_list, j_list, d_list):
        sym_i, sym_j = atoms[i].symbol, atoms[j].symbol
        key = tuple(sorted([sym_i, sym_j]))
        if key in cutoff_dict and d > cutoff_dict[key] * 1.2:  # 超过20%阈值
            reactive_mask[i] = reactive_mask[j] = True
    return reactive_mask
class SmoothHybridCalculator(Calculator):
    implemented_properties = ['energy', 'forces', 'stress']
    def __init__(self, mlp_calc, xtb_calc, 
                 R_inner=3.0, R_outer=5.0,
                 reactive_detector=None,
                 device='cpu'):
        super().__init__()
        self.mlp_calc = mlp_calc
        self.xtb_calc = xtb_calc
        self.R_inner = R_inner
        self.R_outer = R_outer
        self.reactive_detector = reactive_detector or detect_reactive_atoms
        self.device = device
    
    def calculate(self, atoms=None, properties=['energy'], system_changes=all_changes):
        super().calculate(atoms, properties, system_changes)

        # Step 1: 用 MLP 计算全局能量和力
        atoms.set_calculator(self.mlp_calc)
        E_mlp = atoms.get_potential_energy()
        F_mlp = atoms.get_forces()
        if 'stress' in properties:
            S_mlp = atoms.get_stress(voigt=False)

        # Step 2: 检测反应原子
        reactive_mask = self.reactive_detector(atoms)  # bool array (N,)
        reactive_indices = np.where(reactive_mask)[0]

        if len(reactive_indices) == 0:
            self.results = {'energy': E_mlp, 'forces': F_mlp}
            if 'stress' in properties:
                self.results['stress'] = S_mlp
            return

        # Step 3: 构建 xTB 计算的局部超胞（反应原子 + 缓冲区）
        # 获取反应原子邻域（缓冲区原子）
        i_list, j_list, d_list = neighbor_list('ijd', atoms, cutoff=self.R_outer + 1.0)
        buffer_mask = np.zeros(len(atoms), dtype=bool)
        for i in reactive_indices:
            neighbors = j_list[i_list == i]
            buffer_mask[neighbors] = True
        buffer_mask[reactive_indices] = True  # 包含反应原子自身
        cluster_indices = np.where(buffer_mask)[0]

        # 构建局部结构（带最小周期镜像）
        cluster_atoms = atoms[cluster_indices]
        # 注意：如需周期性，建议用 atoms.get_cell().create_extended_atoms(...) 或 ASE 的 cut()

        # Step 4: 用 xTB 计算局部簇
        cluster_atoms.set_calculator(self.xtb_calc)
        try:
            E_xtb_cluster = cluster_atoms.get_potential_energy()
            F_xtb_cluster = cluster_atoms.get_forces()  # shape (M, 3)
        except Exception as e:
            print("xTB failed, falling back to MLP:", e)
            self.results = {'energy': E_mlp, 'forces': F_mlp}
            return

        # Step 5: 平滑混合力和能量
        F_final = F_mlp.copy()
        E_delta = 0.0

        # 对每个缓冲区原子，计算其到最近反应原子的距离，应用切换权重
        positions = atoms.get_positions()
        for idx in cluster_indices:
            # 找到该原子到所有反应原子的最小距离
            dists = np.linalg.norm(positions[idx] - positions[reactive_indices], axis=1)
            r_min = np.min(dists) if len(dists) > 0 else 0.0
            w = switching_function(r_min, self.R_inner, self.R_outer)

            if w > 1e-6:  # 避免数值误差
                local_idx = np.where(cluster_indices == idx)[0][0]  # 在cluster中的索引
                F_interp = w * F_xtb_cluster[local_idx] + (1 - w) * F_mlp[idx]
                F_final[idx] = F_interp
                # 能量校正：按权重分配局部能量差（简化处理）
                E_delta += w * (F_xtb_cluster[local_idx] - F_mlp[idx]) @ (positions[idx] - positions.mean(axis=0)) * 0.1

        # 能量：MLP全局 + 局部校正（简化模型，也可用簇能量差加权分配）
        E_final = E_mlp + E_delta

        self.results = {'energy': E_final, 'forces': F_final}
        if 'stress' in properties:
            # 应力暂不混合（或按体积比例混合），通常MLP应力足够
            self.results['stress'] = S_mlp

In [3]:
mlp_calc = MACECalculator(
    model_paths= "/home/netszx/models/2024-01-07-mace-128-L2_epoch-199.model",
    device='cuda' if torch.cuda.is_available() else 'cpu',
    default_dtype='float64'
)
xtb_calc = XTB(method='GFN2-xTB', accuracy='normal')
hybrid_calc = SmoothHybridCalculator(
    mlp_calc=mlp_calc,
    xtb_calc=xtb_calc,
    R_inner=3.0,    # 3Å 内完全用 xTB
    R_outer=5.0,    # 5Å 外完全用 MLP
    reactive_detector=detect_reactive_atoms
)

atoms = read("Li6PS5Cl.cif")

atoms.set_calculator(hybrid_calc)

from ase.optimize import BFGS
dyn = BFGS(atoms)
dyn.run(fmax=0.05)

  torch.load(f=model_path, map_location=device)
  atoms.set_calculator(hybrid_calc)
  atoms.set_calculator(self.mlp_calc)


Using head Default out of ['Default']
xTB failed, falling back to MLP: must be real number, not str
      Step     Time          Energy          fmax
BFGS:    0 13:58:47      -53.574308        0.027425


  cluster_atoms.set_calculator(self.xtb_calc)


np.True_