In [None]:
import torch
import ase
import mace
from xtb.ase.calculator import XTB
import numpy as np
import os
from ase import Atoms
from ase.io import read, write
from ase.neighborlist import neighbor_list
from ase.md.langevin import Langevin
from ase import units
from mace.calculators import MACECalculator
import pickle

In [None]:
atoms = read('Li_Li6PS5Cl_interface.xyz')

In [None]:
def detect_reaction_event(atoms, verbose=False):
    # Get chemical symbols and positions
    symbols = atoms.get_chemical_symbols()
    positions = atoms.get_positions()

    # Build neighbor list within 4.0 Å cutoff
    i_list, j_list, d_list = neighbor_list('ijd', atoms, cutoff=4.0)

    # Initialize counters and flags
    li_s_bond_count = 0
    p_s_broken = False
    cl_moved = False

    # Analyze bonds
    for i, j, d in zip(i_list, j_list, d_list):
        s_i, s_j = symbols[i], symbols[j]
        if {s_i, s_j} == {'Li', 'S'}:
            if d < 2.4: li_s_bond_count += 1 # Li-S bond
        elif {s_i, s_j} == {'P', 'S'}:
            if d > 2.7: p_s_broken = True # P-S bond broken

    # Check Cl position relative to P and S atoms
    cl_indices = [i for i, s in enumerate(symbols) if s == 'Cl']
    if cl_indices:
        cl_pos = positions[cl_indices[0]]
        ps_positions = [positions[i] for i, s in enumerate(symbols) if s in ['P', 'S']]
        min_dist_to_ps = min(np.linalg.norm(cl_pos - p) for p in ps_positions)
        if min_dist_to_ps > 4.0:
            cl_moved = True
    
    is_reactive = (li_s_bond_count >= 3) or p_s_broken or cl_moved
    if verbose and is_reactive: print(f"Li-S bonds: {li_s_bond_count}, P-S broken: {p_s_broken}, Cl moved: {cl_moved}")
def get_reactive_indices(atoms):
    symbols = np.array(atoms.get_chemical_symbols())
    positions = atoms.get_positions()
    if cutoff_dict is None:
        cutoff_dict = {
            ('Li', 'S'): (2.0, 3.0), # Li-S bond
            ('P', 'S'): (1.8, 2.6), # P-S bond
            ('S', 'Cl'): (2.2, 3.5), # S-Cl interaction
            ('Li', 'P'): (2.2, 3.2), # Li-P interaction
        }
    reactive_mask = np.zeros(len(atoms), dtype=bool)

    i_list, j_list, d_list = neighbor_list('ijd', atoms, cutoff=5.0)

In [None]:
MACE_MODEL = '/home/netszx/models/2024-01-07-mace-128-L2_epoch-199.model'
mace_calc = MACECalculator(model_paths = MACE_MODEL, device="cuda" if torch.cuda.is_available() else "cpu", default_dtype="float64")
xtb_calc = XTB(method="GFN2-xTB")

class HybridCalculator:
    def __init__(self, mace_calc, xtb_calc, R_inner=2.5, R_outer=4.5):
        self.mlp_calc = mace_calc
        self.xtb_calc = xtb_calc
        self.energy_offset = self.energy_alignment(atoms)
        self.R_inner = R_inner
        self.R_outer = R_outer
        self.cluster_atoms, self.cluster_indicies = self.cluster(atoms)

    def cluster(self, atoms, reactive_indices):
        from ase import neighborlist
        nl = neighborlist.NeighborList(self.R_outer + 1.0, skin=0.0, self_interaction=False).update(atoms)
        cluster_mask = np.zeros(len(atoms), dtype=bool)
        for i in reactive_indices:
            indices, _ = nl.get_neighbors(i)
            cluster_mask[indices] = True
        cluster_mask[reactive_indices] = True
        cluster_indices = np.where(cluster_mask)[0]
        
        cluster_atoms = atoms[cluster_indices]
        return cluster_atoms, cluster_indices
    
    
    def energy_alignment(self, atoms):
        atoms_ref = atoms.copy()
        atoms_ref.calc = self.mlp_calc
        E_mlp_ref = atoms_ref.get_potential_energy()

        atoms_ref_xtb = atoms_ref.copy()
        atoms_ref_xtb.calc = self.xtb_calc
        E_xtb_ref = atoms_ref_xtb.get_potential_energy()
        return E_xtb_ref - E_mlp_ref

    def get_potential_energy(self, atoms, reactive_indices):
        atoms.calc = self.mlp_calc
        E_mlp = atoms.get_potential_energy()

        if len(reactive_indices) == 0: return E_mlp

        E_xtb_cluster = self.cluster_atoms.get_potential_energy() - self.energy_offset
        E_final = self._blend_energy(atoms, E_mlp, E_xtb_cluster, reactive_indices)
        return E_final

    def get_forces(self, atoms, reactive_indices):
        atoms.calc = self.mlp_calc
        F_mlp = atoms.get_forces()

        if len(reactive_indices) == 0: return F_mlp
        self.cluster_atoms.calc = self.xtb_calc
        
        F_xtb_cluster = self.cluster_atoms.get_forces()
        F_xtb_global = F_mlp.copy()
        for local_i, global_i in enumerate(self.cluster_indices): F_xtb_global[global_i] = F_xtb_cluster[local_i]
        F_final = self._blend_forces(atoms, F_mlp, F_xtb_global, reactive_indices)

        return F_final

    def _blend_forces(self, atoms, F_mlp, F_xtb, reactive_indices):
        F_out = F_mlp.copy()
        pos = atoms.positions
        for i in range(len(atoms)):
            if len(reactive_indices) == 0:
                continue
            dists = np.linalg.norm(pos[i] - pos[reactive_indices], axis=1)
            r = np.min(dists)
            if r <= self.R_inner:
                w = 1.0
            elif r >= self.R_outer:
                w = 0.0
            else:
                w = 0.5 * (1 + np.cos(np.pi * (r - self.R_inner) / (self.R_outer - self.R_inner)))
            F_out[i] = w * F_xtb[i] + (1 - w) * F_mlp[i]
        return F_out
    
ACTIVE_LEARNING_FILE = "active_learning_pool.pkl"
if os.path.exists(ACTIVE_LEARNING_FILE):
    with open(ACTIVE_LEARNING_FILE, 'rb') as f:
        active_learning_pool = pickle.load(f)
    print(f"Loaded {len(active_learning_pool)} samples from active learning pool.")
else:
    active_learning_pool = []
    print("Initialized empty active learning pool.")

In [None]:
MAX_STEPS = 10000      # Maximum MD steps
CHECK_INTERVAL = 100   # Check for reactions every N steps
ERROR_THRESHOLD_ENERGY = 0.1   # eV
ERROR_THRESHOLD_FORCE = 0.1    # eV/Å
FINETUNE_EVERY_N_SAMPLES = 30  # Number of samples to trigger fine-tuning
HYBRID_CALC = HybridCalculator(mace_calc, xtb_calc)

atoms.calc = mace_calc
dyn = Langevin(
    atoms,
    timestep=1.0 * units.fs,
    temperature_K=300,
    friction=0.01 / units.fs,
    trajectory='adaptive_mlmd.traj'
)

print("Starting adaptive ML-MD...")

for step in range(MAX_STEPS):
    dyn.run(steps=1)
    
    if step % CHECK_INTERVAL == 0:
        reactive_indices = get_reactive_indices(atoms)
        if detect_reaction_event(atoms):
            print(f"\n Reaction detected at step {step}")
            # Compute energies and forces
            E_hybrid, F_hybrid = HYBRID_CALC.get_energy_and_forces(atoms, reactive_indices)

            E_mlp = atoms.get_potential_energy()
            F_mlp = atoms.get_forces()
            
            dE = abs(E_hybrid- E_mlp)
            dF = np.mean(np.linalg.norm(F_hybrid - F_mlp, axis=1))
            print(f"dE = {dE:.4f} eV, dF = {dF:.4f} eV/Å")

            # Check if errors exceed thresholds
            if dE > ERROR_THRESHOLD_ENERGY or dF > ERROR_THRESHOLD_FORCE:
                sample = {'atoms': atoms.copy(), 'energy': E_hybrid, 'forces': F_hybrid.copy(), 'step': step}
                active_learning_pool.append(sample)
                print("Added to active learning pool")

                with open(ACTIVE_LEARNING_FILE, 'wb') as f:
                    pickle.dump(active_learning_pool, f)

                # Check if we need to trigger fine-tuning
                if len(active_learning_pool) >= FINETUNE_EVERY_N_SAMPLES:
                    print("\n Triggering MACE fine-tuning...")
                    print("Please run Cell 6 to fine-tune the model.")
                    break
print("Adaptive MD finished.")