In [1]:
import functools
import e3x
from flax import linen as nn
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax

# Disable future warnings.
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
jax.devices()

[cuda(id=0)]

# Moment of inertia tensor

In [2]:
def calculate_moment_of_inertia_tensor(masses, positions):
  diag = jnp.sum(positions**2, axis=-1)[..., None, None]*jnp.eye(3)
  outer = positions[..., None, :] * positions[..., :, None]
  return jnp.sum(masses[..., None, None] * (diag - outer), axis=-3)

def generate_datasets(key, num_train=1000, num_valid=100, num_points=10, min_mass=0.0, max_mass=1.0, stdev=1.0):
  # Generate random keys.
  train_position_key, train_masses_key, valid_position_key, valid_masses_key = jax.random.split(key, num=4)

  # Draw random point masses with random positions.
  train_positions = stdev * jax.random.normal(train_position_key,  shape=(num_train, num_points, 3))
  train_masses = jax.random.uniform(train_masses_key, shape=(num_train, num_points), minval=min_mass, maxval=max_mass)
  valid_positions = stdev * jax.random.normal(valid_position_key,  shape=(num_valid, num_points, 3))
  valid_masses = jax.random.uniform(valid_masses_key, shape=(num_valid, num_points), minval=min_mass, maxval=max_mass)

  # Calculate moment of inertia tensors.
  train_inertia_tensor = calculate_moment_of_inertia_tensor(train_masses, train_positions)
  valid_inertia_tensor = calculate_moment_of_inertia_tensor(valid_masses, valid_positions)

  # Return final train and validation datasets.
  train_data = dict(positions=train_positions, masses=train_masses, inertia_tensor=train_inertia_tensor)
  valid_data = dict(positions=valid_positions, masses=valid_masses, inertia_tensor=valid_inertia_tensor)
  return train_data, valid_data

In [3]:
class Model(nn.Module):
  features = 8
  max_degree = 1

  @nn.compact
  def __call__(self, masses, positions):  # Shapes (..., N) and (..., N, 3).
    # 1. Initialize features.
    x = jnp.concatenate((masses[..., None], positions), axis=-1) # Shape (..., N, 4).
    x = x[..., None, :, None]  # Shape (..., N, 1, 4, 1).

    # 2. Apply transformations.
    x = e3x.nn.Dense(features=self.features)(x)  # Shape (..., N, 1, 4, features).
    x = e3x.nn.TensorDense(max_degree=self.max_degree)(x)  # Shape (..., N, 2, (max_degree+1)**2, features).
    x = e3x.nn.TensorDense(  # Shape (..., N, 2, 9, 1).
        features=1,
        max_degree=2,
    )(x)
    # Try it: Zero-out irrep of degree 1 to only produce symmetric output tensors.
    # x = x.at[..., :, 1:4, :].set(0)

    # 3. Collect even irreps from feature channel 0 and sum over contributions from individual points.
    x = jnp.sum(x[..., 0, :, 0], axis=-2)  # Shape (..., (max_degree+1)**2).

    # 4. Convert output irreps to 3x3 matrix and return.
    cg = e3x.so3.clebsch_gordan(max_degree1=1, max_degree2=1, max_degree3=2)  # Shape (4, 4, 9).
    y = jnp.einsum('...l,nml->...nm', x, cg[1:, 1:, :])  # Shape (..., 3, 3).
    return y

In [4]:
def mean_squared_loss(prediction, target):
  return jnp.mean(optax.l2_loss(prediction, target))

In [5]:
@functools.partial(jax.jit, static_argnames=('model_apply', 'optimizer_update'))
def train_step(model_apply, optimizer_update, batch, opt_state, params):
  def loss_fn(params):
    inertia_tensor = model_apply(params, batch['masses'], batch['positions'])
    loss = mean_squared_loss(inertia_tensor, batch['inertia_tensor'])
    return loss
  loss, grad = jax.value_and_grad(loss_fn)(params)
  updates, opt_state = optimizer_update(grad, opt_state, params)
  params = optax.apply_updates(params, updates)
  return params, opt_state, loss

@functools.partial(jax.jit, static_argnames=('model_apply',))
def eval_step(model_apply, batch, params):
  inertia_tensor = model_apply(params, batch['masses'], batch['positions'])
  loss = mean_squared_loss(inertia_tensor, batch['inertia_tensor'])
  return loss

def train_model(key, model, train_data, valid_data, num_epochs, learning_rate, batch_size):
  # Initialize model parameters and optimizer state.
  key, init_key = jax.random.split(key)
  optimizer = optax.adam(learning_rate)
  params = model.init(init_key, train_data['masses'][0:1], train_data['positions'][0:1])
  opt_state = optimizer.init(params)

  # Determine the number of training steps per epoch.
  train_size = len(train_data['masses'])
  steps_per_epoch = train_size//batch_size

  # Train for 'num_epochs' epochs.
  for epoch in range(1, num_epochs + 1):
    # Draw random permutations for fetching batches from the train data.
    key, shuffle_key = jax.random.split(key)
    perms = jax.random.permutation(shuffle_key, train_size)
    perms = perms[:steps_per_epoch * batch_size]  # Skip the last batch (if incomplete).
    perms = perms.reshape((steps_per_epoch, batch_size))

    # Loop over all batches.
    train_loss = 0.0  # For keeping a running average of the loss.
    for i, perm in enumerate(perms):
      batch = {k: v[perm, ...] for k, v in train_data.items()}
      params, opt_state, loss = train_step(
          model_apply=model.apply,
          optimizer_update=optimizer.update,
          batch=batch,
          opt_state=opt_state,
          params=params
      )
      train_loss += (loss - train_loss)/(i+1)

    # Evaluate on the test set after each training epoch.
    valid_loss = eval_step(
        model_apply=model.apply,
        batch=valid_data,
        params=params
    )

    # Print progress.
    print(f"epoch {epoch : 4d} train loss {train_loss : 8.6f} valid loss {valid_loss : 8.6f}")

  # Return final model parameters.
  return params

In [6]:
# Initialize PRNGKey for random number generation.
key = jax.random.PRNGKey(0)

# Generate train and test datasets.
key, data_key = jax.random.split(key)
train_data, valid_data = generate_datasets(data_key)

# Define training hyperparameters.
learning_rate = 0.002
num_epochs = 100
batch_size = 32

In [7]:
'''print(train_data['masses'][0:1].shape)
print(train_data['positions'][0:1].shape)'''

"print(train_data['masses'][0:1].shape)\nprint(train_data['positions'][0:1].shape)"

In [10]:
valid_data["positions"]

Array([[[-0.494029  ,  0.08081957,  0.58778834],
        [-0.6221061 ,  1.006088  ,  0.6207147 ],
        [ 0.5767194 , -0.40320373, -0.9176049 ],
        ...,
        [-0.59026754,  0.19322182, -0.01509924],
        [ 0.5556615 , -1.4325528 ,  0.3230501 ],
        [-1.0281883 ,  1.8839624 , -1.7259156 ]],

       [[ 0.20453966,  0.41676185,  1.6043137 ],
        [-1.8780847 ,  1.7667238 , -0.5527455 ],
        [-1.2981076 , -0.17131835,  0.19391784],
        ...,
        [-0.22122882, -0.9467521 ,  1.0407714 ],
        [-0.73827523, -0.212966  ,  0.11027703],
        [-0.39328244, -1.3861412 ,  0.3157856 ]],

       [[ 1.2104156 , -0.2244492 ,  1.456738  ],
        [ 0.48320487, -1.1021247 ,  0.21420795],
        [ 1.1749548 ,  0.6107907 ,  1.1698654 ],
        ...,
        [-0.08263119,  0.5027424 ,  1.3183354 ],
        [-0.9717356 ,  1.7407076 ,  2.3585498 ],
        [-1.2621491 ,  0.57098824, -0.36478522]],

       ...,

       [[-1.311947  , -0.9668753 , -0.42978182],
        [ 0

In [11]:
key, train_key = jax.random.split(key)
model = Model()
params = train_model(
  key=train_key,
  model=model,
  train_data=train_data,
  valid_data=valid_data,
  num_epochs=num_epochs,
  learning_rate=learning_rate,
  batch_size=batch_size,
)

epoch    1 train loss  1.363581 valid loss  0.649740
epoch    2 train loss  0.471599 valid loss  0.361932
epoch    3 train loss  0.355637 valid loss  0.331024
epoch    4 train loss  0.335985 valid loss  0.314108
epoch    5 train loss  0.313791 valid loss  0.308226
epoch    6 train loss  0.295876 valid loss  0.261249
epoch    7 train loss  0.269312 valid loss  0.236298
epoch    8 train loss  0.248002 valid loss  0.230964
epoch    9 train loss  0.231901 valid loss  0.205514
epoch   10 train loss  0.225195 valid loss  0.209436
epoch   11 train loss  0.207692 valid loss  0.189051
epoch   12 train loss  0.200844 valid loss  0.185525
epoch   13 train loss  0.190828 valid loss  0.175023
epoch   14 train loss  0.178839 valid loss  0.169128
epoch   15 train loss  0.161747 valid loss  0.144348
epoch   16 train loss  0.147330 valid loss  0.133286
epoch   17 train loss  0.129805 valid loss  0.109766
epoch   18 train loss  0.106006 valid loss  0.088292
epoch   19 train loss  0.090337 valid loss  0.

In [12]:
i = 0
masses, positions, target = valid_data['masses'][i], valid_data['positions'][i], valid_data['inertia_tensor'][i]
prediction = model.apply(params, masses, positions)
print('target')
print(target)
print('prediction')
print(prediction)
print('mean squared error', jnp.mean((prediction-target)**2))

target
[[ 6.013584    1.6290329  -0.17871115]
 [ 1.6290329   4.8540945   0.73430276]
 [-0.17871115  0.73430276  6.1854286 ]]
prediction
[[ 6.0244517   1.6350874  -0.1778086 ]
 [ 1.6350869   4.856238    0.73851514]
 [-0.1778061   0.7385125   6.189898  ]]
mean squared error 2.8120308e-05


In [22]:
i = 0
masses, positions, target = (
    valid_data["masses"][i],
    valid_data["positions"][i].at[:, 0].add(1000),
    valid_data["inertia_tensor"][i],
)
prediction = model.apply(params, masses, positions)
print("target")
print(target)
print("prediction")
print(prediction)
print("mean squared error", jnp.mean((prediction - target) ** 2))

target
[[ 6.013584    1.6290329  -0.17871115]
 [ 1.6290329   4.8540945   0.73430276]
 [-0.17871115  0.73430276  6.1854286 ]]
prediction
[[ 1.9709243e+09  4.8340735e+02  4.7554764e+02]
 [-3.7784225e+01  1.9760028e+09  9.1536945e-01]
 [ 2.5883383e+02  5.6165820e-01  1.9760037e+09]]
mean squared error 1.2993022e+18


In [17]:
valid_data["positions"][i]

Array([[-0.494029  ,  0.08081957,  0.58778834],
       [-0.6221061 ,  1.006088  ,  0.6207147 ],
       [ 0.5767194 , -0.40320373, -0.9176049 ],
       [-1.2885472 , -0.13124795, -0.3965258 ],
       [ 0.23475617, -0.8042504 ,  0.161851  ],
       [ 1.0541666 , -0.06808183, -0.14599147],
       [ 1.5903791 ,  1.1758258 ,  0.75588095],
       [-0.59026754,  0.19322182, -0.01509924],
       [ 0.5556615 , -1.4325528 ,  0.3230501 ],
       [-1.0281883 ,  1.8839624 , -1.7259156 ]], dtype=float32)

In [23]:
positions_trans = valid_data["positions"][i].at[:, 0].add(1000)
positions_trans

Array([[ 9.99505981e+02,  8.08195695e-02,  5.87788343e-01],
       [ 9.99377869e+02,  1.00608802e+00,  6.20714724e-01],
       [ 1.00057672e+03, -4.03203726e-01, -9.17604923e-01],
       [ 9.98711426e+02, -1.31247953e-01, -3.96525800e-01],
       [ 1.00023474e+03, -8.04250419e-01,  1.61851004e-01],
       [ 1.00105414e+03, -6.80818260e-02, -1.45991474e-01],
       [ 1.00159039e+03,  1.17582583e+00,  7.55880952e-01],
       [ 9.99409729e+02,  1.93221822e-01, -1.50992395e-02],
       [ 1.00055566e+03, -1.43255281e+00,  3.23050112e-01],
       [ 9.98971802e+02,  1.88396239e+00, -1.72591555e+00]],      dtype=float32)

# Dataset 

In [31]:
filename = "test_data.npz"
dataset= np.load(filename)
for key in dataset.keys():
    print(key)
print('Dipole moment shape array',dataset['D'].shape)
print('Dipole moment units', dataset['D_units'])

print('Atomic numbers', dataset['z'])

type
R
R_units
z
E
E_units
F
F_units
D
D_units
Q
name
README
theory
Dipole moment shape array (5042, 3)
Dipole moment units eAng
Atomic numbers [23 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14]


# Dipole Moment

In [32]:
def prepare_datasets(filename, key, num_train, num_valid):
    # Load the dataset.
    dataset = np.load(filename)
    num_data = len(dataset["E"])
    Z=jnp.full(16,14)
    Z=jnp.append(Z,23)
    Z=jnp.expand_dims(Z,axis=0)
    Z=jnp.repeat(Z, num_data, axis=0)
    num_draw = num_train + num_valid
    if num_draw > num_data:
        raise RuntimeError(
            f"datasets only contains {num_data} points, requested num_train={num_train}, num_valid={num_valid}"
        )

    # Randomly draw train and validation sets from dataset.
    choice = np.asarray(
        jax.random.choice(key, num_data, shape=(num_draw,), replace=False)
    )
    train_choice = choice[:num_train]
    valid_choice = choice[num_train:]

    # Collect and return train and validation sets.
    train_data = dict(
        #energy=jnp.asarray(dataset["E"][train_choice, 0] - mean_energy),
        #forces=jnp.asarray(dataset["F"][train_choice]),
        dipole_moment= jnp.asarray(dataset["D"][train_choice]),
        atomic_numbers=jnp.asarray(Z[train_choice]),
        # atomic_numbers=jnp.asarray(z_hack),
        positions=jnp.asarray(dataset["R"][train_choice]),
    )
    valid_data = dict(
        #energy=jnp.asarray(dataset["E"][valid_choice, 0] - mean_energy),
        #forces=jnp.asarray(dataset["F"][valid_choice]),
        atomic_numbers=jnp.asarray(Z[valid_choice]),
        dipole_moment= jnp.asarray(dataset["D"][valid_choice]),
        # atomic_numbers=jnp.asarray(z_hack),
        positions=jnp.asarray(dataset["R"][valid_choice]),
    )
    return train_data, valid_data

In [33]:
key = jax.random.PRNGKey(0)
num_train=1000
num_val=200
train_data,valid_data=prepare_datasets(filename,key, num_train,num_val)
print(train_data['dipole_moment'].shape)

(1000, 3)


In [34]:
class Dipole_Moment(nn.Module):
  #features = 1
  #max_degree = 1
  @nn.compact
  def __call__(self,atomic_numbers, positions):  # Shapes (..., N) and (..., N, 3).
    # 1. Initialize features.
    x = jnp.concatenate((atomic_numbers[...,None], positions), axis=-1) # Shape (..., N, 4).
    #print("Initial shape:", x.shape)
    x = x[..., None, :, None]  # Shape (..., N, 1, 3, 1).
    #print("x shape:", x.shape)
    # 2. Apply transformations.
    x = e3x.nn.Dense(features=1)(x) 
    #print("After Dense layer:", x.shape)
    x = e3x.nn.TensorDense(max_degree=1)(x)  
    #print("After TensorDense layer:", x.shape)
    x=jnp.sum(x, axis=-4) 
    #print("After sum:", x.shape)
    y = x[..., 1, 1:4, 0]
    #print("After slicing:", y.shape)

    return y

In [35]:
!pip list requierement


os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.



Package                            Version
---------------------------------- ---------------------
absl-py                            2.1.0
alembic                            1.13.1
aniso8601                          9.0.1
ase                                3.22.1
asttokens                          2.4.1
blinker                            1.8.2
cachetools                         5.3.3
certifi                            2024.2.2
charset-normalizer                 3.3.2
chex                               0.1.86
click                              8.1.7
cloudpickle                        3.0.0
comm                               0.2.2
contourpy                          1.2.1
cycler                             0.12.1
debugpy                            1.8.1
decorator                          5.1.1
Deprecated                         1.2.14
docker                             7.0.0
e3x                                1.0.2
entrypoints                        0.4
etils                            

In [84]:
'''dm_model = Dipole_Moment()
key = jax.random.PRNGKey(0)

# Generate train and test datasets.
key, data_key = jax.random.split(key)
train_data, valid_data = generate_datasets(data_key)
params = dm_model.init(key, train_data['masses'][0:1], train_data['positions'][0:1])
moment=dm_model.apply(params,train_data['masses'][0:1], train_data['positions'][0:1])
print(moment.shape)'''

"dm_model = Dipole_Moment()\nkey = jax.random.PRNGKey(0)\n\n# Generate train and test datasets.\nkey, data_key = jax.random.split(key)\ntrain_data, valid_data = generate_datasets(data_key)\nparams = dm_model.init(key, train_data['masses'][0:1], train_data['positions'][0:1])\nmoment=dm_model.apply(params,train_data['masses'][0:1], train_data['positions'][0:1])\nprint(moment.shape)"

In [36]:
@functools.partial(jax.jit, static_argnames=('model_apply', 'optimizer_update'))
def train_step(model_apply, optimizer_update, batch, opt_state, params):
  def loss_fn(params):
    dipole_moment = model_apply(params,batch['atomic_numbers'] ,batch['positions'])
    loss = mean_squared_loss(dipole_moment, batch['dipole_moment'])
    return loss
  loss, grad = jax.value_and_grad(loss_fn)(params)
  updates, opt_state = optimizer_update(grad, opt_state, params)
  params = optax.apply_updates(params, updates)
  return params, opt_state, loss

@functools.partial(jax.jit, static_argnames=('model_apply',))
def eval_step(model_apply, batch, params):
  dipole_moment = model_apply(params,batch['atomic_numbers'],batch['positions'])
  loss = mean_squared_loss(dipole_moment, batch['dipole_moment'])
  return loss

def train_model(key, model, train_data, valid_data, num_epochs, learning_rate, batch_size):
  # Initialize model parameters and optimizer state.
  key, init_key = jax.random.split(key)
  optimizer = optax.adam(learning_rate)
  params = model.init(init_key,train_data['atomic_numbers'][0:1],train_data['positions'][0:1])
  opt_state = optimizer.init(params)

  # Determine the number of training steps per epoch.
  train_size = len(train_data['positions'])
  steps_per_epoch = train_size//batch_size

  # Train for 'num_epochs' epochs.
  list_train_loss = []
  list_val_loss = []
  for epoch in range(1, num_epochs + 1):
    # Draw random permutations for fetching batches from the train data.
    key, shuffle_key = jax.random.split(key)
    perms = jax.random.permutation(shuffle_key, train_size)
    perms = perms[:steps_per_epoch * batch_size]  # Skip the last batch (if incomplete).
    perms = perms.reshape((steps_per_epoch, batch_size))

    # Loop over all batches.
    train_loss = 0.0  # For keeping a running average of the loss.

    for i, perm in enumerate(perms):
      batch = {k: v[perm, ...] for k, v in train_data.items()}
      #print(batch['dipole_moment'].shape)

      params, opt_state, loss = train_step(
          model_apply=model.apply,
          optimizer_update=optimizer.update,
          batch=batch,
          opt_state=opt_state,
          params=params
      )
      train_loss += (loss - train_loss)/(i+1)
      
    # Evaluate on the test set after each training epoch.
    valid_loss = eval_step(
        model_apply=model.apply,
        batch=valid_data,
        params=params
    )
    list_val_loss.append(valid_loss)
    list_train_loss.append(train_loss)
    # Print progress.
    print(f"epoch {epoch : 4d} train loss {train_loss : 8.6f} valid loss {valid_loss : 8.6f}")

  # Return final model parameters.
  return params ,list_train_loss , list_val_loss

In [37]:
# Initialize PRNGKey for random number generation.
key = jax.random.PRNGKey(0)
num_train=3000
num_val=200
train_data, valid_data = prepare_datasets(filename,key, num_train,num_val)

# Define training hyperparameters.
learning_rate = 0.002
num_epochs = 10000
batch_size = 512

In [38]:
'''print(train_data['positions'][0:1].shape)
print(train_data['atomic_numbers'][0:1].shape)'''

"print(train_data['positions'][0:1].shape)\nprint(train_data['atomic_numbers'][0:1].shape)"

In [39]:
key, train_key = jax.random.split(key)
model = Dipole_Moment()
params, list_train_loss, list_val_loss = train_model(
    key=train_key,
    model=model,
    train_data=train_data,
    valid_data=valid_data,
    num_epochs=num_epochs,
    learning_rate=learning_rate,
    batch_size=batch_size,
)

epoch    1 train loss  35.072781 valid loss  28.456921
epoch    2 train loss  24.920635 valid loss  20.241196
epoch    3 train loss  17.798740 valid loss  14.573094
epoch    4 train loss  12.910925 valid loss  10.720650
epoch    5 train loss  9.593145 valid loss  8.105575
epoch    6 train loss  7.336629 valid loss  6.317753
epoch    7 train loss  5.781274 valid loss  5.068070
epoch    8 train loss  4.688513 valid loss  4.175456
epoch    9 train loss  3.900592 valid loss  3.522735
epoch   10 train loss  3.315850 valid loss  3.027928
epoch   11 train loss  2.870950 valid loss  2.645716
epoch   12 train loss  2.521299 valid loss  2.340809
epoch   13 train loss  2.238468 valid loss  2.092879
epoch   14 train loss  2.009790 valid loss  1.885348
epoch   15 train loss  1.815598 valid loss  1.709185
epoch   16 train loss  1.648358 valid loss  1.557613
epoch   17 train loss  1.504528 valid loss  1.424069
epoch   18 train loss  1.378565 valid loss  1.306026
epoch   19 train loss  1.264315 valid 

In [40]:
import plotly.graph_objs as go
import plotly.io as pio
from typing import List

def create_loss_plot(
    train_loss: List[np.ndarray],
    val_loss: List[np.ndarray],
    train_label: str,
    val_label: str,
    title: str,
    filename: str,
) -> None:
    """
    Create a Plotly figure with training and validation loss curves and save it as an HTML file.

    Args:
        train_loss (List[np.ndarray]): List of training loss values.
        val_loss (List[np.ndarray]): List of validation loss values.
        train_label (str): Label for the training loss curve.
        val_label (str): Label for the validation loss curve.
        title (str): Title of the plot.
        filename (str): Filename to save the HTML file.

    Returns:
        None
    """
    train_loss_list = [float(loss) for loss in train_loss]
    val_loss_list = [float(loss) for loss in val_loss]

    trace_train = go.Scatter(y=train_loss_list, mode="lines", name=train_label)
    trace_val = go.Scatter(y=val_loss_list, mode="lines", name=val_label)

    fig = go.Figure()
    fig.add_trace(trace_train)
    fig.add_trace(trace_val)
    fig.update_layout(
        title=title, xaxis_title="Epoch", yaxis_title="Loss", legend_title="Legend"
    )
    pio.write_html(fig, filename)


create_loss_plot(
    list_train_loss,
    list_val_loss,
    "Training Loss",
    "Validation Loss",
    "Training vs Validation Loss (Train)",
    "train_vs_val_train_dipole_moment.html",
)

In [41]:
i = 45
Z, positions, target = valid_data['atomic_numbers'][i], valid_data['positions'][i], valid_data['dipole_moment'][i]
prediction = model.apply(params, Z, positions)

print('target')
print(target)
print('prediction')
print(prediction)
print('mean squared error', jnp.mean((prediction-target)**2))

target
[ 1.5010834  -0.35230795  1.6916788 ]
prediction
[ 1.414227   -0.34050953  1.760156  ]
mean squared error 0.0041241227


In [None]:
positions_dst = e3x.ops.gather_dst(positions, dst_idx=dst_idx)