In [1]:
import jax
import flax
import jax.numpy as jnp
from flax import linen as nn
from flax import nnx
from typing import Dict

In [None]:
# This works for one sample. I want to know how to do it
# for a batch of samples

In [None]:

# Atom Embedding
# It takes the number of different atoms in the molecular systems.
# For each atom, create an atom embedding.

@jax.jit
def count_unique(x):
  x = jnp.sort(x)
  return 1 + (x[1:] != x[:-1]).sum()

class AtomEmbedding(nnx.Module):
    def __init__(self, num_atom_types, embedding_dim, rngs: nnx.Rngs):
        super(AtomEmbedding, self).__init__()

        self.embedding = nnx.Embed(
            num_embeddings=num_atom_types,
            features=embedding_dim,
            rngs=rngs
        )
        self.embedding_nums = jnp.arange(0, num_atom_types)

    def __call__(self, atom_types):

        ord_atom_types = jnp.unique(atom_types)
        mask = jnp.array([jnp.where(ord_atom_types==atom_types[i]) for i in range(len(atom_types))]).flatten()


        return self.embedding(mask)

In [None]:
# Given a positions thing, r = (r_1, ..., r_n), returns
# a matrix with the relative distance among them
# (d_ij)_ij = |r_i - r_j|^2

class R_distances(nnx.Module):
    def __init__(self):
        super().__init__()

    def __call__(self, R):
        num_atoms = len(R[0])
        Rij = jnp.array([[R[i] - R[j] for j in range(num_atoms)] for i in range(num_atoms)])
        d_ij = jnp.linalg.norm(Rij, axis=-1)
        return d_ij

In [None]:
class RadialBasisFunctions(nnx.Module):
    def __init__(self, rbf_min, rbf_max, n_rbf, gamma=10):
        super().__init__()
        self.rbf_min = rbf_min
        self.rbf_max = rbf_max
        self.n_rbf = n_rbf
        self.gamma = gamma
        self.centers = jnp.linspace(rbf_min, rbf_max, n_rbf).reshape(1, -1)

    def __call__(self, d_ij):
        diff = d_ij[..., None] - self.centers
        return jnp.exp(-self.gamma * jnp.pow(diff, 2))

In [None]:
# ReLU function

class relu_layer(nnx.Module):
    def __init__(self):
        super().__init__()

    def __call__(self, x):
        return nnx.relu(x)

# Shift Softplus layer

class ssp_layer(nnx.Module):
    def __init__(self):
        super().__init__()

    def __call__(self, x):
        return jnp.log(0.5 * jnp.exp(x) + 0.5)

In [None]:
class filter_generator(nnx.Module):
    def __init__(
            self,
            atom_embeddings_dim,
            rbf_min,
            rbf_max,
            n_rbf,
            rngs: nnx.Rngs,
            activation=ssp_layer
            ):

        super().__init__()
        self.rbf = RadialBasisFunctions(rbf_min, rbf_max, n_rbf)
        self.w_layers = nnx.Sequential(
            nnx.Linear(n_rbf, atom_embeddings_dim, rngs=rngs),
            activation(),
            nnx.Linear(atom_embeddings_dim, atom_embeddings_dim, rngs=rngs),
            activation()
        )

    def __call__(self, d_ij):
        rbfs = self.rbf(d_ij)
        Wij = self.w_layers(rbfs)
        return Wij


In [None]:
class CfConv(nnx.Module):
    def __init__(
            self,
            atom_embeddings_dim,
            rbf_min,
            rbf_max,
            n_rbf,
            rngs: nnx.Rngs,
            activation=ssp_layer):
        super().__init__()

        self.rbf = RadialBasisFunctions(rbf_min, rbf_max, n_rbf)
        self.filters = filter_generator(
            atom_embeddings_dim,
            rbf_min,
            rbf_max,
            n_rbf,
            rngs,
            activation
            )

    def __call__(self, X, d_ij):
        fij = self.filters(d_ij)
        X_ij = X * fij
        return X + jnp.sum(X_ij, axis=0)


In [None]:
class InteractionBlock(nnx.Module):

    def __init__(
            self,
            atom_embeddings_dim,
            rbf_min,
            rbf_max,
            n_rbf,
            rngs: nnx.Rngs,
            activation=ssp_layer,
    ):
        super().__init__()
        self.in_atom_wise = nnx.Linear(
            atom_embeddings_dim,
            atom_embeddings_dim,
            rngs=rngs
        )

        self.cf_conv = CfConv(
            atom_embeddings_dim,
            rbf_min,
            rbf_max,
            n_rbf,
            rngs=rngs,
            activation=activation
        )

        self.out_atom_wise = nnx.Sequential(
            nnx.Linear(atom_embeddings_dim, atom_embeddings_dim, rngs=rngs),
            activation(),
            nnx.Linear(atom_embeddings_dim, atom_embeddings_dim, rngs=rngs)
        )

    def __call__(self, X, R_distances):
        X_in = self.in_atom_wise(X)
        X_conv = self.cf_conv(X_in, R_distances)
        V = self.out_atom_wise(X_conv)
        return X + V


In [None]:
class SchNet(nnx.Module):
    def __init__(
            self,
            atom_embedding_dim=64,
            n_interactions=3,
            n_atom_types=2,
            rbf_min=0.,
            rbf_max=30.,
            n_rbf=300,
            rngs: nnx.Rngs = nnx.Rngs(0),
            activation: nnx.Module = ssp_layer
    ):
        super().__init__()
        self.n_atom_types = n_atom_types
        self.embedding = AtomEmbedding(n_atom_types, atom_embedding_dim, rngs=rngs)

        self.interactions = [
            InteractionBlock(
                atom_embedding_dim, rbf_min, rbf_max, n_rbf, rngs, activation
            )
            for _ in range(n_interactions)
        ]

        self.output_layers = nnx.Sequential(
            nnx.Linear(atom_embedding_dim, 32, rngs=rngs),
            activation(),
            nnx.Linear(32, 1, rngs=rngs)
        )

        self.distances = R_distances()

    def __call__(self, Z, R):
        R_distances = self.distances(R)
        X = self.embedding(Z)
        X_interacted = X
        for _, interaction in enumerate(self.interactions):
            X_interacted = interaction(X_interacted, R_distances)

        atom_outputs = self.output_layers(X_interacted)
        predicted_energy = jnp.sum(atom_outputs)
        return predicted_energy


In [None]:
class SchNet(nnx.Module):
    def __init__(
            self,
            atom_embedding_dim=64,
            n_interactions=3,
            n_atom_types=2,
            rbf_min=0.,
            rbf_max=30.,
            n_rbf=300,
            rngs: nnx.Rngs = nnx.Rngs(0),
            activation: nnx.Module = ssp_layer
    ):
        super().__init__()
        self.n_atom_types = n_atom_types
        self.embedding = AtomEmbedding(n_atom_types, atom_embedding_dim, rngs=rngs)

        self.interactions = [
            InteractionBlock(
                atom_embedding_dim, rbf_min, rbf_max, n_rbf, rngs, activation
            )
            for _ in range(n_interactions)
        ]

        self.output_layers = nnx.Sequential(
            nnx.Linear(atom_embedding_dim, 32, rngs=rngs),
            activation(),
            nnx.Linear(32, 1, rngs=rngs)
        )

        self.distances = R_distances()

    def __call__(self, Z, R):
        R_distances = self.distances(R)
        X = self.embedding(Z)
        X_interacted = X
        return X_interacted, R_distances


In [47]:
model = SchNet()

In [48]:
r = jnp.array([[[1., 2., 3.], [4., 5., 6.]],[[1., 2., 3.], [4., 5., 6.]],[[1., 2., 3.], [4., 5., 6.]]])
z = jnp.array([[1,8], [1,8], [1,8]])

In [38]:
z[0]

Array([1, 8], dtype=int32)

In [49]:
x_i, r_d = model(z, r)

In [53]:
x_i.shape

(6, 64)

In [50]:
r_d

Array([[[0.      , 5.196152],
        [5.196152, 0.      ]],

       [[0.      , 5.196152],
        [5.196152, 0.      ]],

       [[0.      , 5.196152],
        [5.196152, 0.      ]]], dtype=float32)

In [36]:
x_i.shape

(6, 64)

In [29]:
r.shape

(3, 2, 3)

In [30]:
z.shape

(3, 2)

In [28]:
model(z, r)

ValueError: Incompatible shapes for broadcasting: shapes=[(6, 64), (3, 3, 2, 64)]

In [16]:
x = jnp.array([1,2,3])
x_2 = x
x_2 = x_2 + 2
print(x)
print(x_2)

[1 2 3]
[3 4 5]


In [9]:
distances = R_distances()
rbf = RadialBasisFunctions(0.0, 10.0, 30)
conv = CfConv(20, 0., 10., 30, rngs=nnx.Rngs(0))
intblock = InteractionBlock(64, 0., 30., 300, rngs=nnx.Rngs(0))
atomemb = AtomEmbedding(2, 64, rngs = nnx.Rngs(0))

In [231]:
distances(r)

Array([[0.      , 5.196152],
       [5.196152, 0.      ]], dtype=float32)

In [10]:
r = jnp.array([[1., 2., 3.], [4., 5., 6.]])
z = jnp.array([1,8])
x = atomemb(z)
d_ij = distances(r)
rbfs = rbf(d_ij)
convs = intblock(x, d_ij)

In [13]:
convs.shape

(2, 64)

In [15]:
[[r[i] - r[j] if i!=j else 0 for j in range(len(r))] for i in range(len(r))]

[[0, Array([-3., -3., -3.], dtype=float32)],
 [Array([3., 3., 3.], dtype=float32), 0]]

In [162]:
convs.shape

(2, 2, 20)

In [170]:
jnp.sum(convs, axis = 0).shape

(2, 20)

In [149]:
(convs * convs).shape

(2, 2, 20)

In [125]:
d_ij

Array([[0.      , 5.196152],
       [5.196152, 0.      ]], dtype=float32)

In [128]:
rbfs

Array([[[1.0000000e+00, 3.0450717e-01, 8.5978527e-03, 2.2510118e-05,
         5.4646212e-09, 1.2300879e-13, 2.5675032e-19, 4.9691211e-26,
         8.9174416e-34, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
         0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
         0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
         0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
         0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
         0.0000000e+00, 0.0000000e+00],
        [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
         0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 2.3932948e-34,
         1.5708493e-26, 9.5601916e-20, 5.3950432e-14, 2.8230238e-09,
         1.3697130e-05, 6.1622960e-03, 2.5706941e-01, 9.9438071e-01,
         3.5665613e-01, 1.1861547e-02, 3.6578665e-05, 1.0459439e-08,
         2.7732200e-13, 6.8180962e-19, 1.5542854e-25, 3.2854171e-33,
         0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000

In [126]:
rbfs.shape

(2, 2, 30)

In [90]:
jnp.linspace(0., 30., 10).reshape(1,-1)

Array([[ 0.       ,  3.3333333,  6.6666665, 10.       , 13.333333 ,
        16.666666 , 20.       , 23.333332 , 26.666666 , 30.       ]],      dtype=float32)

In [88]:
import torch

torch.linspace(0., 30., 10).view(1, -1)

tensor([[ 0.0000,  3.3333,  6.6667, 10.0000, 13.3333, 16.6667, 20.0000, 23.3333,
         26.6667, 30.0000]])

In [102]:
R = jnp.array([[1., 2., 3.], [4., 5., 6.]])

In [105]:
d_ij = jnp.linalg.norm(R, axis=-1);d_ij

Array([3.7416575, 8.774964 ], dtype=float32)

In [106]:
centers = jnp.linspace(0.,30.,30).reshape(1, -1)

In [109]:
diff = d_ij[..., None] - centers

In [112]:
jnp.exp(-10*jnp.pow(diff, 2))

Array([[0.0000000e+00, 1.4840380e-32, 7.0612179e-13, 1.7024335e-02,
        2.0797738e-01, 1.2874062e-09, 4.0380931e-27, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00],
       [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 2.2824060e-29, 6.1079398e-11,
        8.2824282e-02, 5.6907900e-02, 1.9812488e-11, 3.4952461e-30,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0

In [71]:
rdis = R_distances()

In [73]:
coso=rdis(R);coso

Array([[[ 0.,  0.,  0.],
        [ 3.,  3.,  3.]],

       [[-3., -3., -3.],
        [ 0.,  0.,  0.]]], dtype=float32)

In [81]:
jnp.linalg.norm(coso, axis=-1)

Array([[0.      , 5.196152],
       [5.196152, 0.      ]], dtype=float32)

In [82]:
jnp.sqrt(27)

Array(5.196152, dtype=float32, weak_type=True)

In [67]:
import torch
torch.linspace(0., 10, 30).view(1, -1)

tensor([[ 0.0000,  0.3448,  0.6897,  1.0345,  1.3793,  1.7241,  2.0690,  2.4138,
          2.7586,  3.1034,  3.4483,  3.7931,  4.1379,  4.4828,  4.8276,  5.1724,
          5.5172,  5.8621,  6.2069,  6.5517,  6.8966,  7.2414,  7.5862,  7.9310,
          8.2759,  8.6207,  8.9655,  9.3103,  9.6552, 10.0000]])

In [None]:
# Radial Basis Functions
#

In [None]:
# Atom-Wise layer


In [61]:
emb = AtomEmbedding(num_atom_types=2, embedding_dim=5, rngs=nnx.Rngs(0))

In [62]:
atom_types = jnp.array([1,8,8])

In [63]:
emb(atom_types=atom_types)

Array([[-0.16571221, -0.03163584, -0.7436971 ,  0.13232894, -0.08278736],
       [ 0.54843533, -0.27384827,  0.3879608 , -0.69536334,  0.4358869 ],
       [ 0.54843533, -0.27384827,  0.3879608 , -0.69536334,  0.4358869 ]],      dtype=float32)

In [54]:
atom_types = jnp.array([8,1,1,2,3,4,4,7,23,1,1,1,1,1,2])

@jax.jit
def count_unique(x):
  x = jnp.sort(x)
  return 1 + (x[1:] != x[:-1]).sum()

print(count_unique(atom_types))

7


In [37]:
jnp.arange(0, 10)

Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

In [55]:
ordered_and_unique_atom_types = jnp.unique(atom_types); ordered_and_unique_atom_types

Array([ 1,  2,  3,  4,  7,  8, 23], dtype=int32)

In [50]:
num_unique = count_unique(atom_types)

In [51]:
array_emb = jnp.arange(0, num_unique)

In [52]:
array_emb[xd]

Array([5, 0, 0, 1, 2, 3, 3, 4, 6, 0, 0, 0, 0, 0, 1], dtype=int32)

In [46]:
jnp.array([jnp.where(ordered_and_unique_atom_types==atom_types[i]) for i in range(len(atom_types))]).flatten()

Array([5, 0, 0, 1, 2, 3, 3, 4, 6, 0, 0, 0, 0, 0, 1], dtype=int32)

In [47]:
xd = jnp.array([jnp.where(ordered_and_unique_atom_types==atom_types[i]) for i in range(len(atom_types))]).flatten()

In [48]:
xd[0]

Array(5, dtype=int32)

In [49]:
atom_types[xd[0]]

Array(4, dtype=int32)

In [28]:
atom_types[1:]!=atom_types[:-1]

Array([ True, False,  True,  True,  True, False,  True,  True,  True,
       False, False, False, False,  True], dtype=bool)

In [29]:
atom_types[1:]

Array([ 1,  1,  2,  3,  4,  4,  7, 23,  1,  1,  1,  1,  1,  2], dtype=int32)

In [30]:
atom_types[:-1]

Array([ 8,  1,  1,  2,  3,  4,  4,  7, 23,  1,  1,  1,  1,  1], dtype=int32)

In [3]:
atom_types.take(0)

Array(8, dtype=int32)