In [66]:
# this notebook meant to work for batched samples

In [67]:
import jax
import flax
import jax.numpy as jnp
from flax import linen as nn
from flax import nnx
from typing import Dict

In [None]:

# Atom Embedding
# It takes the number of different atoms in the molecular systems.
# For each atom, create an atom embedding.

@jax.jit
def count_unique(x):
  x = jnp.sort(x)
  return 1 + (x[1:] != x[:-1]).sum()

class AtomEmbedding(nnx.Module):
    def __init__(self, num_atom_types, embedding_dim, n_batch, rngs: nnx.Rngs = nnx.Rngs(0)):
        super(AtomEmbedding, self).__init__()

        self.embedding = nnx.Embed(
            num_embeddings=num_atom_types,
            features=embedding_dim,
            rngs=rngs
        )
        self.n_batch = n_batch

    def __call__(self, atom_types):
        ord_atom_types = jnp.unique(atom_types)
        mask = jnp.array([jnp.where(ord_atom_types==elem) for elem in atom_types.flatten()]).flatten()


        return jnp.array(jnp.split(self.embedding(mask), self.n_batch))

In [267]:
# Given a positions thing, r = (r_1, ..., r_n), returns
# a matrix with the relative distance among them
# (d_ij)_ij = |r_i - r_j|^2

class R_distances(nnx.Module):
    def __init__(self):
        super().__init__()

    def __call__(self, R):
        num_atoms = len(R[0])
        Rij = jnp.array([[[r[i] - r[j] for j in range(num_atoms)] for i in range(num_atoms)] for r in R])
        d_ij = jnp.linalg.norm(Rij, axis=-1)
        return d_ij

In [None]:
class RadialBasisFunctions(nnx.Module):
    def __init__(self, rbf_min, rbf_max, n_rbf, gamma=10):
        super().__init__()
        self.rbf_min = rbf_min
        self.rbf_max = rbf_max
        self.n_rbf = n_rbf
        self.gamma = gamma
        self.centers = jnp.linspace(rbf_min, rbf_max, n_rbf).reshape(1, -1)

    def __call__(self, d_ij):
        diff = d_ij[..., None] - self.centers
        return jnp.exp(-self.gamma * jnp.pow(diff, 2))

In [269]:
r = jnp.array([[[1., 2., 3.], [4., 5., 6.]],[[1., 2., 3.], [4., 5., 6.]],[[1., 2., 3.], [4., 5., 6.]]])
z = jnp.array([[1,8], [1,8], [1,8]])

rbf_min = 0.
rbf_max = 10.
n_rbf = 30
num_atom_types = 2
embedding_dim=64
n_batch = 10
rbf = RadialBasisFunctions(rbf_min, rbf_max, n_rbf)
distances = R_distances()
emb = AtomEmbedding(num_atom_types=num_atom_types, embedding_dim=embedding_dim, n_batch=n_batch)

In [275]:
z = batch[0]
r = batch[1]
z.shape
emb = AtomEmbedding(num_atom_types=num_atom_types, embedding_dim=embedding_dim, n_batch=n_batch)

In [276]:
at_em = emb(z)
dis = distances(r)
rbfs = rbf(dis)

In [277]:
at_em.shape

(10, 3, 64)

In [278]:
dis.shape

(10, 3, 3)

In [279]:
rbfs.shape

(10, 3, 3, 30)

In [None]:
# ReLU function

class relu_layer(nnx.Module):
    def __init__(self):
        super().__init__()

    def __call__(self, x):
        return nnx.relu(x)

# Shift Softplus layer

class ssp_layer(nnx.Module):
    def __init__(self):
        super().__init__()

    def __call__(self, x):
        return jnp.log(0.5 * jnp.exp(x) + 0.5)

In [None]:
class filter_generator(nnx.Module):
    def __init__(
            self,
            atom_embeddings_dim,
            rbf_min,
            rbf_max,
            n_rbf,
            rngs: nnx.Rngs = nnx.Rngs(0),
            activation=ssp_layer
            ):

        super().__init__()
        self.rbf = RadialBasisFunctions(rbf_min, rbf_max, n_rbf)
        self.w_layers = nnx.Sequential(
            nnx.Linear(n_rbf, atom_embeddings_dim, rngs=rngs),
            activation(),
            nnx.Linear(atom_embeddings_dim, atom_embeddings_dim, rngs=rngs),
            activation()
        )

    def __call__(self, d_ij):
        rbfs = self.rbf(d_ij)
        Wij = self.w_layers(rbfs)
        return Wij


In [282]:
fil_gen = filter_generator(atom_embeddings_dim=embedding_dim, rbf_min=rbf_min, rbf_max=rbf_max, n_rbf=n_rbf)

In [283]:
filters = fil_gen(dis); filters.shape

(10, 3, 3, 64)

In [None]:
class CfConv(nnx.Module):
    def __init__(
            self,
            atom_embeddings_dim,
            rbf_min,
            rbf_max,
            n_rbf,
            rngs: nnx.Rngs = nnx.Rngs(0),
            activation=ssp_layer):
        super().__init__()

        self.rbf = RadialBasisFunctions(rbf_min, rbf_max, n_rbf)
        self.filters = filter_generator(
            atom_embeddings_dim,
            rbf_min,
            rbf_max,
            n_rbf,
            rngs,
            activation
            )

    def __call__(self, X, d_ij):
        fij = jnp.sum(self.filters(d_ij), axis = 1)
        X_ij = X * fij
        return X + jnp.sum(X_ij, axis=0)


In [285]:
conv = CfConv(embedding_dim, rbf_min, rbf_max, n_rbf)

In [286]:
conv(at_em, dis).shape

(10, 3, 64)

In [None]:
class InteractionBlock(nnx.Module):

    def __init__(
            self,
            atom_embeddings_dim,
            rbf_min,
            rbf_max,
            n_rbf,
            rngs: nnx.Rngs,
            activation=ssp_layer,
    ):
        super().__init__()
        self.in_atom_wise = nnx.Linear(
            atom_embeddings_dim,
            atom_embeddings_dim,
            rngs=rngs
        )

        self.cf_conv = CfConv(
            atom_embeddings_dim,
            rbf_min,
            rbf_max,
            n_rbf,
            rngs=rngs,
            activation=activation
        )

        self.out_atom_wise = nnx.Sequential(
            nnx.Linear(atom_embeddings_dim, atom_embeddings_dim, rngs=rngs),
            activation(),
            nnx.Linear(atom_embeddings_dim, atom_embeddings_dim, rngs=rngs)
        )

    def __call__(self, X, R_distances):
        X_in = self.in_atom_wise(X)
        X_conv = self.cf_conv(X_in, R_distances)
        V = self.out_atom_wise(X_conv)
        return X + V


In [None]:
class SchNet(nnx.Module):
    def __init__(
            self,
            n_batch,
            atom_embedding_dim=64,
            n_interactions=3,
            n_atom_types=3,
            rbf_min=0.,
            rbf_max=30.,
            n_rbf=300,
            rngs: nnx.Rngs = nnx.Rngs(0),
            activation: nnx.Module = ssp_layer
    ):
        super().__init__()
        self.n_atom_types = n_atom_types
        self.embedding = AtomEmbedding(n_atom_types, atom_embedding_dim, rngs=rngs, n_batch=n_batch)

        self.interactions = [
            InteractionBlock(
                atom_embedding_dim, rbf_min, rbf_max, n_rbf, rngs, activation
            )
            for _ in range(n_interactions)
        ]

        self.output_layers = nnx.Sequential(
            nnx.Linear(atom_embedding_dim, 32, rngs=rngs),
            activation(),
            nnx.Linear(32, 1, rngs=rngs)
        )

        self.distances = R_distances()

    def __call__(self, batch):
        Z = batch[0]
        R = batch[1]
        R_distances = self.distances(R)
        X = self.embedding(Z)
        X_interacted = X
        for _, interaction in enumerate(self.interactions):
            X_interacted = interaction(X_interacted, R_distances)

        atom_outputs = self.output_layers(X_interacted)
        predicted_energies = jnp.sum(atom_outputs, axis=1)
        return predicted_energies


In [289]:
import jax_dataloader as jdl
import numpy as np

def create_water_dataset_from_npz(data_name: str):
    water_data = np.load(data_name)
    energy = water_data["E"]
    n_data = len(energy)
    atom_types = jnp.array([water_data["z"] for _ in range(n_data)])
    positions = water_data["R"]
    return jdl.ArrayDataset(atom_types, positions, energy)


# Create a dataset with multiple identical water molecules
dataset = create_water_dataset_from_npz("data_water.npz")
dataloader = jdl.DataLoader(dataset, 'jax', batch_size=10, shuffle=True)
#loader = jdl.DataLoader(dataset=dataset, batch_size=10, backend='jax', shuffle=True)


In [295]:
import optax

learning_rate = 0.005
momentum = 0.9
n_batch = 10

dataloader = jdl.DataLoader(dataset, 'jax', batch_size=n_batch, shuffle=True)
model = SchNet(n_batch=10, n_atom_types=3)
optimizer = nnx.Optimizer(model, optax.adam(learning_rate))
metrics = nnx.MultiMetric(
    accuracy = nnx.metric.Accuracy(),
    loss = nnx.metrics.Average("loss"),
)

nnx.display(optimizer)

ValueError: Arrays leaves are not supported, at 'interactions/0/cf_conv/filters/rbf/centers': [[ 0.          0.10033445  0.2006689   0.30100334  0.4013378   0.50167227
   0.6020067   0.70234114  0.8026756   0.9030101   1.0033445   1.103679
   1.2040133   1.3043479   1.4046823   1.5050168   1.6053512   1.7056856
   1.8060201   1.9063545   2.006689    2.1070235   2.207358    2.3076923
   2.4080267   2.5083613   2.6086957   2.7090302   2.8093646   2.909699
   3.0100336   3.110368    3.2107024   3.3110368   3.4113712   3.5117059
   3.6120403   3.7123747   3.812709    3.9130435   4.013378    4.1137123
   4.214047    4.3143816   4.414716    4.5150504   4.6153846   4.715719
   4.8160534   4.916388    5.0167227   5.117057    5.2173915   5.3177257
   5.4180603   5.518395    5.618729    5.7190638   5.819398    5.9197326
   6.020067    6.1204014   6.220736    6.32107     6.421405    6.5217395
   6.6220737   6.7224083   6.8227425   6.923077    7.0234118   7.123746
   7.2240806   7.3244147   7.4247494   7.525084    7.625418    7.725753
   7.826087    7.9264216   8.026756    8.12709     8.227425    8.32776
   8.428094    8.528428    8.628763    8.729097    8.829432    8.929766
   9.030101    9.130435    9.230769    9.331104    9.431438    9.531773
   9.632107    9.732442    9.832776    9.93311    10.033445   10.13378
  10.234114   10.334449   10.434783   10.535117   10.635451   10.735786
  10.836121   10.936455   11.03679    11.137124   11.237458   11.337793
  11.4381275  11.538462   11.638796   11.739131   11.839465   11.939799
  12.040134   12.140469   12.240803   12.341138   12.441472   12.541806
  12.64214    12.7424755  12.84281    12.943144   13.043479   13.143813
  13.244147   13.344481   13.444817   13.545151   13.645485   13.74582
  13.846154   13.946488   14.0468235  14.147158   14.247492   14.347826
  14.448161   14.548495   14.648829   14.749165   14.849499   14.949833
  15.050168   15.150502   15.250836   15.351171   15.451506   15.55184
  15.652174   15.752509   15.852843   15.953177   16.053513   16.153847
  16.25418    16.354515   16.45485    16.555183   16.65552    16.755854
  16.856188   16.956522   17.056856   17.15719    17.257526   17.35786
  17.458195   17.558529   17.658863   17.759197   17.859531   17.959867
  18.060202   18.160536   18.26087    18.361204   18.461538   18.561872
  18.662209   18.762543   18.862877   18.963211   19.063545   19.16388
  19.264214   19.36455    19.464884   19.565218   19.665552   19.765886
  19.86622    19.966557   20.06689    20.167225   20.26756    20.367893
  20.468227   20.568562   20.668898   20.769232   20.869566   20.9699
  21.070234   21.170568   21.270903   21.371239   21.471573   21.571907
  21.672241   21.772575   21.87291    21.973246   22.07358    22.173914
  22.274248   22.374582   22.474916   22.57525    22.675587   22.77592
  22.876255   22.97659    23.076923   23.177258   23.277592   23.377928
  23.478262   23.578596   23.67893    23.779264   23.879599   23.979933
  24.080269   24.180603   24.280937   24.381271   24.481606   24.58194
  24.682276   24.78261    24.882944   24.983278   25.083612   25.183947
  25.28428    25.384617   25.484951   25.585285   25.68562    25.785954
  25.886288   25.986622   26.086958   26.187292   26.287626   26.38796
  26.488295   26.588629   26.688963   26.789299   26.889633   26.989967
  27.090302   27.190636   27.29097    27.391306   27.49164    27.591974
  27.692308   27.792643   27.892977   27.99331    28.093647   28.193981
  28.294315   28.39465    28.494984   28.595318   28.695652   28.795988
  28.896322   28.996656   29.09699    29.197325   29.297659   29.397993
  29.49833    29.598663   29.698997   29.799332   29.899666   30.        ]]

In [291]:
model = SchNet(n_batch=10, n_atom_types=3)

In [296]:
flax.__version__

'0.10.0'

In [297]:
optax.__version__

'0.2.3'

In [292]:
energy = model(batch)

In [293]:
energy

Array([[-0.16561924],
       [-0.16561924],
       [-0.16561924],
       [-0.16561924],
       [-0.16561924],
       [-0.16561924],
       [-0.16561924],
       [-0.16561924],
       [-0.16561925],
       [-0.1656192 ]], dtype=float32)

In [206]:
energy

Array(0.8646013, dtype=float32)

In [192]:
energy[0].shape

(3, 2, 1)

In [80]:
dis = R_distances(3)

In [81]:
dis(r_d)

Array([[[[[0., 0.]]]],



       [[[[0., 0.]]]],



       [[[[0., 0.]]]]], dtype=float32)

In [45]:
x_s = jnp.split(x_i, 3)

In [46]:
jnp.array(x_s).shape

(3, 2, 64)