In [1]:
import functools
import e3x
from flax import linen as nn
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax

# Disable future warnings.
import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)
jax.devices()

[CpuDevice(id=0)]

In [38]:
class MP_Dipole_Moment(nn.Module):
    features: int = 32
    max_degree: int = 2
    num_iterations: int = 3
    num_basis_functions: int = 8
    cutoff: float = 5.0
    max_atomic_number: int = 118  # This is overkill for most applications.

    def dipole_moment(
        self, atomic_numbers, positions, dst_idx, src_idx, batch_segments, batch_size
    ):
        # 1. Calculate displacement vectors.
        print(("atomic_numbers", atomic_numbers))
        positions_dst = e3x.ops.gather_dst(positions, dst_idx=dst_idx)
        positions_src = e3x.ops.gather_src(positions, src_idx=src_idx)
        displacements = positions_src - positions_dst  # Shape (num_pairs, 3).

        # 2. Expand displacement vectors in basis functions.
        basis = e3x.nn.basis(  # Shape (num_pairs, 1, (max_degree+1)**2, num_basis_functions).
            displacements,
            num=self.num_basis_functions,
            max_degree=self.max_degree,
            radial_fn=e3x.nn.reciprocal_bernstein,
            cutoff_fn=functools.partial(e3x.nn.smooth_cutoff, cutoff=self.cutoff),
        )

        # 3. Embed atomic numbers in feature space, x has shape (num_atoms, 1, 1, features).
        x = e3x.nn.Embed(
            num_embeddings=self.max_atomic_number + 1, features=self.features
        )(atomic_numbers)
        # print('Embed',x.shape)
        # print('Basis',basis.shape)

        # 4. Perform iterations (message-passing + atom-wise refinement).
        for i in range(self.num_iterations):
            # Message-pass.
            if i == self.num_iterations - 1:  # Final iteration.
                # Since we will only use scalar features after the final message-pass, we do not want to produce non-scalar
                # features for efficiency reasons.
                y = e3x.nn.MessagePass(max_degree=2, include_pseudotensors=False)(
                    x, basis, dst_idx=dst_idx, src_idx=src_idx
                )
                # print('Final',y.shape)
                # After the final message pass, we can safely throw away all non-scalar features.
                x = e3x.nn.change_max_degree_or_type(
                    x, max_degree=2, include_pseudotensors=False
                )
            else:
                # In intermediate iterations, the message-pass should consider all possible coupling paths.
                print(x.shape, basis.shape, "intermediate iterations,")
                y = e3x.nn.MessagePass()(x, basis, dst_idx=dst_idx, src_idx=src_idx)
                # print('Message',y.shape)
            y = e3x.nn.add(x, y)

            # Atom-wise refinement MLP.
            y = e3x.nn.Dense(self.features)(y)
            y = e3x.nn.silu(y)
            y = e3x.nn.Dense(self.features, kernel_init=jax.nn.initializers.zeros)(y)

            # Residual connection.
            x = e3x.nn.add(x, y)
            # print('Residual',x.shape)

            # 5. Predict atomic energies with an ordinary dense layer.
            # element_bias = self.param(
            #    "element_bias",
            #    lambda rng, shape: jnp.zeros(shape),
            #    (self.max_atomic_number + 1),
            # )

        x = nn.Dense(1, use_bias=False, kernel_init=jax.nn.initializers.zeros)(
            x
        )  # (..., Natoms, 1, 9, 1)
        print("After dense:", x.shape)
        element_bias = self.param(
            "element_bias",
            lambda rng, shape: jnp.zeros(shape),
            (self.max_atomic_number + 1),
        )
        print('element_bias',element_bias[atomic_numbers].shape)
        bias= element_bias[atomic_numbers]
        x += bias[:,None,None,None]
        print(x.shape, ' after bias ')
        x = jax.ops.segment_sum(
            x, segment_ids=batch_segments, num_segments=batch_size
        )
        print("After segment_sum:", x.shape)
        #x = jnp.sum(x, axis=1)
        x=jnp.squeeze(x, axis=0)
        print("After sum:", x.shape)
        x = x[..., 1:4, 0]
        # x = x[..., :3]
        # x = jnp.squeeze(x)
        print('After slicing:' ,x.shape)
        # x = jnp.sum(x, axis=1)

        # x = x[:, 1:4]

        print("Forma final:", x.shape)
        return x

    @nn.compact
    def __call__(
        self,
        atomic_numbers,
        positions,
        dst_idx,
        src_idx,
        batch_segments=None,
        batch_size=None,
    ):
        if batch_segments is None:
            batch_segments = jnp.zeros_like(atomic_numbers)
            batch_size = 1
            print("pase", batch_segments, atomic_numbers)
            # Since we want to also predict forces, i.e. the gradient of the energy w.r.t. positions (argument 1), we use
            # jax.value_and_grad to create a function for predicting both energy and forces for us.
        print(batch_segments.shape, "batch", batch_size)
        dipole = self.dipole_moment(
            atomic_numbers, positions, dst_idx, src_idx, batch_segments, batch_size
        )
        #print(dipole)

        return dipole

In [39]:
def mean_squared_loss(dipole_prediction, dipole_target):
    return jnp.mean(optax.l2_loss(dipole_prediction, dipole_target))

In [40]:
def prepare_batches(key, data, batch_size):
    # Determine the number of training steps per epoch.
    data_size = len(data["dipole_moment"])
    steps_per_epoch = data_size // batch_size

    # Draw random permutations for fetching batches from the train data.
    perms = jax.random.permutation(key, data_size)
    perms = perms[
        : steps_per_epoch * batch_size
    ]  # Skip the last batch (if incomplete).
    perms = perms.reshape((steps_per_epoch, batch_size))

    # Prepare entries that are identical for each batch.
    num_atoms = len(data["atomic_numbers"])
    batch_segments = jnp.repeat(jnp.arange(batch_size), num_atoms)
    atomic_numbers = jnp.tile(data["atomic_numbers"], batch_size)
    offsets = jnp.arange(batch_size) * num_atoms
    dst_idx, src_idx = e3x.ops.sparse_pairwise_indices(num_atoms)
    dst_idx = (dst_idx + offsets[:, None]).reshape(-1)
    src_idx = (src_idx + offsets[:, None]).reshape(-1)

    # Assemble and return batches.
    return [
        dict(
            dipole_moment=data["dipole_moment"][perm].reshape(-1, 3),
            atomic_numbers=atomic_numbers,
            positions=data["positions"][perm].reshape(-1, 3),
            dst_idx=dst_idx,
            src_idx=src_idx,
            batch_segments=batch_segments,
        )
        for perm in perms
    ]

In [41]:
@functools.partial(
    jax.jit, static_argnames=("model_apply", "optimizer_update", "batch_size")
)
def train_step(model_apply, optimizer_update, batch, batch_size, opt_state, params):
    def loss_fn(params):
        dipole = model_apply(
            params,
            atomic_numbers=batch["atomic_numbers"],
            positions=batch["positions"],
            dst_idx=batch["dst_idx"],
            src_idx=batch["src_idx"],
            batch_segments=batch["batch_segments"],
            batch_size=batch_size,
        )
        loss = mean_squared_loss(
            dipole_prediction=dipole, dipole_target=batch["dipole_moment"]
        )
        return loss

    loss, grad = jax.value_and_grad(loss_fn)(params)
    updates, opt_state = optimizer_update(grad, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss


@functools.partial(jax.jit, static_argnames=("model_apply", "batch_size"))
def eval_step(model_apply, batch, batch_size, params):
    dipole = model_apply(
        params,
        atomic_numbers=batch["atomic_numbers"],
        positions=batch["positions"],
        dst_idx=batch["dst_idx"],
        src_idx=batch["src_idx"],
        batch_segments=batch["batch_segments"],
        batch_size=batch_size,
    )
    print('dipole_prediction',dipole[0])
    print('dipole_target',batch["dipole_moment"][0])
    loss = mean_squared_loss(
        dipole_prediction=dipole, dipole_target=batch["dipole_moment"]
    )
    return loss


def train_model(
    key, model, train_data, valid_data, num_epochs, learning_rate, batch_size
):
    # Initialize model parameters and optimizer state.
    key, init_key = jax.random.split(key)
    optimizer = optax.adam(learning_rate)
    dst_idx, src_idx = e3x.ops.sparse_pairwise_indices(
        len(train_data["atomic_numbers"])
    )
    params = model.init(
        init_key,
        atomic_numbers=train_data["atomic_numbers"],
        positions=train_data["positions"][0],
        dst_idx=dst_idx,
        src_idx=src_idx,
    )
    opt_state = optimizer.init(params)

    # Batches for the validation set need to be prepared only once.
    key, shuffle_key = jax.random.split(key)
    valid_batches = prepare_batches(shuffle_key, valid_data, batch_size)

    # Train for 'num_epochs' epochs.
    for epoch in range(1, num_epochs + 1):
        # Prepare batches.
        key, shuffle_key = jax.random.split(key)
        train_batches = prepare_batches(shuffle_key, train_data, batch_size)

        # Loop over train batches.
        train_loss = 0.0
        for i, batch in enumerate(train_batches):
            
            params, opt_state, loss = train_step(
                model_apply=model.apply,
                optimizer_update=optimizer.update,
                batch=batch,
                batch_size=batch_size,
                opt_state=opt_state,
                params=params,
            )
            train_loss += (loss - train_loss) / (i + 1)

        # Evaluate on validation set.
        valid_loss = 0.0
        for i, batch in enumerate(valid_batches):
            loss = eval_step(
                model_apply=model.apply,
                batch=batch,
                batch_size=batch_size,
                params=params,
            )
            valid_loss += (loss - valid_loss) / (i + 1)

        # Print progress.
        print(f"epoch: {epoch: 3d}                    train:   valid:")
        print(f"    loss [a.u.]             {train_loss : 8.6f} {valid_loss : 8.3f}")

    # Return final model parameters.
    return params

In [42]:
def prepare_datasets(filename, key, num_train, num_valid):
    # Load the dataset.
    dataset = np.load(filename)
    num_data = len(dataset["E"])
    Z = jnp.full(16, 14)
    Z = jnp.append(Z, 23)
    Z = jnp.expand_dims(Z, axis=0)
    Z = jnp.repeat(Z, num_data, axis=0)
    num_draw = num_train + num_valid
    if num_draw > num_data:
        raise RuntimeError(
            f"datasets only contains {num_data} points, requested num_train={num_train}, num_valid={num_valid}"
        )

    # Randomly draw train and validation sets from dataset.
    choice = np.asarray(
        jax.random.choice(key, num_data, shape=(num_draw,), replace=False)
    )
    train_choice = choice[:num_train]
    valid_choice = choice[num_train:]

    # Collect and return train and validation sets.
    train_data = dict(
        # energy=jnp.asarray(dataset["E"][train_choice, 0] - mean_energy),
        # forces=jnp.asarray(dataset["F"][train_choice]),
        dipole_moment=jnp.asarray(dataset["D"][train_choice]),
        #atomic_numbers=jnp.asarray(Z[train_choice]),
        atomic_numbers=jnp.asarray(dataset["z"]),
        # atomic_numbers=jnp.asarray(z_hack),
        positions=jnp.asarray(dataset["R"][train_choice]),
    )
    valid_data = dict(
        # energy=jnp.asarray(dataset["E"][valid_choice, 0] - mean_energy),
        # forces=jnp.asarray(dataset["F"][valid_choice]),
        #atomic_numbers=jnp.asarray(Z[valid_choice]),
        dipole_moment=jnp.asarray(dataset["D"][valid_choice]),
        # atomic_numbers=jnp.asarray(z_hack),
        atomic_numbers=jnp.asarray(dataset["z"]),
        positions=jnp.asarray(dataset["R"][valid_choice]),
    )
    return train_data, valid_data

In [43]:
filename = "test_data.npz"
dataset = np.load(filename)
for key in dataset.keys():
    print(key)
print("Dipole moment shape array", dataset["D"].shape)
print("Dipole moment units", dataset["D_units"])

print("Atomic numbers", dataset["z"])

type
R
R_units
z
E
E_units
F
F_units
D
D_units
Q
name
README
theory
Dipole moment shape array (5042, 3)
Dipole moment units eAng
Atomic numbers [23 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14]


In [47]:
key = jax.random.PRNGKey(0)
num_train = 100
num_val = 20
# Define training hyperparameters.
learning_rate = 0.001
num_epochs = 100
batch_size = 1

In [48]:
# Model hyperparameters.
features = 32
max_degree = 2
num_iterations = 3
num_basis_functions = 16
cutoff = 6.0
max_atomic_number = 23

In [49]:
train_data, valid_data = prepare_datasets(filename, key, num_train, num_val)
key, train_key = jax.random.split(key)
key = jax.random.PRNGKey(0)
model = MP_Dipole_Moment(
    features=features,
    max_degree=max_degree,
    num_iterations=num_iterations,
    num_basis_functions=num_basis_functions,
    cutoff=cutoff,
    max_atomic_number=max_atomic_number
)
params = train_model(
    key=train_key,
    model=model,
    train_data=train_data,
    valid_data=valid_data,
    num_epochs=num_epochs,
    learning_rate=learning_rate,
    batch_size=batch_size,
)

pase [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] [23 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14]
(17,) batch 1
('atomic_numbers', Array([23, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14],      dtype=int32))
(17, 1, 1, 32) (272, 1, 9, 16) intermediate iterations,
(17, 1, 9, 32) (272, 1, 9, 16) intermediate iterations,
After dense: (17, 1, 9, 1)
element_bias (17,)
(17, 1, 9, 1)  after bias 
After segment_sum: (1, 1, 9, 1)
After sum: (1, 9, 1)
After slicing: (1, 3)
Forma final: (1, 3)
(17,) batch 1
('atomic_numbers', Traced<ShapedArray(int32[17])>with<DynamicJaxprTrace(level=1/0)>)
(17, 1, 1, 32) (272, 1, 9, 16) intermediate iterations,
(17, 1, 9, 32) (272, 1, 9, 16) intermediate iterations,
After dense: (17, 1, 9, 1)
element_bias (17,)
(17, 1, 9, 1)  after bias 
After segment_sum: (1, 1, 9, 1)
After sum: (1, 9, 1)
After slicing: (1, 3)
Forma final: (1, 3)
(17,) batch 1
('atomic_numbers', Traced<ShapedArray(int32[17])>with<DynamicJaxprTrace(level=1/0)>)
(17, 1, 1, 32) (272,

In [50]:
params['params'].keys()

dict_keys(['Dense_0', 'Dense_1', 'Dense_2', 'Dense_3', 'Embed_0', 'MessagePass_0', 'MessagePass_1', 'MessagePass_2', 'element_bias'])

In [52]:
print(params['params']['element_bias'])
print(params['params']['Dense_0'])
print(params['params']['Dense_1'])
print(params['params']['Dense_2'])
print(params['params']['Dense_3'])

[0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.05393197 0.         0.         0.
 0.         0.         0.         0.         0.         0.05393196]
{'0+': {'bias': Array([-2.53218375e-02, -9.53917354e-02,  1.57667831e-01, -7.80547559e-02,
        1.36273623e-01,  1.95839822e-01,  2.66463216e-02,  1.13846228e-01,
        6.53901175e-02, -5.57711311e-02,  2.96855648e-03, -4.23120148e-02,
        1.06172115e-01, -8.18778127e-02,  1.87666953e-01, -4.21365499e-02,
       -1.36071682e-01,  5.61018325e-02,  8.82903785e-02,  4.70986729e-03,
       -2.60044456e-01,  3.94901708e-02,  1.52000010e-01,  3.76045220e-02,
        5.37868366e-02, -1.42639047e-02, -5.71135012e-03,  7.24664715e-05,
       -6.48561791e-02,  7.30109513e-02, -1.40815064e-01,  9.64180157e-02],      dtype=float32), 'kernel': Array([[-0.0689008 ,  0.22902995,  0.46396974, ..., -0.14538912,
         0.01942987,  0.41012758],
       

In [None]:
valid_data["atomic_numbers"][0]

Array(23, dtype=int32)

In [61]:
i=4
dst_idx, src_idx = e3x.ops.sparse_pairwise_indices(17)

dm=model.apply(params,
    atomic_numbers=valid_data["atomic_numbers"],
    positions=valid_data["positions"][i],
    dst_idx=dst_idx,
    src_idx=src_idx)

print('dipole moment prediction',dm)
print('dipole moment target',valid_data['dipole_moment'][i])

pase [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] [23 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14]
(17,) batch 1
('atomic_numbers', Array([23, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14],      dtype=int32))
(17, 1, 1, 32) (272, 1, 9, 16) intermediate iterations,
(17, 1, 9, 32) (272, 1, 9, 16) intermediate iterations,
After dense: (17, 1, 9, 1)
element_bias (17,)
(17, 1, 9, 1)  after bias 
After segment_sum: (1, 1, 9, 1)
After sum: (1, 9, 1)
After slicing: (1, 3)
Forma final: (1, 3)
dipole moment prediction [[1.3089142  0.4336555  0.46206573]]
dipole moment target [ 1.4130961  -0.46504444  1.7085109 ]
