In [2]:
import os
from pathlib import Path

import numpy as np
import pandas as pd
from ase import Atom, Atoms
from ase.data import chemical_symbols, covalent_radii, vdw_alvarez
from ase.io import read, write
from pymatgen.core import Element
from scipy import stats
from tqdm.auto import tqdm

from mlip_arena.models.utils import REGISTRY, MLIPEnum

# model_name = "MACE-MP(M)"

# calc = MLIPEnum[model_name].value()

In [3]:
for model in MLIPEnum:
    
    model_name = model.name
    
    json_fpath = Path(REGISTRY[model_name]["family"]) / "homonuclear-diatomics.json"
    
    if json_fpath.exists():
        continue
    
    print(f"========== {model_name} ==========")

    calc = MLIPEnum[model_name].value()

    for symbol in tqdm(chemical_symbols[1:]):

        s = set([symbol])

        if "X" in s:
            continue

        try:
            atom = Atom(symbol)
            rmin = 0.9 * covalent_radii[atom.number]
            rvdw = vdw_alvarez.vdw_radii[atom.number] if atom.number < len(vdw_alvarez.vdw_radii) else np.nan
            rmax = 3.1 * rvdw if not np.isnan(rvdw) else 6
            rstep = 0.01

            a = 2 * rmax

            npts = int((rmax - rmin)/rstep)

            rs = np.linspace(rmin, rmax, npts)
            es = np.zeros_like(rs)

            da = symbol + symbol

            out_dir = Path(REGISTRY[model_name]["family"]) / str(da)
            os.makedirs(out_dir, exist_ok=True)

            skip = 0

            element = Element(symbol)

            try:
                m = element.valence[1]
                if element.valence == (0, 2):
                    m = 0
            except:
                m = 0


            r = rs[0]

            positions = [
                [a/2-r/2, a/2, a/2],
                [a/2+r/2, a/2, a/2],
            ]

            traj_fpath = out_dir / f"{model_name}.extxyz"

            if traj_fpath.exists():
                traj = read(traj_fpath, index=":")
                skip = len(traj)
                atoms = traj[-1]
            else:
                # Create the unit cell with two atoms
                atoms = Atoms(
                    da,
                    positions=positions,
                    # magmoms=magmoms,
                    cell=[a, a+0.001, a+0.002],
                    pbc=True
                )

            print(atoms)

            atoms.calc = calc

            for i, r in enumerate(tqdm(rs)):

                if i < skip:
                    continue

                positions = [
                    [a/2-r/2, a/2, a/2],
                    [a/2+r/2, a/2, a/2],
                ]

                # atoms.set_initial_magnetic_moments(magmoms)

                atoms.set_positions(positions)

                es[i] = atoms.get_potential_energy()

                write(traj_fpath, atoms, append="a")
        except Exception as e:
            print(e)


Selected GPU cuda:0 with 40338.06 MB free memory from 4 GPUs
Default dtype float32 does not match model dtype float64, converting models to float32.


  0%|          | 0/118 [00:00<?, ?it/s]

Atoms(symbols='H2', pbc=True, cell=[7.4399999999999995, 7.441, 7.441999999999999])


  0%|          | 0/344 [00:00<?, ?it/s]

Atoms(symbols='He2', pbc=True, cell=[8.866, 8.866999999999999, 8.868])


  0%|          | 0/418 [00:00<?, ?it/s]

2 is not in list
Atoms(symbols='Li2', pbc=True, cell=[13.144000000000002, 13.145000000000001, 13.146000000000003])


  0%|          | 0/542 [00:00<?, ?it/s]

3 is not in list
Atoms(symbols='Be2', pbc=True, cell=[12.276, 12.277, 12.278])


  0%|          | 0/527 [00:00<?, ?it/s]

4 is not in list
Atoms(symbols='B2', pbc=True, cell=[11.842, 11.843, 11.844000000000001])


  0%|          | 0/516 [00:00<?, ?it/s]

5 is not in list
Atoms(symbols='C2', pbc=True, cell=[10.974, 10.975, 10.976])


  0%|          | 0/480 [00:00<?, ?it/s]

Atoms(symbols='N2', pbc=True, cell=[10.292, 10.293, 10.294])


  0%|          | 0/450 [00:00<?, ?it/s]

Atoms(symbols='O2', pbc=True, cell=[9.3, 9.301, 9.302000000000001])


  0%|          | 0/405 [00:00<?, ?it/s]

Atoms(symbols='F2', pbc=True, cell=[9.052, 9.052999999999999, 9.054])


  0%|          | 0/401 [00:00<?, ?it/s]

Atoms(symbols='Ne2', pbc=True, cell=[9.796000000000001, 9.797, 9.798000000000002])


  0%|          | 0/437 [00:00<?, ?it/s]

10 is not in list
Atoms(symbols='Na2', pbc=True, cell=[15.5, 15.501, 15.502])


  0%|          | 0/625 [00:00<?, ?it/s]

11 is not in list
Atoms(symbols='Mg2', pbc=True, cell=[15.562, 15.562999999999999, 15.564])


  0%|          | 0/651 [00:00<?, ?it/s]

12 is not in list
Atoms(symbols='Al2', pbc=True, cell=[13.950000000000001, 13.951, 13.952000000000002])


  0%|          | 0/588 [00:00<?, ?it/s]

13 is not in list
Atoms(symbols='Si2', pbc=True, cell=[13.578, 13.578999999999999, 13.58])


  0%|          | 0/578 [00:00<?, ?it/s]

14 is not in list
Atoms(symbols='P2', pbc=True, cell=[11.78, 11.780999999999999, 11.782])


  0%|          | 0/492 [00:00<?, ?it/s]

Atoms(symbols='S2', pbc=True, cell=[11.718, 11.719, 11.72])


  0%|          | 0/491 [00:00<?, ?it/s]

Atoms(symbols='Cl2', pbc=True, cell=[11.284, 11.285, 11.286000000000001])


  0%|          | 0/472 [00:00<?, ?it/s]

Atoms(symbols='Ar2', pbc=True, cell=[11.346, 11.347, 11.348])


  0%|          | 0/471 [00:00<?, ?it/s]

18 is not in list
Atoms(symbols='K2', pbc=True, cell=[16.926000000000002, 16.927000000000003, 16.928])


  0%|          | 0/663 [00:00<?, ?it/s]

19 is not in list
Atoms(symbols='Ca2', pbc=True, cell=[16.244, 16.245, 16.246])


  0%|          | 0/653 [00:00<?, ?it/s]

20 is not in list
Atoms(symbols='Sc2', pbc=True, cell=[15.996, 15.997, 15.998000000000001])


  0%|          | 0/646 [00:00<?, ?it/s]

21 is not in list
Atoms(symbols='Ti2', pbc=True, cell=[15.252, 15.253, 15.254000000000001])


  0%|          | 0/618 [00:00<?, ?it/s]

22 is not in list
Atoms(symbols='V2', pbc=True, cell=[15.004, 15.004999999999999, 15.006])


  0%|          | 0/612 [00:00<?, ?it/s]

23 is not in list
Atoms(symbols='Cr2', pbc=True, cell=[15.190000000000001, 15.191, 15.192000000000002])


  0%|          | 0/634 [00:00<?, ?it/s]

24 is not in list
Atoms(symbols='Mn2', pbc=True, cell=[15.190000000000001, 15.191, 15.192000000000002])


  0%|          | 0/634 [00:00<?, ?it/s]

25 is not in list
Atoms(symbols='Fe2', pbc=True, cell=[15.128, 15.129, 15.13])


  0%|          | 0/637 [00:00<?, ?it/s]

26 is not in list
Atoms(symbols='Co2', pbc=True, cell=[14.879999999999999, 14.880999999999998, 14.882])


  0%|          | 0/630 [00:00<?, ?it/s]

27 is not in list
Atoms(symbols='Ni2', pbc=True, cell=[14.879999999999999, 14.880999999999998, 14.882])


  0%|          | 0/632 [00:00<?, ?it/s]

28 is not in list
Atoms(symbols='Cu2', pbc=True, cell=[14.756, 14.757, 14.758000000000001])


  0%|          | 0/618 [00:00<?, ?it/s]

29 is not in list
Atoms(symbols='Zn2', pbc=True, cell=[14.818000000000001, 14.819, 14.820000000000002])


  0%|          | 0/631 [00:00<?, ?it/s]

30 is not in list
Atoms(symbols='Ga2', pbc=True, cell=[14.383999999999999, 14.384999999999998, 14.386])


  0%|          | 0/609 [00:00<?, ?it/s]

31 is not in list
Atoms(symbols='Ge2', pbc=True, cell=[14.198, 14.199, 14.200000000000001])


  0%|          | 0/601 [00:00<?, ?it/s]

32 is not in list
Atoms(symbols='As2', pbc=True, cell=[11.655999999999999, 11.656999999999998, 11.658])


  0%|          | 0/475 [00:00<?, ?it/s]

33 is not in list
Atoms(symbols='Se2', pbc=True, cell=[11.284, 11.285, 11.286000000000001])


  0%|          | 0/456 [00:00<?, ?it/s]

34 is not in list
Atoms(symbols='Br2', pbc=True, cell=[11.532000000000002, 11.533000000000001, 11.534000000000002])


  0%|          | 0/468 [00:00<?, ?it/s]

Atoms(symbols='Kr2', pbc=True, cell=[13.950000000000001, 13.951, 13.952000000000002])


  0%|          | 0/593 [00:00<?, ?it/s]

36 is not in list
Atoms(symbols='Rb2', pbc=True, cell=[19.902, 19.903000000000002, 19.904])


  0%|          | 0/797 [00:00<?, ?it/s]

37 is not in list
Atoms(symbols='Sr2', pbc=True, cell=[17.608, 17.609, 17.61])


  0%|          | 0/704 [00:00<?, ?it/s]

38 is not in list
Atoms(symbols='Y2', pbc=True, cell=[17.05, 17.051000000000002, 17.052])


  0%|          | 0/681 [00:00<?, ?it/s]

39 is not in list
Atoms(symbols='Zr2', pbc=True, cell=[15.624, 15.625, 15.626000000000001])


  0%|          | 0/623 [00:00<?, ?it/s]

40 is not in list
Atoms(symbols='Nb2', pbc=True, cell=[15.872000000000002, 15.873000000000001, 15.874000000000002])


  0%|          | 0/646 [00:00<?, ?it/s]

41 is not in list
Atoms(symbols='Mo2', pbc=True, cell=[15.190000000000001, 15.191, 15.192000000000002])


  0%|          | 0/620 [00:00<?, ?it/s]

42 is not in list
Atoms(symbols='Tc2', pbc=True, cell=[15.128, 15.129, 15.13])


  0%|          | 0/624 [00:00<?, ?it/s]

43 is not in list
Atoms(symbols='Ru2', pbc=True, cell=[15.252, 15.253, 15.254000000000001])


  0%|          | 0/631 [00:00<?, ?it/s]

44 is not in list
Atoms(symbols='Rh2', pbc=True, cell=[15.128, 15.129, 15.13])


  0%|          | 0/628 [00:00<?, ?it/s]

45 is not in list
Atoms(symbols='Pd2', pbc=True, cell=[13.33, 13.331, 13.332])


  0%|          | 0/541 [00:00<?, ?it/s]

46 is not in list
Atoms(symbols='Ag2', pbc=True, cell=[15.686, 15.687, 15.688])


  0%|          | 0/653 [00:00<?, ?it/s]

47 is not in list
Atoms(symbols='Cd2', pbc=True, cell=[15.438000000000002, 15.439000000000002, 15.440000000000003])


  0%|          | 0/642 [00:00<?, ?it/s]

48 is not in list
Atoms(symbols='In2', pbc=True, cell=[15.066, 15.067, 15.068000000000001])


  0%|          | 0/625 [00:00<?, ?it/s]

49 is not in list
Atoms(symbols='Sn2', pbc=True, cell=[15.004, 15.004999999999999, 15.006])


  0%|          | 0/625 [00:00<?, ?it/s]

50 is not in list
Atoms(symbols='Sb2', pbc=True, cell=[15.314000000000002, 15.315000000000001, 15.316000000000003])


  0%|          | 0/640 [00:00<?, ?it/s]

51 is not in list
Atoms(symbols='Te2', pbc=True, cell=[12.338000000000001, 12.339, 12.340000000000002])


  0%|          | 0/492 [00:00<?, ?it/s]

52 is not in list
Atoms(symbols='I2', pbc=True, cell=[12.648000000000001, 12.649000000000001, 12.650000000000002])


  0%|          | 0/507 [00:00<?, ?it/s]

Atoms(symbols='Xe2', pbc=True, cell=[12.772, 12.773, 12.774000000000001])


  0%|          | 0/512 [00:00<?, ?it/s]

54 is not in list
Atoms(symbols='Cs2', pbc=True, cell=[21.576, 21.577, 21.578])


  0%|          | 0/859 [00:00<?, ?it/s]

55 is not in list
Atoms(symbols='Ba2', pbc=True, cell=[18.785999999999998, 18.787, 18.787999999999997])


  0%|          | 0/745 [00:00<?, ?it/s]

56 is not in list
Atoms(symbols='La2', pbc=True, cell=[18.476, 18.477, 18.477999999999998])


  0%|          | 0/737 [00:00<?, ?it/s]

57 is not in list
Atoms(symbols='Ce2', pbc=True, cell=[17.855999999999998, 17.857, 17.857999999999997])


  0%|          | 0/709 [00:00<?, ?it/s]

58 is not in list
Atoms(symbols='Pr2', pbc=True, cell=[18.104, 18.105, 18.105999999999998])


  0%|          | 0/722 [00:00<?, ?it/s]

59 is not in list
Atoms(symbols='Nd2', pbc=True, cell=[18.290000000000003, 18.291000000000004, 18.292])


  0%|          | 0/733 [00:00<?, ?it/s]

60 is not in list
Atoms(symbols='Pm2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/420 [00:00<?, ?it/s]

61 is not in list
Atoms(symbols='Sm2', pbc=True, cell=[17.98, 17.981, 17.982])


  0%|          | 0/720 [00:00<?, ?it/s]

62 is not in list
Atoms(symbols='Eu2', pbc=True, cell=[17.794, 17.795, 17.796])


  0%|          | 0/711 [00:00<?, ?it/s]

63 is not in list
Atoms(symbols='Gd2', pbc=True, cell=[17.546, 17.547, 17.548])


  0%|          | 0/700 [00:00<?, ?it/s]

64 is not in list
Atoms(symbols='Tb2', pbc=True, cell=[17.298000000000002, 17.299000000000003, 17.3])


  0%|          | 0/690 [00:00<?, ?it/s]

65 is not in list
Atoms(symbols='Dy2', pbc=True, cell=[17.794, 17.795, 17.796])


  0%|          | 0/716 [00:00<?, ?it/s]

66 is not in list
Atoms(symbols='Ho2', pbc=True, cell=[17.422, 17.423000000000002, 17.424])


  0%|          | 0/698 [00:00<?, ?it/s]

67 is not in list
Atoms(symbols='Er2', pbc=True, cell=[17.546, 17.547, 17.548])


  0%|          | 0/707 [00:00<?, ?it/s]

68 is not in list
Atoms(symbols='Tm2', pbc=True, cell=[17.298000000000002, 17.299000000000003, 17.3])


  0%|          | 0/693 [00:00<?, ?it/s]

69 is not in list
Atoms(symbols='Yb2', pbc=True, cell=[17.36, 17.361, 17.362])


  0%|          | 0/699 [00:00<?, ?it/s]

70 is not in list
Atoms(symbols='Lu2', pbc=True, cell=[16.988000000000003, 16.989000000000004, 16.990000000000002])


  0%|          | 0/681 [00:00<?, ?it/s]

71 is not in list
Atoms(symbols='Hf2', pbc=True, cell=[16.306, 16.307000000000002, 16.308])


  0%|          | 0/657 [00:00<?, ?it/s]

72 is not in list
Atoms(symbols='Ta2', pbc=True, cell=[15.686, 15.687, 15.688])


  0%|          | 0/631 [00:00<?, ?it/s]

73 is not in list
Atoms(symbols='W2', pbc=True, cell=[15.934, 15.934999999999999, 15.936])


  0%|          | 0/650 [00:00<?, ?it/s]

74 is not in list
Atoms(symbols='Re2', pbc=True, cell=[15.438000000000002, 15.439000000000002, 15.440000000000003])


  0%|          | 0/636 [00:00<?, ?it/s]

75 is not in list
Atoms(symbols='Os2', pbc=True, cell=[15.376, 15.376999999999999, 15.378])


  0%|          | 0/639 [00:00<?, ?it/s]

76 is not in list
Atoms(symbols='Ir2', pbc=True, cell=[14.942000000000002, 14.943000000000001, 14.944000000000003])


  0%|          | 0/620 [00:00<?, ?it/s]

77 is not in list
Atoms(symbols='Pt2', pbc=True, cell=[14.198, 14.199, 14.200000000000001])


  0%|          | 0/587 [00:00<?, ?it/s]

78 is not in list
Atoms(symbols='Au2', pbc=True, cell=[14.383999999999999, 14.384999999999998, 14.386])


  0%|          | 0/596 [00:00<?, ?it/s]

79 is not in list
Atoms(symbols='Hg2', pbc=True, cell=[15.190000000000001, 15.191, 15.192000000000002])


  0%|          | 0/640 [00:00<?, ?it/s]

80 is not in list
Atoms(symbols='Tl2', pbc=True, cell=[15.314000000000002, 15.315000000000001, 15.316000000000003])


  0%|          | 0/635 [00:00<?, ?it/s]

81 is not in list
Atoms(symbols='Pb2', pbc=True, cell=[16.12, 16.121000000000002, 16.122])


  0%|          | 0/674 [00:00<?, ?it/s]

82 is not in list
Atoms(symbols='Bi2', pbc=True, cell=[15.748000000000001, 15.749, 15.750000000000002])


  0%|          | 0/654 [00:00<?, ?it/s]

83 is not in list
Atoms(symbols='Po2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/474 [00:00<?, ?it/s]

84 is not in list
Atoms(symbols='At2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/465 [00:00<?, ?it/s]

85 is not in list
Atoms(symbols='Rn2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/465 [00:00<?, ?it/s]

86 is not in list
Atoms(symbols='Fr2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/365 [00:00<?, ?it/s]

87 is not in list
Atoms(symbols='Ra2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/401 [00:00<?, ?it/s]

88 is not in list
Atoms(symbols='Ac2', pbc=True, cell=[17.36, 17.361, 17.362])


  0%|          | 0/674 [00:00<?, ?it/s]

89 is not in list
Atoms(symbols='Th2', pbc=True, cell=[18.166, 18.167, 18.168])


  0%|          | 0/722 [00:00<?, ?it/s]

90 is not in list
Atoms(symbols='Pa2', pbc=True, cell=[17.855999999999998, 17.857, 17.857999999999997])


  0%|          | 0/712 [00:00<?, ?it/s]

91 is not in list
Atoms(symbols='U2', pbc=True, cell=[16.802, 16.803, 16.804])


  0%|          | 0/663 [00:00<?, ?it/s]

92 is not in list
Atoms(symbols='Np2', pbc=True, cell=[17.483999999999998, 17.485, 17.485999999999997])


  0%|          | 0/703 [00:00<?, ?it/s]

93 is not in list
Atoms(symbols='Pu2', pbc=True, cell=[17.422, 17.423000000000002, 17.424])


  0%|          | 0/702 [00:00<?, ?it/s]

94 is not in list
Atoms(symbols='Am2', pbc=True, cell=[17.546, 17.547, 17.548])


  0%|          | 0/715 [00:00<?, ?it/s]

95 is not in list
Atoms(symbols='Cm2', pbc=True, cell=[18.91, 18.911, 18.912])


  0%|          | 0/793 [00:00<?, ?it/s]

96 is not in list
Atoms(symbols='Bk2', pbc=True, cell=[21.08, 21.081, 21.081999999999997])


  0%|          | 0/1036 [00:00<?, ?it/s]

97 is not in list
Atoms(symbols='Cf2', pbc=True, cell=[18.91, 18.911, 18.912])


  0%|          | 0/927 [00:00<?, ?it/s]

98 is not in list
Atoms(symbols='Es2', pbc=True, cell=[16.740000000000002, 16.741000000000003, 16.742])


  0%|          | 0/819 [00:00<?, ?it/s]

99 is not in list
Atoms(symbols='Fm2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

100 is not in list
Atoms(symbols='Md2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

101 is not in list
Atoms(symbols='No2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

102 is not in list
Atoms(symbols='Lr2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

103 is not in list
Atoms(symbols='Rf2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

104 is not in list
Atoms(symbols='Db2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

105 is not in list
Atoms(symbols='Sg2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

106 is not in list
Atoms(symbols='Bh2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

107 is not in list
Atoms(symbols='Hs2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

108 is not in list
Atoms(symbols='Mt2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

109 is not in list
Atoms(symbols='Ds2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

110 is not in list
Atoms(symbols='Rg2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

111 is not in list
Atoms(symbols='Cn2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

112 is not in list
Atoms(symbols='Nh2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

113 is not in list
Atoms(symbols='Fl2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

114 is not in list
Atoms(symbols='Mc2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

115 is not in list
Atoms(symbols='Lv2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

116 is not in list
Atoms(symbols='Ts2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

117 is not in list
Atoms(symbols='Og2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/582 [00:00<?, ?it/s]

118 is not in list


In [4]:


for model in MLIPEnum:
    
    model_name = model.name
    
    print(f"========== {model_name} ==========")
    
    df = pd.DataFrame(columns=[
        "name", 
        "method", 
        "R", "E", "F", "S^2", 
        "force-flip-times",
        "force-total-variation",
        "energy-diff-flip-times",
        "energy-grad-norm-max",
        "energy-jump",
        "energy-total-variation",
        "tortuosity",
        "conservation-deviation",
        "spearman-descending-force",
        "spearman-ascending-force",
        "spearman-repulsion-energy",
        "spearman-attraction-energy"
    ])
    

    for symbol in tqdm(chemical_symbols[1:]):

        da = symbol + symbol

        out_dir = Path(REGISTRY[model_name]["family"]) / da

        traj_fpath = out_dir / f"{model_name}.extxyz"


        if traj_fpath.exists():
            traj = read(traj_fpath, index=":")
        else:
            continue

        Rs, Es, Fs, S2s = [], [], [], []
        for atoms in traj:

            vec = atoms.positions[1] - atoms.positions[0]
            r = np.linalg.norm(vec)
            e = atoms.get_potential_energy()
            f = np.inner(vec/r, atoms.get_forces()[1])
            # s2 = np.mean(np.power(atoms.get_magnetic_moments(), 2))

            Rs.append(r)
            Es.append(e)
            Fs.append(f)
            # S2s.append(s2)

        rs = np.array(Rs)
        es = np.array(Es)
        fs = np.array(Fs)

        indices = np.argsort(rs)[::-1]
        rs = rs[indices]
        es = es[indices]
        fs = fs[indices]

        iminf = np.argmin(fs)
        imine = np.argmin(es)

        de_dr = np.gradient(es, rs)
        d2e_dr2 = np.gradient(de_dr, rs)

        rounded_fs = np.copy(fs)
        rounded_fs[np.abs(rounded_fs) < 1e-2] = 0
        fs_sign = np.sign(rounded_fs)
        fs_sign = fs_sign[fs_sign != 0]

        # rounded_ediff = np.diff(es)
        # rounded_ediff[np.abs(rounded_ediff) < zero_threshold] = 0
        ediff = np.diff(es)
        ediff[np.abs(ediff) < 1e-3] = 0
        ediff_sign = np.sign(ediff)
        mask = ediff_sign != 0
        ediff = ediff[mask]
        ediff_sign = ediff_sign[mask]
        ediff_flip = np.diff(ediff_sign) != 0
        ejump = np.abs(ediff[:-1][ediff_flip]).sum() + np.abs(ediff[1:][ediff_flip]).sum()

        conservation_deviation = np.mean(np.abs(fs + de_dr))
        
        etv = np.sum(np.abs(np.diff(es)))

        data = {
            "name": da,
            "method": model_name,
            "R": rs,
            "E": es,
            "F": fs,
            "S^2": S2s,
            "force-flip-times": np.sum(np.diff(fs_sign)!=0),
            "force-total-variation": np.sum(np.abs(np.diff(fs))),
            "energy-diff-flip-times": np.sum(ediff_flip),
            "energy-grad-norm-max": np.max(np.abs(de_dr)),
            "energy-jump": ejump,
            # "energy-grad-norm-mean": np.mean(de_dr_abs),
            "energy-total-variation": etv,
            "tortuosity": etv / (abs(es[0] - es.min()) + (es[-1] - es.min())),
            "conservation-deviation": conservation_deviation,
            "spearman-descending-force": stats.spearmanr(rs[iminf:], fs[iminf:]).statistic,
            "spearman-ascending-force": stats.spearmanr(rs[:iminf], fs[:iminf]).statistic,
            "spearman-repulsion-energy": stats.spearmanr(rs[imine:], es[imine:]).statistic,
            "spearman-attraction-energy": stats.spearmanr(rs[:imine], es[:imine]).statistic,
        }

        df = pd.concat([df, pd.DataFrame([data])], ignore_index=True)


    json_fpath = Path(REGISTRY[model_name]["family"]) / "homonuclear-diatomics.json"

    if json_fpath.exists():
        df0 = pd.read_json(json_fpath)
        df = pd.concat([df0, df], ignore_index=True)
        df.drop_duplicates(inplace=True, subset=["name", "method"], keep='last')

    df.to_json(json_fpath, orient="records")



  0%|          | 0/118 [00:00<?, ?it/s]

  df = pd.concat([df, pd.DataFrame([data])], ignore_index=True)
  "tortuosity": etv / (abs(es[0] - es.min()) + (es[-1] - es.min())),
  "spearman-descending-force": stats.spearmanr(rs[iminf:], fs[iminf:]).statistic,
  "spearman-repulsion-energy": stats.spearmanr(rs[imine:], es[imine:]).statistic,




  0%|          | 0/118 [00:00<?, ?it/s]

  df = pd.concat([df, pd.DataFrame([data])], ignore_index=True)




  0%|          | 0/118 [00:00<?, ?it/s]

  df = pd.concat([df, pd.DataFrame([data])], ignore_index=True)
  "spearman-ascending-force": stats.spearmanr(rs[:iminf], fs[:iminf]).statistic,




  0%|          | 0/118 [00:00<?, ?it/s]

  df = pd.concat([df, pd.DataFrame([data])], ignore_index=True)




  0%|          | 0/118 [00:00<?, ?it/s]

  df = pd.concat([df, pd.DataFrame([data])], ignore_index=True)
  "spearman-attraction-energy": stats.spearmanr(rs[:imine], es[:imine]).statistic,




  0%|          | 0/118 [00:00<?, ?it/s]

  df = pd.concat([df, pd.DataFrame([data])], ignore_index=True)
  "tortuosity": etv / (abs(es[0] - es.min()) + (es[-1] - es.min())),
  "spearman-repulsion-energy": stats.spearmanr(rs[imine:], es[imine:]).statistic,
  "spearman-attraction-energy": stats.spearmanr(rs[:imine], es[:imine]).statistic,
  "tortuosity": etv / (abs(es[0] - es.min()) + (es[-1] - es.min())),




  0%|          | 0/118 [00:00<?, ?it/s]

  df = pd.concat([df, pd.DataFrame([data])], ignore_index=True)




  0%|          | 0/118 [00:00<?, ?it/s]



  0%|          | 0/118 [00:00<?, ?it/s]

  df = pd.concat([df, pd.DataFrame([data])], ignore_index=True)




  0%|          | 0/118 [00:00<?, ?it/s]

  df = pd.concat([df, pd.DataFrame([data])], ignore_index=True)
  "spearman-attraction-energy": stats.spearmanr(rs[:imine], es[:imine]).statistic,


In [5]:
df

Unnamed: 0,name,method,R,E,F,S^2,force-flip-times,force-total-variation,energy-diff-flip-times,energy-grad-norm-max,energy-jump,energy-total-variation,conservation-deviation,spearman-descending-force,spearman-ascending-force,spearman-repulsion-energy,spearman-attraction-energy,tortuosity
118,HH,ALIGNN,"[3.7199999999999998, 3.70996794, 3.69993586, 3...","[-1.2249419689178467, -1.2238645553588867, -1....","[1.91e-06, 0.00826454, 0.00533009, -0.00355052...",[],29,443.614255,30,40.656501,4.074248,13.676080,2.022960,-0.200684,-0.119952,-0.986572,0.844786,1.797762
119,HeHe,ALIGNN,"[4.433, 4.4229736200000005, 4.41294724, 4.4029...","[2.4748411178588867, 2.4748411178588867, 2.474...","[0.0, -1e-08, 0.0, 0.0, 1e-08, 0.0, 0.0, 0.0, ...",[],44,1448.979436,43,160.175544,13.831072,32.544741,4.849522,-0.021494,-0.195001,-0.720218,0.609519,3.921846
120,LiLi,ALIGNN,"[6.572000000000001, 6.561981520000001, 6.55196...","[-0.21738338470458984, -0.21738338470458984, -...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",[],22,1127.749591,20,281.659136,2.067203,20.053466,2.469147,-0.714286,-0.444127,-0.991010,-0.654715,1.386696
121,BeBe,ALIGNN,"[6.138000000000001, 6.12797338, 6.117946759999...","[2.665587902069092, 2.665587902069092, 2.66558...","[0.0, 0.0, 0.0, 0.0, 1e-08, 0.0, 0.0, 0.0, 0.0...",[],25,1645.960357,21,145.040462,9.422094,51.647955,5.076523,0.544090,-0.158055,-0.989962,0.342476,1.715953
122,BB,ALIGNN,"[5.921000000000001, 5.91097088, 5.900941739999...","[0.6220548152923584, 0.6220548152923584, 0.622...","[0.0, 0.0, 1e-08, 0.0, 0.0, 1e-08, -1e-08, -1e...",[],35,1811.413732,36,131.340791,13.088368,52.691856,5.551787,0.052632,-0.171460,-0.985024,0.757880,1.617372
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
231,FlFl,ALIGNN,"[6.0, 5.989982779999999, 5.979965579999999, 5....","[-10.127323150634766, -10.127323150634766, -10...","[0.0, 0.0, -9.5e-07, 0.0, 0.0, 0.0, 0.0, 9.5e-...",[],39,3329.770831,44,181.338142,12.292844,84.239434,8.151059,-0.249255,-0.316556,-0.946770,0.159576,2.694147
232,McMc,ALIGNN,"[6.0, 5.989982779999999, 5.979965579999999, 5....","[-10.127323150634766, -10.127323150634766, -10...","[0.0, 0.0, -9.5e-07, 0.0, 0.0, 0.0, 0.0, 9.5e-...",[],39,3329.770831,44,181.338142,12.292844,84.239434,8.151059,-0.249255,-0.316556,-0.946770,0.159576,2.694147
233,LvLv,ALIGNN,"[6.0, 5.989982779999999, 5.979965579999999, 5....","[-10.127323150634766, -10.127323150634766, -10...","[0.0, 0.0, -9.5e-07, 0.0, 0.0, 0.0, 0.0, 9.5e-...",[],39,3329.770831,44,181.338142,12.292844,84.239434,8.151059,-0.249255,-0.316556,-0.946770,0.159576,2.694147
234,TsTs,ALIGNN,"[6.0, 5.989982779999999, 5.979965579999999, 5....","[-10.127323150634766, -10.127323150634766, -10...","[0.0, 0.0, -9.5e-07, 0.0, 0.0, 0.0, 0.0, 9.5e-...",[],39,3329.770831,44,181.338142,12.292844,84.239434,8.151059,-0.249255,-0.316556,-0.946770,0.159576,2.694147
