In [1]:
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
import jax.numpy as jnp
import numpy as np
from jax import jit, grad, random, tree_util
import jax, optax
import flax.linen as nn
from deepmd_jax.data import DataSystem
from deepmd_jax.model import DPModel
import pickle
from time import time
# jax.config.update('jax_enable_x64', True)
jax.config.update('jax_default_matmul_precision', 'float32')
np.set_printoptions(precision=5, suppress=True)
print('Starting program on device', jax.devices())

save_name    = 'model_polaron.pkl'
train_data   = DataSystem('polaron_data', ['coord', 'box', 'force', 'energy'])
rcut         = 7.0
embed_widths = [32, 64, 128]
fit_widths   = [128, 128, 128]
axis_neuron  = 16
batch_size   = 1
lr           = 0.002 
beta2        = 0.999
s_pref_e     = 0.02
l_pref_e     = 1
s_pref_f     = 1000
l_pref_f     = 1
total_steps  = 2001
decay_steps  = 4000
decay_rate   = 0.95

RANDOM_SEED  = np.random.randint(1000)
l_smoothing  = 20
getstat_bs   = 100

train_data.compute_lattice_candidate(rcut)
model = DPModel({'embed_widths':embed_widths,
                 'fit_widths':fit_widths,
                 'axis_neuron':axis_neuron,
                 'Ebias':train_data.compute_Ebias()})
batch, lattice_args = train_data.get_batch(getstat_bs)
static_args = nn.FrozenDict({'lattice':lattice_args,
                            'rcut':rcut,
                            'type_index':tuple(train_data.type_index),
                            'ntype_index':tuple(lattice_args['lattice_max']*train_data.type_index)})
model.get_stats(batch['coord'], batch['box'], static_args)
variables = model.init(random.PRNGKey(RANDOM_SEED), batch['coord'][0], batch['box'][0], static_args)
lr_scheduler = optax.exponential_decay(init_value=lr, transition_steps=decay_steps,
                    decay_rate=decay_rate, transition_begin=0, staircase=True)
optimizer = optax.adam(learning_rate=lr_scheduler, b2=beta2)
opt_state = optimizer.init(variables)
loss, loss_and_grad = model.get_loss_ef_fn()
print('System initialized with', sum(i.size for i in tree_util.tree_flatten(variables)[0]), 'parameters.')

def train_step(batch, variables, opt_state, state_args, static_args):
    r = lr_scheduler(state_args['iteration']) / lr
    pref = {'e': s_pref_e*r + l_pref_e*(1-r), 'f': s_pref_f*r + l_pref_f*(1-r)}
    (loss_total, (loss_e, loss_f)), grads = loss_and_grad(variables, batch, pref, static_args)
    updates, opt_state = optimizer.update(grads, opt_state)
    variables = optax.apply_updates(variables, updates)
    state_args['loss_avg'] = state_args['loss_avg'] * (1-1/l_smoothing) + loss_total
    state_args['le_avg'] = state_args['le_avg'] * (1-1/l_smoothing) + loss_e
    state_args['lf_avg'] = state_args['lf_avg'] * (1-1/l_smoothing) + loss_f
    state_args['iteration'] += 1
    return variables, opt_state, state_args
train_step = jit(train_step, static_argnums=(4,))

state_args = {'le_avg':0., 'lf_avg':0., 'loss_avg':0., 'iteration':0}
tic = time()
for iteration in range(total_steps):
    batch, lattice_args = train_data.get_batch(batch_size)
    variables, opt_state, state_args = train_step(batch, variables, opt_state, state_args, static_args)
    if iteration % 1000 == 0:
        beta = l_smoothing * (1 - (1/l_smoothing)**(iteration+1))
        print('Iter ', iteration,
              ': Loss = %.5f' % (state_args['loss_avg']/beta)**0.5,
              ' Loss_E = %.5f' % ((state_args['le_avg']/beta)**0.5/train_data.natoms),
              ' Loss_F = %.5f' % (state_args['lf_avg']/beta)**0.5,
              ' Time = %.2f' % (time()-tic))
        tic = time()

with open(save_name, 'wb') as file:
    pickle.dump({'model':model, 'variables':variables}, file)

env: XLA_PYTHON_CLIENT_PREALLOCATE=false
Starting program on device [gpu(id=0)]
Loaded data from 'polaron_data' with 29419 frames and 384 atoms.
System initialized with 633346 parameters.
Iter  0 : Loss = 5.94454  Loss_E = 0.03051  Loss_F = 0.18054  Time = 4.51
Iter  1000 : Loss = 23.45628  Loss_E = 0.03403  Loss_F = 0.73945  Time = 9.84
Iter  2000 : Loss = 11.65014  Loss_E = 0.09187  Loss_F = 0.33292  Time = 6.63


KeyboardInterrupt: 

In [2]:
with jax.profiler.trace("./trace", create_perfetto_link=True):
    for i in range(4):
        batch, _ = train_data.get_batch(batch_size)
        variables, opt_state, state_args = train_step(batch, variables, opt_state, state_args, static_args)
    jax.block_until_ready(variables)

2023-09-25 17:13:36.706241: E external/xla/xla/python/profiler/internal/python_hooks.cc:398] Can't import tensorflow.python.profiler.trace
2023-09-25 17:13:36.795763: E external/xla/xla/python/profiler/internal/python_hooks.cc:398] Can't import tensorflow.python.profiler.trace


Open URL in browser: https://ui.perfetto.dev/#!/?url=http://127.0.0.1:9001/perfetto_trace.json.gz


KeyboardInterrupt: 