In [1]:
import functools
import e3x
from flax import linen as nn
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax

# Disable future warnings.
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

# Moment of inertia tensor

In [2]:
def calculate_moment_of_inertia_tensor(masses, positions):
  diag = jnp.sum(positions**2, axis=-1)[..., None, None]*jnp.eye(3)
  outer = positions[..., None, :] * positions[..., :, None]
  return jnp.sum(masses[..., None, None] * (diag - outer), axis=-3)

def generate_datasets(key, num_train=1000, num_valid=100, num_points=10, min_mass=0.0, max_mass=1.0, stdev=1.0):
  # Generate random keys.
  train_position_key, train_masses_key, valid_position_key, valid_masses_key = jax.random.split(key, num=4)

  # Draw random point masses with random positions.
  train_positions = stdev * jax.random.normal(train_position_key,  shape=(num_train, num_points, 3))
  train_masses = jax.random.uniform(train_masses_key, shape=(num_train, num_points), minval=min_mass, maxval=max_mass)
  valid_positions = stdev * jax.random.normal(valid_position_key,  shape=(num_valid, num_points, 3))
  valid_masses = jax.random.uniform(valid_masses_key, shape=(num_valid, num_points), minval=min_mass, maxval=max_mass)

  # Calculate moment of inertia tensors.
  train_inertia_tensor = calculate_moment_of_inertia_tensor(train_masses, train_positions)
  valid_inertia_tensor = calculate_moment_of_inertia_tensor(valid_masses, valid_positions)

  # Return final train and validation datasets.
  train_data = dict(positions=train_positions, masses=train_masses, inertia_tensor=train_inertia_tensor)
  valid_data = dict(positions=valid_positions, masses=valid_masses, inertia_tensor=valid_inertia_tensor)
  return train_data, valid_data

In [54]:
class Model(nn.Module):
  features = 8
  max_degree = 1

  @nn.compact
  def __call__(self, masses, positions):  # Shapes (..., N) and (..., N, 3).
    # 1. Initialize features.
    x = jnp.concatenate((masses[..., None], positions), axis=-1) # Shape (..., N, 4).
    x = x[..., None, :, None]  # Shape (..., N, 1, 4, 1).

    # 2. Apply transformations.
    x = e3x.nn.Dense(features=self.features)(x)  # Shape (..., N, 1, 4, features).
    x = e3x.nn.TensorDense(max_degree=self.max_degree)(x)  # Shape (..., N, 2, (max_degree+1)**2, features).
    x = e3x.nn.TensorDense(  # Shape (..., N, 2, 9, 1).
        features=1,
        max_degree=2,
    )(x)
    # Try it: Zero-out irrep of degree 1 to only produce symmetric output tensors.
    # x = x.at[..., :, 1:4, :].set(0)

    # 3. Collect even irreps from feature channel 0 and sum over contributions from individual points.
    x = jnp.sum(x[..., 0, :, 0], axis=-2)  # Shape (..., (max_degree+1)**2).

    # 4. Convert output irreps to 3x3 matrix and return.
    cg = e3x.so3.clebsch_gordan(max_degree1=1, max_degree2=1, max_degree3=2)  # Shape (4, 4, 9).
    y = jnp.einsum('...l,nml->...nm', x, cg[1:, 1:, :])  # Shape (..., 3, 3).
    return y

In [4]:
def mean_squared_loss(prediction, target):
  return jnp.mean(optax.l2_loss(prediction, target))

In [55]:
@functools.partial(jax.jit, static_argnames=('model_apply', 'optimizer_update'))
def train_step(model_apply, optimizer_update, batch, opt_state, params):
  def loss_fn(params):
    inertia_tensor = model_apply(params, batch['masses'], batch['positions'])
    loss = mean_squared_loss(inertia_tensor, batch['inertia_tensor'])
    return loss
  loss, grad = jax.value_and_grad(loss_fn)(params)
  updates, opt_state = optimizer_update(grad, opt_state, params)
  params = optax.apply_updates(params, updates)
  return params, opt_state, loss

@functools.partial(jax.jit, static_argnames=('model_apply',))
def eval_step(model_apply, batch, params):
  inertia_tensor = model_apply(params, batch['masses'], batch['positions'])
  loss = mean_squared_loss(inertia_tensor, batch['inertia_tensor'])
  return loss

def train_model(key, model, train_data, valid_data, num_epochs, learning_rate, batch_size):
  # Initialize model parameters and optimizer state.
  key, init_key = jax.random.split(key)
  optimizer = optax.adam(learning_rate)
  params = model.init(init_key, train_data['masses'][0:1], train_data['positions'][0:1])
  opt_state = optimizer.init(params)

  # Determine the number of training steps per epoch.
  train_size = len(train_data['masses'])
  steps_per_epoch = train_size//batch_size

  # Train for 'num_epochs' epochs.
  for epoch in range(1, num_epochs + 1):
    # Draw random permutations for fetching batches from the train data.
    key, shuffle_key = jax.random.split(key)
    perms = jax.random.permutation(shuffle_key, train_size)
    perms = perms[:steps_per_epoch * batch_size]  # Skip the last batch (if incomplete).
    perms = perms.reshape((steps_per_epoch, batch_size))

    # Loop over all batches.
    train_loss = 0.0  # For keeping a running average of the loss.
    for i, perm in enumerate(perms):
      batch = {k: v[perm, ...] for k, v in train_data.items()}
      params, opt_state, loss = train_step(
          model_apply=model.apply,
          optimizer_update=optimizer.update,
          batch=batch,
          opt_state=opt_state,
          params=params
      )
      train_loss += (loss - train_loss)/(i+1)

    # Evaluate on the test set after each training epoch.
    valid_loss = eval_step(
        model_apply=model.apply,
        batch=valid_data,
        params=params
    )

    # Print progress.
    print(f"epoch {epoch : 4d} train loss {train_loss : 8.6f} valid loss {valid_loss : 8.6f}")

  # Return final model parameters.
  return params

In [56]:
# Initialize PRNGKey for random number generation.
key = jax.random.PRNGKey(0)

# Generate train and test datasets.
key, data_key = jax.random.split(key)
train_data, valid_data = generate_datasets(data_key)

# Define training hyperparameters.
learning_rate = 0.002
num_epochs = 100
batch_size = 32

In [7]:
'''print(train_data['masses'][0:1].shape)
print(train_data['positions'][0:1].shape)'''

"print(train_data['masses'][0:1].shape)\nprint(train_data['positions'][0:1].shape)"

In [57]:
key, train_key = jax.random.split(key)
model = Model()
params = train_model(
  key=train_key,
  model=model,
  train_data=train_data,
  valid_data=valid_data,
  num_epochs=num_epochs,
  learning_rate=learning_rate,
  batch_size=batch_size,
)

epoch    1 train loss  1.359933 valid loss  0.650806
epoch    2 train loss  0.471154 valid loss  0.361696
epoch    3 train loss  0.355795 valid loss  0.330647
epoch    4 train loss  0.335975 valid loss  0.313806
epoch    5 train loss  0.313707 valid loss  0.307905
epoch    6 train loss  0.295819 valid loss  0.261203
epoch    7 train loss  0.269274 valid loss  0.236152
epoch    8 train loss  0.247977 valid loss  0.230414
epoch    9 train loss  0.231734 valid loss  0.205375
epoch   10 train loss  0.225083 valid loss  0.209193
epoch   11 train loss  0.207602 valid loss  0.188981
epoch   12 train loss  0.200761 valid loss  0.185399
epoch   13 train loss  0.190793 valid loss  0.175384
epoch   14 train loss  0.178643 valid loss  0.169232
epoch   15 train loss  0.161587 valid loss  0.144264
epoch   16 train loss  0.147181 valid loss  0.133549
epoch   17 train loss  0.129426 valid loss  0.109603
epoch   18 train loss  0.105608 valid loss  0.088042
epoch   19 train loss  0.089911 valid loss  0.

In [53]:
i = 0
masses, positions, target = valid_data['masses'][i], valid_data['positions'][i], valid_data['inertia_tensor'][i]
prediction = model.apply(params, masses, positions)

print('target')
print(target)
print('prediction')
print(prediction)
print('mean squared error', jnp.mean((prediction-target)**2))

Initial shape: (10, 4)
x shape: (10, 1, 4, 1)


ScopeParamShapeError: Initializer expected to generate shape (1, 8) but got shape (1, 1) instead for parameter "kernel" in "/Dense_0/0+". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)

# Dataset 

In [22]:
filename='Si16Vplus..DFT.SP-GRD.wB97X-D.tight.Data.5042.R_E_F_D_Q.npz'
dataset= np.load(filename)
for key in dataset.keys():
    print(key)
print('Dipole moment shape array',dataset['D'].shape)
print('Dipole moment units', dataset['D_units'])

print('Atomic numbers', dataset['z'])

type
R
R_units
z
E
E_units
F
F_units
D
D_units
Q
name
README
theory
Dipole moment shape array (5042, 3)
Dipole moment units eAng
Atomic numbers [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]


# Dipole Moment

In [59]:
def prepare_datasets(filename, key, num_train, num_valid):
    # Load the dataset.
    dataset = np.load(filename)
    num_data = len(dataset["E"])
    Z=jnp.full(16,14)
    Z=jnp.append(Z,23)
    Z=jnp.expand_dims(Z,axis=0)
    Z=jnp.repeat(Z, num_data, axis=0)
    num_draw = num_train + num_valid
    if num_draw > num_data:
        raise RuntimeError(
            f"datasets only contains {num_data} points, requested num_train={num_train}, num_valid={num_valid}"
        )

    # Randomly draw train and validation sets from dataset.
    choice = np.asarray(
        jax.random.choice(key, num_data, shape=(num_draw,), replace=False)
    )
    train_choice = choice[:num_train]
    valid_choice = choice[num_train:]

    # Collect and return train and validation sets.
    train_data = dict(
        #energy=jnp.asarray(dataset["E"][train_choice, 0] - mean_energy),
        #forces=jnp.asarray(dataset["F"][train_choice]),
        dipole_moment= jnp.asarray(dataset["D"][train_choice]),
        atomic_numbers=jnp.asarray(Z[train_choice]),
        # atomic_numbers=jnp.asarray(z_hack),
        positions=jnp.asarray(dataset["R"][train_choice]),
    )
    valid_data = dict(
        #energy=jnp.asarray(dataset["E"][valid_choice, 0] - mean_energy),
        #forces=jnp.asarray(dataset["F"][valid_choice]),
        atomic_numbers=jnp.asarray(Z[valid_choice]),
        dipole_moment= jnp.asarray(dataset["D"][valid_choice]),
        # atomic_numbers=jnp.asarray(z_hack),
        positions=jnp.asarray(dataset["R"][valid_choice]),
    )
    return train_data, valid_data

In [82]:
key = jax.random.PRNGKey(0)
num_train=1000
num_val=200
train_data,valid_data=prepare_datasets(filename,key, num_train,num_val)
print(train_data['dipole_moment'].shape)

(1000, 3)


In [88]:
class Dipole_Moment(nn.Module):
  #features = 1
  #max_degree = 1
  @nn.compact
  def __call__(self,atomic_numbers, positions):  # Shapes (..., N) and (..., N, 3).
    # 1. Initialize features.
    x = jnp.concatenate((atomic_numbers[...,None], positions), axis=-1) # Shape (..., N, 4).
    #print("Initial shape:", x.shape)
    x = x[..., None, :, None]  # Shape (..., N, 1, 3, 1).
    #print("x shape:", x.shape)
    # 2. Apply transformations.
    x = e3x.nn.Dense(features=1)(x) 
    #print("After Dense layer:", x.shape)
    x = e3x.nn.TensorDense(max_degree=1)(x)  
    #print("After TensorDense layer:", x.shape)
    x=jnp.sum(x, axis=-4) 
    #print("After sum:", x.shape)
    y = x[..., 1, 1:4, 0]
    #print("After slicing:", y.shape)

    return y

In [112]:
!pip list requierement

Package             Version
------------------- ----------
absl-py             2.1.0
ase                 3.23.0
asttokens           2.4.1
attrs               23.2.0
chex                0.1.85
colorama            0.4.6
comm                0.2.1
contextlib2         21.6.0
contourpy           1.2.0
cycler              0.12.1
dataclasses         0.6
debugpy             1.8.1
decorator           5.1.1
dm-haiku            0.0.12
e3nn-jax            0.20.6
e3x                 1.0.1
einops              0.7.0
etils               1.7.0
exceptiongroup      1.2.0
executing           2.0.1
flax                0.8.1
fonttools           4.49.0
fsspec              2024.2.0
importlib-metadata  7.0.1
importlib_resources 6.1.3
ipykernel           6.29.2
ipython             8.22.1
jax                 0.4.25
jax-md              0.2.8
jaxlib              0.4.25
jaxtyping           0.2.28
jedi                0.19.1
jmp                 0.0.4
jraph               0.0.6.dev0
jupyter_client      8.6.0
jupyter_cor

In [84]:
'''dm_model = Dipole_Moment()
key = jax.random.PRNGKey(0)

# Generate train and test datasets.
key, data_key = jax.random.split(key)
train_data, valid_data = generate_datasets(data_key)
params = dm_model.init(key, train_data['masses'][0:1], train_data['positions'][0:1])
moment=dm_model.apply(params,train_data['masses'][0:1], train_data['positions'][0:1])
print(moment.shape)'''

"dm_model = Dipole_Moment()\nkey = jax.random.PRNGKey(0)\n\n# Generate train and test datasets.\nkey, data_key = jax.random.split(key)\ntrain_data, valid_data = generate_datasets(data_key)\nparams = dm_model.init(key, train_data['masses'][0:1], train_data['positions'][0:1])\nmoment=dm_model.apply(params,train_data['masses'][0:1], train_data['positions'][0:1])\nprint(moment.shape)"

In [89]:
@functools.partial(jax.jit, static_argnames=('model_apply', 'optimizer_update'))
def train_step(model_apply, optimizer_update, batch, opt_state, params):
  def loss_fn(params):
    dipole_moment = model_apply(params,batch['atomic_numbers'] ,batch['positions'])
    loss = mean_squared_loss(dipole_moment, batch['dipole_moment'])
    return loss
  loss, grad = jax.value_and_grad(loss_fn)(params)
  updates, opt_state = optimizer_update(grad, opt_state, params)
  params = optax.apply_updates(params, updates)
  return params, opt_state, loss

@functools.partial(jax.jit, static_argnames=('model_apply',))
def eval_step(model_apply, batch, params):
  dipole_moment = model_apply(params,batch['atomic_numbers'],batch['positions'])
  loss = mean_squared_loss(dipole_moment, batch['dipole_moment'])
  return loss

def train_model(key, model, train_data, valid_data, num_epochs, learning_rate, batch_size):
  # Initialize model parameters and optimizer state.
  key, init_key = jax.random.split(key)
  optimizer = optax.adam(learning_rate)
  params = model.init(init_key,train_data['atomic_numbers'][0:1],train_data['positions'][0:1])
  opt_state = optimizer.init(params)

  # Determine the number of training steps per epoch.
  train_size = len(train_data['positions'])
  steps_per_epoch = train_size//batch_size

  # Train for 'num_epochs' epochs.
  for epoch in range(1, num_epochs + 1):
    # Draw random permutations for fetching batches from the train data.
    key, shuffle_key = jax.random.split(key)
    perms = jax.random.permutation(shuffle_key, train_size)
    perms = perms[:steps_per_epoch * batch_size]  # Skip the last batch (if incomplete).
    perms = perms.reshape((steps_per_epoch, batch_size))

    # Loop over all batches.
    train_loss = 0.0  # For keeping a running average of the loss.
    for i, perm in enumerate(perms):
      batch = {k: v[perm, ...] for k, v in train_data.items()}
      #print(batch['dipole_moment'].shape)

      params, opt_state, loss = train_step(
          model_apply=model.apply,
          optimizer_update=optimizer.update,
          batch=batch,
          opt_state=opt_state,
          params=params
      )
      train_loss += (loss - train_loss)/(i+1)

    # Evaluate on the test set after each training epoch.
    valid_loss = eval_step(
        model_apply=model.apply,
        batch=valid_data,
        params=params
    )

    # Print progress.
    print(f"epoch {epoch : 4d} train loss {train_loss : 8.6f} valid loss {valid_loss : 8.6f}")

  # Return final model parameters.
  return params

In [105]:
# Initialize PRNGKey for random number generation.
key = jax.random.PRNGKey(0)
num_train=3000
num_val=200
train_data, valid_data = prepare_datasets(filename,key, num_train,num_val)

# Define training hyperparameters.
learning_rate = 0.002
num_epochs = 100
batch_size = 32

In [106]:
'''print(train_data['positions'][0:1].shape)
print(train_data['atomic_numbers'][0:1].shape)'''

"print(train_data['positions'][0:1].shape)\nprint(train_data['atomic_numbers'][0:1].shape)"

In [107]:
key, train_key = jax.random.split(key)
model = Dipole_Moment()
params = train_model(
  key=train_key,
  model=model,
  train_data=train_data,
  valid_data=valid_data,
  num_epochs=num_epochs,
  learning_rate=learning_rate,
  batch_size=batch_size,
)

epoch    1 train loss  7.643333 valid loss  1.240807
epoch    2 train loss  0.617081 valid loss  0.194544
epoch    3 train loss  0.061019 valid loss  0.006598
epoch    4 train loss  0.004027 valid loss  0.004015
epoch    5 train loss  0.003565 valid loss  0.004022
epoch    6 train loss  0.003572 valid loss  0.004016
epoch    7 train loss  0.003574 valid loss  0.004019
epoch    8 train loss  0.003568 valid loss  0.004015
epoch    9 train loss  0.003571 valid loss  0.004017
epoch   10 train loss  0.003567 valid loss  0.004017
epoch   11 train loss  0.003569 valid loss  0.004017
epoch   12 train loss  0.003562 valid loss  0.004013
epoch   13 train loss  0.003564 valid loss  0.004021
epoch   14 train loss  0.003562 valid loss  0.004007
epoch   15 train loss  0.003570 valid loss  0.004006
epoch   16 train loss  0.003559 valid loss  0.004008
epoch   17 train loss  0.003564 valid loss  0.004005
epoch   18 train loss  0.003560 valid loss  0.004003
epoch   19 train loss  0.003559 valid loss  0.

In [111]:
i = 45
Z, positions, target = valid_data['atomic_numbers'][i], valid_data['positions'][i], valid_data['dipole_moment'][i]
prediction = model.apply(params, Z, positions)

print('target')
print(target)
print('prediction')
print(prediction)
print('mean squared error', jnp.mean((prediction-target)**2))

target
[ 1.5010834  -0.35230795  1.6916788 ]
prediction
[ 1.4647644  -0.30743074  1.7500215 ]
mean squared error 0.0022456348
