In [1]:
from jax.config import config
config.update("jax_enable_x64", True)
import haiku as hk
from ase import Atoms
import numpy as np
from typing import List, Optional
from jax_md import partition
from jax_md import partition, space
from jax_md.util import high_precision_sum
import jax
import jax.numpy as jnp
import einops
from typing import Sequence, Any
from jax.tree_util import Partial
from ase.io import read

from gmnn_jax.utils.weight_transfer import transfer_parameters
from gmnn_jax.model import GMNN


def convert_ase_to_data_dict(atoms_list: List[Atoms]):
    """Converts a list of ASE Atoms objects into a dictionary containing numpy arrays.
    All arrays containing per-atom quantities (e.g. positions, forces) are zero-padded
    to have the same size as the largest structure in the list.
    """
    n_data = len(atoms_list)
    n_maxat = max([len(atoms) for atoms in atoms_list])

    key_map = {"energy": "energy", "forces": "forces", "stress": "stress"}

    data = {
        "numbers": np.zeros((n_data, n_maxat), dtype=np.int64),  # atomic numbers
        "n_atoms": np.zeros((n_data,), dtype=np.int64),  # number of atoms
        "positions": np.zeros(
            (n_data, n_maxat, 3), dtype=np.float64
        ),  # Cartesian coordinates
        "energy": np.zeros((n_data,), dtype=np.float64),  # total energies
        "forces": np.zeros((n_data, n_maxat, 3), dtype=np.float64),  # atomic forces
        "cell": np.zeros((n_data, 3, 3), dtype=np.float64),  # periodic cell
        "stress": np.zeros((n_data, 3, 3), dtype=np.float64),
        "charge": np.zeros((n_data,), dtype=np.float64),
        "dipole": np.zeros((n_data, 6), dtype=np.float64),
        "mat": np.zeros((n_data, 6), dtype=np.float64),
    }

    for ii, atoms in enumerate(atoms_list):
        data["numbers"][ii, : len(atoms)] = atoms.get_atomic_numbers()
        data["n_atoms"][ii] = len(atoms)
        data["positions"][ii, : len(atoms), :] = atoms.get_positions()
        data["cell"][ii] = atoms.get_cell()

        if atoms.calc is not None:
            results = {key_map.get(k, k): v for k, v in atoms.calc.results.items()}
            for k, v in results.items():
                if k == "forces":
                    data["forces"][ii, : len(atoms), :] = v
                else:
                    data[k][ii] = v

    pruned_data = {k: v for k, v in data.items() if np.any(np.abs(v) > 1e-6)}

    return pruned_data


def get_model(
    atomic_numbers,
    units,
    displacement,
    box_size: float=10.0,
    cutoff_distance=6.0,
    n_basis=7,
    n_radial=5,
    dr_threshold=0.5,
    nl_format: partition.NeighborListFormat=partition.Sparse,
    **neighbor_kwargs
    ):
    neighbor_fn = partition.neighbor_list(
        displacement,
        box_size,
        cutoff_distance,
        dr_threshold,
        fractional_coordinates=False,
        format=nl_format,
        **neighbor_kwargs
        )

    n_atoms = atomic_numbers.shape[0]
    Z = jnp.asarray(atomic_numbers)
    n_species = 9#10#jnp.max(Z)
    
    @hk.without_apply_rng
    @hk.transform
    def model(R, neighbor):
        gmnn = GMNN(
            units,
            displacement,
            n_atoms=n_atoms,
            n_basis=n_basis,
            n_radial=n_radial,
            n_species=n_species,
        )
        out = gmnn(R, Z, neighbor)
        # mask = partition.neighbor_list_mask(neighbor)
        # out = out * mask
        return high_precision_sum(out) # jnp.sum(out)

    return neighbor_fn, model.init, model.apply


trained_params = np.load("./etoh_model_params.npz")
# print(trained_params.files)
# for name in trained_params.files:
#     print(name)
# quit()


# atoms = read("raw_data/ds.extxyz")
atoms = read("ethanol.traj")

In [2]:
data = convert_ase_to_data_dict([atoms])
data = {k: jnp.asarray(v[0]) for k,v in data.items()}

# box_size = data["cell"][0,0]
box_size = None
r_cutoff = 6.0
nl_format = partition.Sparse
# displacement_fn, shift_fn = space.periodic(box_size)
displacement_fn, shift_fn = space.free()

neighbor_fn, model_init, model = get_model(
    atomic_numbers=data["numbers"],
    units=[512,512],
    displacement=displacement_fn,
    box_size=box_size,
    cutoff_distance=r_cutoff,
    dr_threshold=0.0,
    )
neighbor = neighbor_fn.allocate(data["positions"], extra_capacity=0)
# mask = partition.neighbor_list_mask(neighbor)
# intmask = mask.astype(int)
# n_neighbors = jnp.sum(intmask)
# print(set(np.asarray(neighbor.idx[0])))


# print(n_neighbors)
# quit()
rng_key = jax.random.PRNGKey(42)
params = model_init(rng=rng_key, R=data["positions"], neighbor=neighbor)
# print(params.keys())
# quit()

# quit()
transfered_params = transfer_parameters(params, trained_params)

result = model(params=transfered_params, R=data["positions"], neighbor=neighbor) # , data["numbers"]
print(result)

F = jax.grad(model, argnums=1)(transfered_params, data["positions"], neighbor)
print(F)
# coords are equal
# n_neighborsare equal





-97159.6679311938
[[-5.30348755 -3.28318722 -3.44235098]
 [-8.86077721 -4.4744465   4.40773533]
 [ 2.57754753  2.04141467 -1.77634654]
 [ 4.99685481  5.66539417  2.39515222]
 [-0.66234355 -0.28186489  1.91485807]
 [-0.07846298 -1.92709607 -1.25942134]
 [ 5.49081969  2.38676672 -1.14964303]
 [ 3.47984819  3.25736062 -1.14081617]
 [-1.63999893 -3.38434149  0.05083245]]


In [3]:
from functools import partial

energy_fn = partial(model,transfered_params)
result = energy_fn(data["positions"], neighbor)
result


DeviceArray(-97159.66793119, dtype=float64)

In [4]:
data["numbers"], atoms.get_masses()

(DeviceArray([6, 6, 1, 1, 1, 1, 8, 1, 1], dtype=int64),
 array([12.011, 12.011,  1.008,  1.008,  1.008,  1.008, 15.999,  1.008,
         1.008]))

In [6]:
from jax_md import simulate
from ase import units
K_B = 8.617e-5
dt = 0.5 * units.fs
kT = K_B * 200
# Si_mass = 2.91086E-3
masses = jnp.array(atoms.get_masses())

init_fn, apply_fn = simulate.nvt_nose_hoover(energy_fn, shift_fn, dt, kT)

apply_fn = jax.jit(apply_fn)

In [7]:
state = init_fn(jax.random.PRNGKey(0), data["positions"], masses, neighbor=neighbor)
state



NVTNoseHooverState(position=DeviceArray([[-3.68683679,  2.36186247, -0.29336568],
             [-3.78000175,  0.92056215,  0.21397447],
             [-3.42405028,  2.37994027, -1.36304911],
             [-2.93796942,  2.97120768,  0.25672536],
             [-4.67591426,  2.82840412, -0.18108502],
             [-2.82737916,  0.37769472,  0.05450457],
             [-4.82847316,  0.30119435, -0.53156614],
             [-3.96906495,  0.92205407,  1.30280436],
             [-4.89959806, -0.62270479, -0.21510487]], dtype=float64), momentum=DeviceArray([[-0.85630635, -0.05643766, -0.30167575],
             [ 0.20351371,  0.40557841,  0.39048464],
             [ 0.15760368,  0.30039093, -0.15663619],
             [ 0.11643982, -0.06655513,  0.09720758],
             [ 0.22542999, -0.2626623 , -0.1294574 ],
             [ 0.10284429,  0.06528132, -0.12404784],
             [-0.30186973,  0.02585869,  0.20906192],
             [ 0.27107071, -0.1171036 ,  0.08199368],
             [ 0.08127389, -

In [8]:
import time

In [9]:
@jax.jit
def sim(state, neighbor):
    def body_fn(i, state):
        state, neighbor = state
        neighbor = neighbor.update(state.position)
        state = apply_fn(state, neighbor=neighbor)
        return state, neighbor
    return jax.lax.fori_loop(0, 5, body_fn, (state, neighbor))

positions = []
forces = []
start = time.time()
step = 0
while step < 500*4:
    new_state, neighbor = sim(state, neighbor)
    if neighbor.did_buffer_overflow:
        print('Neighbor list overflowed, reallocating.')
        neighbor = neighbor_fn.allocate(state.position)
    else:
        state = new_state
        positions += [state.position]
        forces += [state.force]
        step += 1

end = time.time()
print(f"elapsed time: {end - start}") #
# steps: 20, time: 18.952657222747803
# steps: 5, time: 18.685712099075317



elapsed time: 18.821195602416992


In [11]:
positions[-1]

DeviceArray([[-5.3700625 ,  0.93344789, -0.02903904],
             [-3.99913346,  1.53128384,  0.22864531],
             [-5.75533277,  1.28144812, -1.00352859],
             [-6.10013318,  1.19934828,  0.74430096],
             [-5.26454138, -0.1690307 , -0.11271822],
             [-4.03551175,  2.63355587,  0.2033728 ],
             [-3.12844699,  0.99017069, -0.77384016],
             [-3.60759242,  1.25281025,  1.22729059],
             [-2.28584843,  1.46637783, -0.68298644]], dtype=float64)

In [14]:
from ase.visualize import view
atoms_list = []

for pos in positions:
    new_atoms = Atoms(atoms.symbols, pos, cell=atoms.cell)
    atoms_list.append(new_atoms)
view(atoms_list)


<Popen: returncode: None args: ['/home/ms/miniconda3/envs/gmax/bin/python', ...>