In [None]:
# Linalg support
import numpy as onp

# Jax imports
import jax
from jax import lax
from jax import jit, vmap, grad
import jax.numpy as np
from jax import random

# ML imports
import optax

# Jax-md imports
from jax_md import energy, space, simulate, quantity

# Plotting.
import matplotlib.pyplot as plt

In [None]:
# Load up your saved data here.
positions = np.load("trajectory.npy")
energies = np.load("energy.npy")
forces = np.load("forces.npy")

In [None]:
# Split the data into training and testing sets.
n_train_points = 800
train_indices = onp.random.choice(np.arange(energies.shape[0]), n_train_points, replace=False)
test_indices = np.array([i for i in np.arange(energies.shape[0]) if i not in train_indices])

train_positions = positions[train_indices]
train_energies = energies[train_indices]
train_forces = forces[train_indices]

test_positions = positions[test_indices]
test_energies = energies[test_indices]
test_forces = forces[test_indices]

In [None]:
# Normalize the energies.
energy_mean = np.mean(train_energies)
energy_std = np.std(train_energies)

train_energies = (train_energies - energy_mean) / energy_std
test_energies = (test_energies - energy_mean) / energy_std

In [None]:
box_size = 14.474693  # The size of the simulation region, adjust if necessary.
displacement, shift = space.periodic(box_size)

In [None]:
# Define the Behler Parrinello Neural Network.
neighbor_fn, init_fn, energy_fn = energy.behler_parrinello_neighbor_list(
    displacement, box_size, dr_threshold=0.1, sym_kwargs={"cutoff_distance": 4.0})

# Neighbour list computation, should improve performance.
neighbor = neighbor_fn.allocate(train_positions[0], extra_capacity=6)

print('Allocating space for at most {} edges'.format(neighbor.idx.shape[1]))

In [None]:
@jit
def train_energy_fn(params, R):
  _neighbor = neighbor.update(R)
  return energy_fn(params, R, _neighbor)

# Vectorize over states, not parameters.
vectorized_energy_fn = vmap(train_energy_fn, (None, 0))

grad_fn = grad(train_energy_fn, argnums=1)
force_fn = lambda params, R, **kwargs: -grad_fn(params, R)
vectorized_force_fn = vmap(force_fn, (None, 0))

In [None]:
# Initialize the neural network parameters
key = random.PRNGKey(0)
params = init_fn(key, train_positions[0], neighbor)

In [None]:
# Look at the priors over the data before training. What do you see?

predicted_energies = vmap(train_energy_fn, (None, 0))(params, train_positions)
predicted_forces = vectorized_force_fn(params, test_positions)

fig, ax = plt.subplots(1, 2, figsize=(8, 8))

# Energy priors
ax[0].plot(train_energies, predicted_energies, 'o')
ax[0].plot(train_energies, train_energies, 'k--')

# Force priors
ax[1].plot(test_forces.flatten(), predicted_forces.flatten(), 'o')
ax[1].plot(test_forces.flatten(), test_forces.flatten(), 'k--')

plt.show()

In [None]:
# Define the loss functions.
@jit
def energy_loss(params, R, energy_targets):
  return np.mean((vectorized_energy_fn(params, R) - energy_targets) ** 2)

@jit
def force_loss(params, R, force_targets):
  dforces = vectorized_force_fn(params, R) - force_targets
  return np.mean(np.sum(dforces ** 2, axis=(1, 2)))

@jit
def loss(params, R, targets):
  return energy_loss(params, R, targets[0]) + force_loss(params, R, targets[1])

In [None]:
opt = optax.chain(
  optax.clip_by_global_norm(1.0), optax.adam(1e-3)
)

@jit
def update_step(params, opt_state, R, labels):
  updates, opt_state = opt.update(grad(loss)(params, R, labels),
                                  opt_state)
  return optax.apply_updates(params, updates), opt_state

@jit
def update_epoch(params_and_opt_state, batches):
  def inner_update(params_and_opt_state, batch):
    params, opt_state = params_and_opt_state
    b_xs, b_labels = batch

    return update_step(params, opt_state, b_xs, b_labels), 0
  return lax.scan(inner_update, params_and_opt_state, batches)[0]

In [None]:
dataset_size = train_positions.shape[0]
batch_size = 128

lookup = onp.arange(dataset_size)
onp.random.shuffle(lookup)

@jit
def make_batches(lookup):
  batch_Rs = []
  batch_Es = []
  batch_Fs = []

  for i in range(0, len(lookup), batch_size):
    if i + batch_size > len(lookup):
      break

    idx = lookup[i:i + batch_size]

    batch_Rs += [train_positions[idx]]
    batch_Es += [train_energies[idx]]
    batch_Fs += [train_forces[idx]]

  return np.stack(batch_Rs), np.stack(batch_Es), np.stack(batch_Fs)

batch_Rs, batch_Es, batch_Fs = make_batches(lookup)

In [None]:
train_epochs = 5000  # Adjust as necessary.

opt_state = opt.init(params)

train_energy_error = []
test_energy_error = []

for iteration in range(train_epochs):
  train_energy_error += [float(np.sqrt(energy_loss(params, batch_Rs[0], batch_Es[0])))]
  test_energy_error += [float(np.sqrt(energy_loss(params, test_positions, test_energies)))]
 
  params, opt_state = update_epoch((params, opt_state), 
                                   (batch_Rs, (batch_Es, batch_Fs)))

  onp.random.shuffle(lookup)
  batch_Rs, batch_Es, batch_Fs = make_batches(lookup)

In [None]:
fig, ax = plt.subplots(1, 2)

predicted_energies = vectorized_energy_fn(params, test_positions)
ax[0].plot(test_energies, predicted_energies, 'o')
ax[0].plot(test_energies, test_energies, '--')


predicted_forces = vectorized_force_fn(params, test_positions)
ax[1].plot(test_forces.reshape((-1,)),
         predicted_forces.reshape((-1,)), 
         'o')
ax[1].plot(
    test_forces.reshape((-1,)),
    test_forces.reshape((-1,)),
)

plt.show()

In [None]:
def compute_energy_metrics(energy_predictions: np.ndarray, energy_targets: np.ndarray):
    

In [None]:
# Save your model parameters.
