In [42]:
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
import jax.numpy as jnp
import numpy as np
from jax import jit, grad, random, tree_util
import jax, optax
import flax.linen as nn
from deepmd_jax.data import DataSystem
from deepmd_jax.model import DPModel
import pickle
from time import time
# jax.config.update('jax_enable_x64', True)
jax.config.update('jax_default_matmul_precision', 'float32')
np.set_printoptions(precision=5, suppress=True)
print('Starting program on device', jax.devices())

save_name    = 'model_polaron.pkl'
train_data   = DataSystem('polaron_data', ['coord', 'box', 'force', 'energy'])
orthorombic  = True
rcut         = 7.6
embed_widths = [16, 32, 32]
fit_widths   = [128, 128]
axis_neuron  = 16
batch_size   = 1
lr           = 0.002 
beta2        = 0.99
s_pref_e     = 0.02
l_pref_e     = 1
s_pref_f     = 1000
l_pref_f     = 100
total_steps  = 400000
decay_steps  = 4000
decay_rate   = 0.95

RANDOM_SEED  = np.random.randint(1000)
l_smoothing  = 20
getstat_bs   = 64

train_data.compute_lattice_candidate(rcut)
model = DPModel({'embed_widths':embed_widths,
                 'fit_widths':fit_widths,
                 'axis_neuron':axis_neuron,
                 'Ebias':train_data.compute_Ebias()})
batch, lattice_args = train_data.get_batch(getstat_bs)
static_args = nn.FrozenDict({'lattice': lattice_args | {'ortho':orthorombic},
                            'rcut':rcut,
                            'type_index':tuple(train_data.type_index),
                            'ntype_index':tuple(lattice_args['lattice_max']*train_data.type_index)})
model.get_stats(batch['coord'], batch['box'], static_args)
variables = model.init(random.PRNGKey(RANDOM_SEED), batch['coord'][0], batch['box'][0], static_args)
lr_scheduler = optax.exponential_decay(init_value=lr, transition_steps=decay_steps,
                    decay_rate=decay_rate, transition_begin=0, staircase=True)
optimizer = optax.adam(learning_rate=lr_scheduler, b2=beta2)
opt_state = optimizer.init(variables)
loss, loss_and_grad = model.get_loss_ef_fn()
print('System initialized with', sum(i.size for i in tree_util.tree_flatten(variables)[0]), 'parameters.')

model.params['normalizer'] *= 1.
model.params['e3norm'] = 1.

def train_step(batch, variables, opt_state, state_args, static_args):
    r = lr_scheduler(state_args['iteration']) / lr
    pref = {'e': s_pref_e*r + l_pref_e*(1-r), 'f': s_pref_f*r + l_pref_f*(1-r)}
    (loss_total, (loss_e, loss_f)), grads = loss_and_grad(variables, batch, pref, static_args)
    updates, opt_state = optimizer.update(grads, opt_state)
    variables = optax.apply_updates(variables, updates)
    state_args['loss_avg'] = state_args['loss_avg'] * (1-1/l_smoothing) + loss_total
    state_args['le_avg'] = state_args['le_avg'] * (1-1/l_smoothing) + loss_e
    state_args['lf_avg'] = state_args['lf_avg'] * (1-1/l_smoothing) + loss_f
    state_args['iteration'] += 1
    return variables, opt_state, state_args
train_step = jit(train_step, static_argnums=(4,))

state_args = {'le_avg':0., 'lf_avg':0., 'loss_avg':0., 'iteration':0}
tic = time()
# for iteration in range(0):
for iteration in range(total_steps):
    batch, lattice_args = train_data.get_batch(batch_size)
    variables, opt_state, state_args = train_step(batch, variables, opt_state, state_args, static_args)
    if iteration % 1000 == 0:
        beta = l_smoothing * (1 - (1/l_smoothing)**(iteration+1))
        print('Iter ', iteration,
              ': Loss = %.5f' % (state_args['loss_avg']/beta)**0.5,
              ' Loss_E = %.5f' % ((state_args['le_avg']/beta)**0.5/train_data.natoms),
              ' Loss_F = %.5f' % (state_args['lf_avg']/beta)**0.5,
              ' Time = %.2f' % (time()-tic))
        tic = time()

with open(save_name, 'wb') as file:
    pickle.dump({'model':model, 'variables':variables}, file)

env: XLA_PYTHON_CLIENT_PREALLOCATE=false
Starting program on device [gpu(id=0)]


Loaded data from 'polaron_data' with 28928 frames of 384 atoms.
System initialized with 171426 parameters.
Iter  0 : Loss = 6.38994  Loss_E = 0.10773  Loss_F = 0.20185  Time = 2.71
Iter  1000 : Loss = 4.56492  Loss_E = 0.06633  Loss_F = 0.14424  Time = 5.38
Iter  2000 : Loss = 3.76446  Loss_E = 0.10843  Loss_F = 0.11866  Time = 2.88
Iter  3000 : Loss = 3.45661  Loss_E = 0.10709  Loss_F = 0.10890  Time = 2.88
Iter  4000 : Loss = 3.21452  Loss_E = 0.03912  Loss_F = 0.10174  Time = 2.88
Iter  5000 : Loss = 2.87790  Loss_E = 0.03391  Loss_F = 0.09296  Time = 2.88
Iter  6000 : Loss = 2.59096  Loss_E = 0.02107  Loss_F = 0.08377  Time = 2.88
Iter  7000 : Loss = 2.62747  Loss_E = 0.03716  Loss_F = 0.08480  Time = 2.88
Iter  8000 : Loss = 2.66845  Loss_E = 0.02868  Loss_F = 0.08631  Time = 2.88
Iter  9000 : Loss = 2.25745  Loss_E = 0.01204  Loss_F = 0.07469  Time = 2.88
Iter  10000 : Loss = 2.37409  Loss_E = 0.02002  Loss_F = 0.07848  Time = 2.88
Iter  11000 : Loss = 2.23793  Loss_E = 0.02104  

KeyboardInterrupt: 

In [40]:
i = 3
e, debug = model.apply(variables, train_data.data['coord'][i], train_data.data['box'][i], static_args)
(r_NM, embed_nmC, R_4NM, G_N4C, Feat_NX, fit_n1) = debug

In [41]:
np.std(G_N4C[:,0]) , np.std(G_N4C[:,0]) / np.std(G_N4C[:,1:]), np.std(Feat_NX)

(Array(0.5569, dtype=float32),
 Array(7.99721, dtype=float32),
 Array(0.37572, dtype=float32))

In [35]:
np.std(G_N4C[:,0]) , np.std(G_N4C[:,0]) / np.std(G_N4C[:,1:]), np.std(Feat_NX)

(Array(0.42991, dtype=float32),
 Array(7.5261, dtype=float32),
 Array(0.18197, dtype=float32))

In [32]:
np.std(G_N4C[:,0]) , np.std(G_N4C[:,0]) / np.std(G_N4C[:,1:]), np.std(Feat_NX)

(Array(0.32271, dtype=float32),
 Array(7.7443, dtype=float32),
 Array(0.09806, dtype=float32))

In [29]:
np.std(G_N4C[:,0]) , np.std(G_N4C[:,0]) / np.std(G_N4C[:,1:]), np.std(Feat_NX)

(Array(0.43412, dtype=float32),
 Array(7.55578, dtype=float32),
 Array(0.18473, dtype=float32))

In [22]:
np.std(G_N4C[:,0]) , np.std(G_N4C[:,0]) / np.std(G_N4C[:,1:]), np.std(Feat_NX)

(Array(0.4334, dtype=float32),
 Array(8.00204, dtype=float32),
 Array(0.19608, dtype=float32))

In [18]:
np.std(G_N4C[:,0]) , np.std(G_N4C[:,0]) / np.std(G_N4C[:,1:]), np.std(Feat_NX)

(Array(0.31626, dtype=float32),
 Array(7.32899, dtype=float32),
 Array(0.10737, dtype=float32))

In [8]:
model.params['normalizer']

Array(45.24406, dtype=float32)

In [9]:
model.params['srstd']

[Array(0.15338, dtype=float32), Array(0.13168, dtype=float32)]