In [1]:
import os
import jraph
import ase
import ase.io
import jax
import jax.numpy as jnp
import numpy as np
import optax

from phonax.predictors import predict_energy_forces_stress

from phonax.trained_models import (
    NequIP_JAXMD_uniIAP_model,
    NequIP_JAXMD_uniIAP_PBEsol_finetuned_model,
    MACE_uniIAP_model,
    MACE_uniIAP_PBEsol_finetuned_model,
    NequIP_JAX_uniIAP_model,
)

from phonax.data_utils import (
    crystal_struct_to_period_graph,
    crystal_atoms_to_period_graph,
)

jax.config.update("jax_debug_nans", True)
jax.config.update("jax_debug_infs", True)
np.set_printoptions(precision=3, suppress=True)


# Atomic structure force examination

One of the conditions in deriving physical phonon or vibrational spectrum of an atomic structure is the (local) equilibrium structure.
That is, the atomic forces should be (close to) zero, and the eigenvalues for the Hessians can be interpreted as the vibrational mode energies.
Out-of-equilibrium structures can lead to unstable negative modes or spurious mode frequencies in diagonalizing the Hessian matrix.

Therefore, to make consistent phonon predictions of a given atomic structure, one has to first check the predicted forces within the structure and make sure they are vanishing and hence close to a local equilibrium.
In the interatomic potential models parametrized by equivariant graph neural networks, different training data and model architectures can lead to different predictions of the local equilibrium structure.
For example, different exchange-correlation functional used in the DFT such as LDA, PBE, PBEsol, or the more advanced hybrid functionals, can lead to different equilibrium structure predictions.

Here in this tutorial, we demonstrate how to check the (residual) atomic forces given an atomic structure and a trained energy model. For out-of-equilibrum structures, we also show how we can perform relaxations of the atomic coordinates to reach a local equilibrum state.



In [2]:
# NequIP model trained with universal-IAP (PBE)
model_fn, params, num_message_passing, r_max = NequIP_JAXMD_uniIAP_model(os.path.join(os.getcwd(), 'trained-models'))

# NequIP model trained with universal-IAP (PBE) + PBEsol fine-tuning
#model_fn, params, num_message_passing, r_max = NequIP_JAXMD_uniIAP_PBEsol_finetuned_model()

# MACE model trained with universal-IAP (PBE)
#model_fn, params, num_message_passing, r_max = MACE_uniIAP_PBEsol_finetuned_model()|

# MACE model trained with universal-IAP (PBE) + PBEsol fine-tuning
#model_fn, params, num_message_passing, r_max = NequIP_JAX_uniIAP_model()

Create NequIP (JAX-MD version) with parameters {'use_sc': True, 'graph_net_steps': 2, 'hidden_irreps': '128x0e+ 128x0o  + 64x1e +64x1o +64x2e +64x2o', 'nonlinearities': {'e': 'swish', 'o': 'tanh'}, 'r_max': 5.0, 'avg_num_neighbors': 36.712880186304396, 'avg_r_min': None, 'num_species': 100, 'radial_basis': <function bessel_basis at 0x7fd2910ce840>, 'radial_envelope': <function soft_envelope at 0x7fd2910ce7a0>}


In [3]:
atoms, graph = crystal_struct_to_period_graph('data/mp-149/mp-149.vasp', r_max, num_message_passing)

In [4]:
predictor = jax.jit(
    lambda w, g: predict_energy_forces_stress(lambda *x: model_fn(w, *x), g)
)
pred_out = predictor(params, graph)

print(pred_out)


{'energy': Array([-10.781], dtype=float32), 'forces': Array([[ 0., -0., -0.],
       [ 0.,  0.,  0.]], dtype=float32), 'pressure': Array([0.014], dtype=float32), 'stress': Array([[[ 0.005,  0.   ,  0.   ],
        [-0.   ,  0.005, -0.   ],
        [ 0.   , -0.   ,  0.005]]], dtype=float32), 'stress_cell': Array([[[ 0.005,  0.   ,  0.   ],
        [-0.   ,  0.005, -0.   ],
        [ 0.   , -0.   ,  0.005]]], dtype=float32), 'stress_forces': Array([[[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]]], dtype=float32)}


## Atomic relaxations

Given an out-of-equilibrium atomic structure, we would like to perform atomic coordinate updates to approach the local equilibrium with vanishing atomic forces.
We can use the computed atomic forces as the effective gradients in guiding the relaxations of the atomic positions.
Here we utilize the Optax optimization library for JAX to perform the relaxation.


In [6]:
struct_atoms_init = ase.io.read('data/crystals/MoS2-mono.vasp')
struct_graph_init = crystal_atoms_to_period_graph(struct_atoms_init,
                                                  r_max, 
                                                  num_message_passing)
struct_pred_init = predictor(params,struct_graph_init)
struct_atoms = struct_atoms_init.copy()
print('Init stage:',struct_pred_init)

struct_atoms = ase.io.read('data/crystals/MoS2-mono.vasp')
def struct_relax_grad(pos):
    struct_atoms.set_positions(pos)
    struct_graph = crystal_atoms_to_period_graph(struct_atoms,
                                                 r_max, 
                                                 num_message_passing)
    
    struct_pred = predictor(params,struct_graph)
    pred_force = struct_pred['forces']
    return -pred_force


relax_lr = 0.01
optimizer = optax.adam(relax_lr)
initpos = struct_atoms_init.get_positions()
initpos[:,1] *= 1.10
initpos[:,2] *= 1.10
initpos[:,2] -= np.mean(initpos[:,2])
relax_pos = jnp.array(initpos)
opt_state = optimizer.init(relax_pos)

train_loss = []
print(relax_pos)
for _ in range(100):
    grads = struct_relax_grad(relax_pos)
    updates, opt_state = optimizer.update(grads, opt_state)
    relax_pos = optax.apply_updates(relax_pos, updates)
    
    #print(relax_pos[2,2] - relax_pos[0,2])
    
relax_pos = np.array(relax_pos)
relax_pos[:,2] -= np.mean(relax_pos[:,2])
print(relax_pos)




Init stage: {'energy': Array([-21.427], dtype=float32), 'forces': Array([[ 0.   ,  0.   , -0.   ],
       [ 0.   , -0.   , -0.004],
       [-0.   , -0.   ,  0.004]], dtype=float32), 'pressure': Array([0.002], dtype=float32), 'stress': Array([[[ 0.001, -0.   , -0.   ],
        [-0.   ,  0.001,  0.   ],
        [-0.   ,  0.   ,  0.   ]]], dtype=float32), 'stress_cell': Array([[[ 0.001,  0.   , -0.   ],
        [-0.   ,  0.001, -0.   ],
        [ 0.   , -0.   , -0.   ]]], dtype=float32), 'stress_forces': Array([[[ 0., -0., -0.],
        [ 0., -0.,  0.],
        [-0.,  0.,  0.]]], dtype=float32)}
[[ 1.596  1.014  0.   ]
 [ 0.     2.027 -1.721]
 [ 0.     2.027  1.721]]
[[ 1.595  1.06  -0.   ]
 [-0.001  1.981 -1.565]
 [-0.001  1.981  1.565]]
