<a href="https://colab.research.google.com/github/knc6/jarvis-tools-notebooks/blob/master/jarvis-tools-notebooks/Li-P-O_Neural_Network_Potentials.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ALIGNN on Li-P-O dataset

Ref:https://github.com/jax-md/jax-md/blob/main/notebooks/tutorial/Chapter_3_Neural_Network_Potentials.ipynb



In [None]:
!pip install --pre dgl -f https://data.dgl.ai/wheels/cu117/repo.html
!pip install --pre dglgo -f https://data.dgl.ai/wheels-test/repo.html
!pip install alignn

In [None]:
!wget https://github.com/amilmerchant/jax-md/raw/nequip_scratch/notebooks/data/lipo_small.zip

In [None]:
import jax.numpy as jnp
import numpy as np
from jax import ShapedArray
from jax import device_put
from jax import tree_map
from jax.config import config
from flax import serialization
f32 = jnp.float32
f64 = jnp.float64
     
# @title Data Utilities
def read_lipo(filename):
  read_target = {
    'box': ShapedArray((3, 3), f32),
    'atoms': ShapedArray((94,), f32),
    'position': ShapedArray((1, 3), f32),
    'force': ShapedArray((1, 3), f32),
    'energy': ShapedArray((1,), f32),
  }

  with open(filename, 'rb') as f:
    return tree_map(lambda x: x.astype(f64), 
                    serialization.from_bytes(read_target, f.read()))

!unzip -o lipo_small.zip

In [None]:
train = read_lipo('lipo_small/train.msgpack')
test = read_lipo('lipo_small/test.msgpack')


In [None]:
from jarvis.core.specie import atomic_numbers_to_symbols
from jarvis.core.atoms import Atoms
Z = np.where(train['atoms']==1)[1]+1
symbols = atomic_numbers_to_symbols(Z)
atoms = Atoms(lattice_mat=train['box'],elements=symbols, coords=train['position'][0],cartesian=False)
mean_energy = train['energy'].mean()

In [None]:
mem = []
count=0
#train
for p, e, f in zip(train['position'], train['energy']-mean_energy, train['force']):
  atoms = Atoms(lattice_mat=train['box'],elements=symbols, coords=p,cartesian=False)
  info = {}
  count+=1
  info['jid'] = str(count)
  info["atoms"] = atoms.to_dict()
  info["total_energy"] = e / atoms.num_atoms
  info["forces"] = f.tolist()
  info['stresses'] = np.zeros((3,3)).tolist()
  mem.append(info)

#validation
for p, e, f in zip(test['position'], test['energy']-mean_energy, test['force']):
  atoms = Atoms(lattice_mat=test['box'],elements=symbols, coords=p,cartesian=False)
  info = {}
  count+=1
  info['jid'] = str(count)
  info["atoms"] = atoms.to_dict()
  info["total_energy"] = e / atoms.num_atoms
  info["forces"] = f.tolist()
  info['stresses'] = np.zeros((3,3)).tolist()
  mem.append(info)
#test
for p, e, f in zip(test['position'], test['energy']-mean_energy, test['force']):
  atoms = Atoms(lattice_mat=test['box'],elements=symbols, coords=p,cartesian=False)
  info = {}
  count+=1
  info['jid'] = str(count)
  info["atoms"] = atoms.to_dict()
  info["total_energy"] = e / atoms.num_atoms
  info["forces"] = f.tolist()
  info['stresses'] = np.zeros((3,3)).tolist()
  mem.append(info)

In [None]:
from jarvis.db.jsonutils import dumpjson
dumpjson(data=mem, filename="id_prop.json")

In [None]:
null=None
false=False
true=True
conf= dict({
    "version": "112bbedebdaecf59fb18e11c929080fb2f358246",
    "dataset": "user_data",
    "target": "target",
    "atom_features": "cgcnn",
    "neighbor_strategy": "k-nearest",
    "id_tag": "jid",
    "random_seed": 123,
    "classification_threshold": null,
    "n_val": 100,
    "n_test": 100,
    "n_train": 100,
    "train_ratio": 0.9,
    "val_ratio": 0.05,
    "test_ratio": 0.05,
    "target_multiplication_factor": null,
    "epochs": 5,
    "batch_size": 2,
    "weight_decay": 1e-05,
    "learning_rate": 0.001,
    "filename": "sample",
    "warmup_steps": 2000,
    "criterion": "l1",
    "optimizer": "adamw",
    "scheduler": "onecycle",
    "pin_memory": false,
    "save_dataloader": false,
    "write_checkpoint": true,
    "write_predictions": true,
    "store_outputs": false,
    "progress": true,
    "log_tensorboard": false,
    "standard_scalar_and_pca": false,
    "use_canonize": false,
    "num_workers": 0,
    "cutoff": 4.0,
    "max_neighbors": 12,
    "keep_data_order": true,
    "normalize_graph_level_loss": false,
    "distributed": false,
    "n_early_stopping": null,
    "output_dir": "./",
    "model": {
        "name": "alignn_atomwise",
        "alignn_layers": 2,
        "gcn_layers": 2,
        "atom_input_features": 92,
        "edge_input_features": 80,
        "triplet_input_features": 40,
        "embedding_features": 64,
        "hidden_features": 256,
        "output_features": 1,
        "grad_multiplier": -1,
        "calculate_gradient": true,
        "atomwise_output_features": 3,
        "graphwise_weight": 1.0,
        "gradwise_weight": 10.0,
        "stresswise_weight": 0.0,
        "atomwise_weight": 0.0,
        "link": "identity",
        "zero_inflated": false,
        "classification": false
    }
})
dumpjson(data=conf,filename='config.json')

In [None]:
!train_folder_ff.py --root_dir "." --config "config.json" --output_dir=temp