In [1]:
import os
from pathlib import Path

import numpy as np
import pandas as pd
from ase import Atom, Atoms
from ase.data import chemical_symbols, covalent_radii, vdw_alvarez
from ase.io import read, write
from pymatgen.core import Element
from scipy import stats
from tqdm.auto import tqdm

from mlip_arena.models.utils import REGISTRY, MLIPEnum

# model_name = "MACE-MP(M)"

# calc = MLIPEnum[model_name].value()

In [None]:
for model in MLIPEnum:
    
    model_name = model.name
    
    if model_name == 'EquiformerV2(OC22)':
        continue
    
    json_fpath = Path(REGISTRY[model_name]["family"]) / "homonuclear-diatomics.json"
    
    if json_fpath.exists():
        continue
    
    print(f"========== {model_name} ==========")

    calc = MLIPEnum[model_name].value()

    for symbol in tqdm(chemical_symbols[1:]):

        s = set([symbol])

        if "X" in s:
            continue

        try:
            atom = Atom(symbol)
            rmin = 0.9 * covalent_radii[atom.number]
            rvdw = vdw_alvarez.vdw_radii[atom.number] if atom.number < len(vdw_alvarez.vdw_radii) else np.nan
            rmax = 3.1 * rvdw if not np.isnan(rvdw) else 6
            rstep = 0.01

            a = 2 * rmax

            npts = int((rmax - rmin)/rstep)

            rs = np.linspace(rmin, rmax, npts)
            es = np.zeros_like(rs)

            da = symbol + symbol

            out_dir = Path(REGISTRY[model_name]["family"]) / str(da)
            os.makedirs(out_dir, exist_ok=True)

            skip = 0

            element = Element(symbol)

            try:
                m = element.valence[1]
                if element.valence == (0, 2):
                    m = 0
            except:
                m = 0


            r = rs[0]

            positions = [
                [a/2-r/2, a/2, a/2],
                [a/2+r/2, a/2, a/2],
            ]

            traj_fpath = out_dir / f"{model_name}.extxyz"

            if traj_fpath.exists():
                traj = read(traj_fpath, index=":")
                skip = len(traj)
                atoms = traj[-1]
            else:
                # Create the unit cell with two atoms
                atoms = Atoms(
                    da,
                    positions=positions,
                    # magmoms=magmoms,
                    cell=[a, a+0.001, a+0.002],
                    pbc=True
                )

            print(atoms)

            atoms.calc = calc

            for i, r in enumerate(tqdm(rs)):

                if i < skip:
                    continue

                positions = [
                    [a/2-r/2, a/2, a/2],
                    [a/2+r/2, a/2, a/2],
                ]

                # atoms.set_initial_magnetic_moments(magmoms)

                atoms.set_positions(positions)

                es[i] = atoms.get_potential_energy()

                write(traj_fpath, atoms, append="a")
        except Exception as e:
            print(e)




  "mean": torch.tensor(state_dict["mean"]),


  0%|          | 0/118 [00:00<?, ?it/s]

Atoms(symbols='H2', pbc=True, cell=[7.4399999999999995, 7.441, 7.441999999999999])


  0%|          | 0/344 [00:00<?, ?it/s]

Atoms(symbols='He2', pbc=True, cell=[8.866, 8.866999999999999, 8.868])


  0%|          | 0/418 [00:00<?, ?it/s]

Atoms(symbols='Li2', pbc=True, cell=[13.144000000000002, 13.145000000000001, 13.146000000000003])


  0%|          | 0/542 [00:00<?, ?it/s]

Atoms(symbols='Be2', pbc=True, cell=[12.276, 12.277, 12.278])


  0%|          | 0/527 [00:00<?, ?it/s]

Atoms(symbols='B2', pbc=True, cell=[11.842, 11.843, 11.844000000000001])


  0%|          | 0/516 [00:00<?, ?it/s]

Atoms(symbols='C2', pbc=True, cell=[10.974, 10.975, 10.976])


  0%|          | 0/480 [00:00<?, ?it/s]

Atoms(symbols='N2', pbc=True, cell=[10.292, 10.293, 10.294])


  0%|          | 0/450 [00:00<?, ?it/s]

Atoms(symbols='O2', pbc=True, cell=[9.3, 9.301, 9.302000000000001])


  0%|          | 0/405 [00:00<?, ?it/s]

Atoms(symbols='F2', pbc=True, cell=[9.052, 9.052999999999999, 9.054])


  0%|          | 0/401 [00:00<?, ?it/s]

Atoms(symbols='Ne2', pbc=True, cell=[9.796000000000001, 9.797, 9.798000000000002])


  0%|          | 0/437 [00:00<?, ?it/s]

Atoms(symbols='Na2', pbc=True, cell=[15.5, 15.501, 15.502])


  0%|          | 0/625 [00:00<?, ?it/s]

Atoms(symbols='Mg2', pbc=True, cell=[15.562, 15.562999999999999, 15.564])


  0%|          | 0/651 [00:00<?, ?it/s]

Atoms(symbols='Al2', pbc=True, cell=[13.950000000000001, 13.951, 13.952000000000002])


  0%|          | 0/588 [00:00<?, ?it/s]

Atoms(symbols='Si2', pbc=True, cell=[13.578, 13.578999999999999, 13.58])


  0%|          | 0/578 [00:00<?, ?it/s]

Atoms(symbols='P2', pbc=True, cell=[11.78, 11.780999999999999, 11.782])


  0%|          | 0/492 [00:00<?, ?it/s]

Atoms(symbols='S2', pbc=True, cell=[11.718, 11.719, 11.72])


  0%|          | 0/491 [00:00<?, ?it/s]

Atoms(symbols='Cl2', pbc=True, cell=[11.284, 11.285, 11.286000000000001])


  0%|          | 0/472 [00:00<?, ?it/s]

Atoms(symbols='Ar2', pbc=True, cell=[11.346, 11.347, 11.348])


  0%|          | 0/471 [00:00<?, ?it/s]

Atoms(symbols='K2', pbc=True, cell=[16.926000000000002, 16.927000000000003, 16.928])


  0%|          | 0/663 [00:00<?, ?it/s]

Atoms(symbols='Ca2', pbc=True, cell=[16.244, 16.245, 16.246])


  0%|          | 0/653 [00:00<?, ?it/s]

Atoms(symbols='Sc2', pbc=True, cell=[15.996, 15.997, 15.998000000000001])


  0%|          | 0/646 [00:00<?, ?it/s]

Atoms(symbols='Ti2', pbc=True, cell=[15.252, 15.253, 15.254000000000001])


  0%|          | 0/618 [00:00<?, ?it/s]

Atoms(symbols='V2', pbc=True, cell=[15.004, 15.004999999999999, 15.006])


  0%|          | 0/612 [00:00<?, ?it/s]

Atoms(symbols='Cr2', pbc=True, cell=[15.190000000000001, 15.191, 15.192000000000002])


  0%|          | 0/634 [00:00<?, ?it/s]

Atoms(symbols='Mn2', pbc=True, cell=[15.190000000000001, 15.191, 15.192000000000002])


  0%|          | 0/634 [00:00<?, ?it/s]

Atoms(symbols='Fe2', pbc=True, cell=[15.128, 15.129, 15.13])


  0%|          | 0/637 [00:00<?, ?it/s]

Atoms(symbols='Co2', pbc=True, cell=[14.879999999999999, 14.880999999999998, 14.882])


  0%|          | 0/630 [00:00<?, ?it/s]

Atoms(symbols='Ni2', pbc=True, cell=[14.879999999999999, 14.880999999999998, 14.882])


  0%|          | 0/632 [00:00<?, ?it/s]

Atoms(symbols='Cu2', pbc=True, cell=[14.756, 14.757, 14.758000000000001])


  0%|          | 0/618 [00:00<?, ?it/s]

Atoms(symbols='Zn2', pbc=True, cell=[14.818000000000001, 14.819, 14.820000000000002])


  0%|          | 0/631 [00:00<?, ?it/s]

Atoms(symbols='Ga2', pbc=True, cell=[14.383999999999999, 14.384999999999998, 14.386])


  0%|          | 0/609 [00:00<?, ?it/s]

Atoms(symbols='Ge2', pbc=True, cell=[14.198, 14.199, 14.200000000000001])


  0%|          | 0/601 [00:00<?, ?it/s]

Atoms(symbols='As2', pbc=True, cell=[11.655999999999999, 11.656999999999998, 11.658])


  0%|          | 0/475 [00:00<?, ?it/s]

Atoms(symbols='Se2', pbc=True, cell=[11.284, 11.285, 11.286000000000001])


  0%|          | 0/456 [00:00<?, ?it/s]

Atoms(symbols='Br2', pbc=True, cell=[11.532000000000002, 11.533000000000001, 11.534000000000002])


  0%|          | 0/468 [00:00<?, ?it/s]

Atoms(symbols='Kr2', pbc=True, cell=[13.950000000000001, 13.951, 13.952000000000002])


  0%|          | 0/593 [00:00<?, ?it/s]

Atoms(symbols='Rb2', pbc=True, cell=[19.902, 19.903000000000002, 19.904])


  0%|          | 0/797 [00:00<?, ?it/s]

Atoms(symbols='Sr2', pbc=True, cell=[17.608, 17.609, 17.61])


  0%|          | 0/704 [00:00<?, ?it/s]

Atoms(symbols='Y2', pbc=True, cell=[17.05, 17.051000000000002, 17.052])


  0%|          | 0/681 [00:00<?, ?it/s]

Atoms(symbols='Zr2', pbc=True, cell=[15.624, 15.625, 15.626000000000001])


  0%|          | 0/623 [00:00<?, ?it/s]

Atoms(symbols='Nb2', pbc=True, cell=[15.872000000000002, 15.873000000000001, 15.874000000000002])


  0%|          | 0/646 [00:00<?, ?it/s]

Atoms(symbols='Mo2', pbc=True, cell=[15.190000000000001, 15.191, 15.192000000000002])


  0%|          | 0/620 [00:00<?, ?it/s]

In [3]:


for model in MLIPEnum:
    
    model_name = model.name
    
    print(f"========== {model_name} ==========")
    
    df = pd.DataFrame(columns=[
        "name", 
        "method", 
        "R", "E", "F", "S^2", 
        "force-flip-times",
        "force-total-variation",
        "energy-diff-flip-times",
        "energy-grad-norm-max",
        "energy-jump",
        "energy-total-variation",
        "conservation-deviation",
        "spearman-descending-force",
        "spearman-ascending-force",
        "spearman-repulsion-energy",
        "spearman-attraction-energy"
    ])
    

    for symbol in tqdm(chemical_symbols[1:]):

        da = symbol + symbol

        out_dir = Path(REGISTRY[model_name]["family"]) / da

        traj_fpath = out_dir / f"{model_name}.extxyz"


        if traj_fpath.exists():
            traj = read(traj_fpath, index=":")
        else:
            continue

        Rs, Es, Fs, S2s = [], [], [], []
        for atoms in traj:

            vec = atoms.positions[1] - atoms.positions[0]
            r = np.linalg.norm(vec)
            e = atoms.get_potential_energy()
            f = np.inner(vec/r, atoms.get_forces()[1])
            # s2 = np.mean(np.power(atoms.get_magnetic_moments(), 2))

            Rs.append(r)
            Es.append(e)
            Fs.append(f)
            # S2s.append(s2)

        rs = np.array(Rs)
        es = np.array(Es)
        fs = np.array(Fs)

        indices = np.argsort(rs)[::-1]
        rs = rs[indices]
        es = es[indices]
        fs = fs[indices]

        iminf = np.argmin(fs)
        imine = np.argmin(es)

        de_dr = np.gradient(es, rs)
        d2e_dr2 = np.gradient(de_dr, rs)

        rounded_fs = np.copy(fs)
        rounded_fs[np.abs(rounded_fs) < 1e-2] = 0
        fs_sign = np.sign(rounded_fs)
        fs_sign = fs_sign[fs_sign != 0]

        # rounded_ediff = np.diff(es)
        # rounded_ediff[np.abs(rounded_ediff) < zero_threshold] = 0
        ediff = np.diff(es)
        ediff[np.abs(ediff) < 1e-3] = 0
        ediff_sign = np.sign(ediff)
        mask = ediff_sign != 0
        ediff = ediff[mask]
        ediff_sign = ediff_sign[mask]
        ediff_flip = np.diff(ediff_sign) != 0
        ejump = np.abs(ediff[:-1][ediff_flip]).sum() + np.abs(ediff[1:][ediff_flip]).sum()

        conservation_deviation = np.mean(np.abs(fs + de_dr))

        data = {
            "name": da,
            "method": model_name,
            "R": rs,
            "E": es,
            "F": fs,
            "S^2": S2s,
            "force-flip-times": np.sum(np.diff(fs_sign)!=0),
            "force-total-variation": np.sum(np.abs(np.diff(fs))),
            "energy-diff-flip-times": np.sum(ediff_flip),
            "energy-grad-norm-max": np.max(np.abs(de_dr)),
            "energy-jump": ejump,
            # "energy-grad-norm-mean": np.mean(de_dr_abs),
            "energy-total-variation": np.sum(np.abs(np.diff(es))),
            "conservation-deviation": conservation_deviation,
            "spearman-descending-force": stats.spearmanr(rs[iminf:], fs[iminf:]).statistic,
            "spearman-ascending-force": stats.spearmanr(rs[:iminf], fs[:iminf]).statistic,
            "spearman-repulsion-energy": stats.spearmanr(rs[imine:], es[imine:]).statistic,
            "spearman-attraction-energy": stats.spearmanr(rs[:imine], es[:imine]).statistic,
        }

        df = pd.concat([df, pd.DataFrame([data])], ignore_index=True)


    json_fpath = Path(REGISTRY[model_name]["family"]) / "homonuclear-diatomics.json"

    if json_fpath.exists():
        df0 = pd.read_json(json_fpath)
        df = pd.concat([df0, df], ignore_index=True)
        df.drop_duplicates(inplace=True, subset=["name", "method"], keep='last')

    df.to_json(json_fpath, orient="records")



  0%|          | 0/118 [00:00<?, ?it/s]

  df = pd.concat([df, pd.DataFrame([data])], ignore_index=True)
  "spearman-descending-force": stats.spearmanr(rs[iminf:], fs[iminf:]).statistic,
  "spearman-repulsion-energy": stats.spearmanr(rs[imine:], es[imine:]).statistic,




  0%|          | 0/118 [00:00<?, ?it/s]

  df = pd.concat([df, pd.DataFrame([data])], ignore_index=True)




  0%|          | 0/118 [00:00<?, ?it/s]

  df = pd.concat([df, pd.DataFrame([data])], ignore_index=True)
  "spearman-ascending-force": stats.spearmanr(rs[:iminf], fs[:iminf]).statistic,




  0%|          | 0/118 [00:00<?, ?it/s]

  df = pd.concat([df, pd.DataFrame([data])], ignore_index=True)




  0%|          | 0/118 [00:00<?, ?it/s]

  df = pd.concat([df, pd.DataFrame([data])], ignore_index=True)
  "spearman-attraction-energy": stats.spearmanr(rs[:imine], es[:imine]).statistic,




  0%|          | 0/118 [00:00<?, ?it/s]

  df = pd.concat([df, pd.DataFrame([data])], ignore_index=True)
  "spearman-repulsion-energy": stats.spearmanr(rs[imine:], es[imine:]).statistic,
  "spearman-attraction-energy": stats.spearmanr(rs[:imine], es[:imine]).statistic,




  0%|          | 0/118 [00:00<?, ?it/s]



  df = pd.concat([df0, df], ignore_index=True)


  0%|          | 0/118 [00:00<?, ?it/s]



  0%|          | 0/118 [00:00<?, ?it/s]



  0%|          | 0/118 [00:00<?, ?it/s]

OSError: Cannot save file into a non-existent directory: 'alignn'

In [12]:
df

Unnamed: 0,name,method,R,E,F,S^2,force-flip-times,force-total-variation,energy-diff-flip-times,energy-grad-norm-max,energy-jump,energy-total-variation,conservation-deviation,spearman-descending-force,spearman-ascending-force,spearman-repulsion-energy,spearman-attraction-energy
89,HH,MACE-MP(M),"[3.7199999999999998, 3.70996794, 3.69993586, 3...","[-2.348365306854248, -2.3483591079711914, -2.3...","[1.7e-07, 0.00035892, 0.00068535, 0.00112034, ...",[],3,80.309463,2,65.875268,0.009835,15.816339,0.043480,-0.999421,0.840215,-1.000000,0.794565
90,HeHe,MACE-MP(M),"[4.433, 4.4229736200000005, 4.41294724, 4.4029...","[0.022882699966430664, 0.022882461547851562, 0...","[0.0, -0.00010788, -0.00021458, -0.00032269, -...",[],2,1.485210,0,0.821236,0.000000,0.344755,0.001872,-0.904433,0.077295,-1.000000,0.093575
91,LiLi,MACE-MP(M),"[6.572000000000001, 6.561981520000001, 6.55196...","[-1.8418288230895996, -1.8418288230895996, -1....","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",[],1,8.583572,1,8.122311,0.002224,5.332280,0.013703,-0.999858,0.993021,-0.999996,0.995153
92,BeBe,MACE-MP(M),"[6.138000000000001, 6.12797338, 6.117946759999...","[-0.6769685745239258, -0.6769685745239258, -0....","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",[],2,29.133944,2,10.346194,0.016283,6.239225,0.022079,,-0.044580,-0.998691,0.999892
93,BB,MACE-MP(M),"[5.921000000000001, 5.91097088, 5.900941739999...","[-1.404820442199707, -1.404820442199707, -1.40...","[-6e-08, -0.00011319, -0.00023184, -0.0003534,...",[],2,330.078666,2,123.581618,0.543755,40.039458,0.138081,,0.201590,-0.999921,0.999947
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
173,ThTh,MACE-MP(M),"[9.082999999999998, 9.07297364, 9.062947300000...","[-4.095590591430664, -4.095590591430664, -4.09...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",[],2,16.226657,2,6.413536,0.006823,3.392294,0.024297,,0.026348,-0.997169,0.918191
174,PaPa,MACE-MP(M),"[8.927999999999999, 8.91797468, 8.90794936, 8....","[-8.707170486450195, -8.707170486450195, -8.70...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",[],2,6.407403,2,2.536645,0.006878,2.349456,0.018977,-0.627256,0.460505,,0.950999
175,UU,MACE-MP(M),"[8.401, 8.390974320000002, 8.38094864, 8.37092...","[-11.437246322631836, -11.437246322631836, -11...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",[],1,34.878494,1,25.824510,0.004967,11.734995,0.025810,-0.974165,0.939272,-1.000000,0.966465
176,NpNp,MACE-MP(M),"[8.741999999999999, 8.7319829, 8.72196582, 8.7...","[-18.34619903564453, -18.34619903564453, -18.3...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",[],3,14.386345,3,8.632403,0.010139,2.824085,0.015532,-1.000000,0.189479,-1.000000,0.439673
