In [None]:
import jax
import flax
import jax.numpy as jnp
from flax import linen as nn
from jax import random
import numpy as np
#from flax import nnx
from typing import Dict

import jax_dataloader as jdl
from functools import partial
#######

import torch
from torch_geometric.loader import DataLoader
from torch.utils.data import Dataset
from torch_geometric.data import Data#, DataLoader


In [79]:
class WaterDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, npz_file):
        """
        Arguments:
            npz_file (string): Path to the npz file with data.
        """
        data = np.load(npz_file)
        self.E = data["E"]
        self.Z = data["z"]
        self.R = data["R"]

    def __len__(self):
        return len(self.E)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        sample = {'positions': self.R[idx], 'atom_types': self.Z, 'energy': self.E[idx]}

        return sample

In [299]:
data = np.load("data_water.npz")

In [302]:
data["E"].shape

(99999, 1)

In [304]:
data["R"].shape

(99999, 3, 3)

In [300]:
data["z"]

array([8, 1, 1])

In [305]:
water_dataset = WaterDataset("data_water.npz")

In [306]:
next(iter(water_dataset))

{'positions': array([[0.        , 0.        , 0.        ],
        [1.88972613, 0.        , 0.        ],
        [0.        , 1.88972613, 0.        ]]),
 'atom_types': array([8, 1, 1]),
 'energy': array([-76377.68573031])}

In [312]:
data["R"]

array([[[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [ 1.88972613e+00,  0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  1.88972613e+00,  0.00000000e+00]],

       [[-1.27435571e-03, -1.91089106e-04,  0.00000000e+00],
        [ 1.89716851e+00,  1.27858870e-02,  0.00000000e+00],
        [ 1.27858870e-02,  1.87997344e+00,  0.00000000e+00]],

       [[-1.18107883e-04,  1.66868486e-03,  0.00000000e+00],
        [ 1.88682094e+00,  4.94677388e-03,  0.00000000e+00],
        [ 4.77983547e-03,  1.85829159e+00,  0.00000000e+00]],

       ...,

       [[ 5.58028378e-02, -1.99056947e-02,  0.00000000e+00],
        [ 1.91782596e+00,  7.38274481e-01,  0.00000000e+00],
        [-9.13878566e-01,  1.46742206e+00,  0.00000000e+00]],

       [[ 4.98592523e-02, -3.73874567e-02,  0.00000000e+00],
        [ 2.03012446e+00,  8.53936283e-01,  0.00000000e+00],
        [-9.31832533e-01,  1.62925450e+00,  0.00000000e+00]],

       [[ 4.85929090e-02, -4.77873943e-02,  0.00000000e+00],
 

In [81]:
water_loader = jdl.DataLoader(water_dataset, backend='pytorch', batch_size=5, shuffle=True, drop_last=False)

In [307]:
batch = next(iter(water_loader))

In [311]:
batch["positions"]

array([[[ 3.92995004e-03, -7.90343938e-03,  0.00000000e+00],
        [ 2.03940170e+00, -1.68464209e-01,  0.00000000e+00],
        [-2.12057224e-01,  2.18364445e+00,  0.00000000e+00]],

       [[ 3.62524493e-02,  1.48215000e-03,  0.00000000e+00],
        [ 1.92562951e+00,  1.13005207e-01,  0.00000000e+00],
        [-6.11351762e-01,  1.75319415e+00,  0.00000000e+00]],

       [[ 5.84968691e-02, -4.13975122e-02,  0.00000000e+00],
        [ 1.96645069e+00,  7.37222187e-01,  0.00000000e+00],
        [-1.00526687e+00,  1.80962190e+00,  0.00000000e+00]],

       [[ 1.82980102e-02, -3.71793222e-02,  0.00000000e+00],
        [ 1.93921977e+00,  6.39939974e-01,  0.00000000e+00],
        [-3.39944534e-01,  1.83994709e+00,  0.00000000e+00]],

       [[ 3.99603755e-02,  6.83375990e-03,  0.00000000e+00],
        [ 1.87247778e+00,  2.30835338e-02,  0.00000000e+00],
        [-6.17057280e-01,  1.75816784e+00,  0.00000000e+00]]])

In [None]:
class AtomEmbedding(nn.Module):
    num_atom_types: int
    embedding_dim: int
    def setup(self):
        self.embedding = nn.Embed(
            num_embeddings=self.num_atom_types,
            features=self.embedding_dim
        )

    @partial(jax.vmap, in_axes = (None, 0), out_axes = 0)
    def __call__(self, atom_types):
        ord_atom_types = jnp.unique(atom_types, size=self.num_atom_types)
        mask = jnp.asarray([jnp.where(ord_atom_types==elem, size=1) for elem in atom_types]).squeeze()
        return jnp.array(self.embedding(mask)).squeeze()

In [314]:
print(z)

[[8 1 1]
 [8 1 1]
 [8 1 1]
 [8 1 1]
 [8 1 1]]


In [313]:
z = batch["atom_types"]
print(z.shape)
emb = AtomEmbedding(num_atom_types=2, embedding_dim=2)
key1, key2 = random.split(random.key(0), 2)
params = emb.init(key2, z)
y = emb.apply(params, z)
print(y.shape)

(5, 3)
(5, 3, 2)


In [159]:
# Given a positions thing, r = (r_1, ..., r_n), returns
# a matrix with the relative distance among them
# (d_ij)_ij = |r_i - r_j|^2

class R_distances(nn.Module):

    @partial(jax.vmap, in_axes = (None, 0), out_axes = 0)
    def __call__(self, R):
        Rij = jnp.array([[r_i - r_j for r_j in R] for r_i in R])
        d_ij = jnp.linalg.norm(Rij, axis=-1)
        return d_ij

In [175]:
dis = R_distances()
R = batch["positions"]
distances = dis(R)
print(distances.shape)

(5, 3, 3)


In [176]:
class RadialBasisFunctions(nn.Module):
    rbf_min: float
    rbf_max: float
    n_rbf: int
    gamma: float = 10

    def setup(self):
        self.centers = jnp.linspace(self.rbf_min, self.rbf_max, self.n_rbf).reshape(1, -1)

    @partial(jax.vmap, in_axes = (None, 0), out_axes = 0)
    def __call__(self, d_ij):
        diff = d_ij[..., None] - self.centers
        return jnp.exp(-self.gamma * jnp.pow(diff, 2))

In [180]:
rbf = RadialBasisFunctions(0.0, 30., 10)
key1, key2 = random.split(key1, 2)
params_rbf = rbf.init(key2, distances)
rbfs = rbf.apply(params_rbf, distances)
print(rbfs.shape)

(5, 3, 3, 10)


In [181]:
# ReLU function

class relu_layer(nn.Module):
    def __call__(self, x):
        return nn.relu(x)

# Shift Softplus layer

class ssp_layer(nn.Module):
    def __call__(self, x):
        return jnp.log(0.5 * jnp.exp(x) + 0.5)

In [None]:
class filter_generator(nn.Module):
    atom_embeddings_dim: int
    rbf_min: float
    rbf_max: float
    n_rbf: int
    activation: nn.Module = ssp_layer

    def setup(self):
        self.rbf = RadialBasisFunctions(
            self.rbf_min, 
            self.rbf_max, 
            self.n_rbf
            )
        
        self.w_layers = nn.Sequential([
            nn.Dense(self.n_rbf, self.atom_embeddings_dim),
            self.activation(),
            nn.Dense(self.atom_embeddings_dim, self.atom_embeddings_dim),
            self.activation()
            ])

    @partial(jax.vmap, in_axes = (None, 0), out_axes = 0)
    def __call__(self, d_ij):
        rbfs = self.rbf(d_ij)
        Wij = self.w_layers(rbfs)
        return Wij


In [None]:
filter_gen = filter_generator(2, 0., 30., 10)
key1, key2 = random.split(key1, 2)
params_fil = filter_gen.init(key2, distances)
filters = filter_gen.apply(params_fil, distances)
print(filters.shape)
print(y.shape)

(5, 3, 3, 2)
(5, 3, 2)


In [265]:
class CfConv(nn.Module):
    
    @partial(jax.vmap, in_axes = (None, 0, 0), out_axes = 0)
    def __call__(self, X, fij):
        ones = jnp.ones(shape=fij.shape)
        jax.lax.fori_loop(0, len(ones), lambda i, ones_i: ones_i.at[i,i].set(jnp.zeros(shape=ones_i[i,i].shape)), ones)
        filters = jnp.sum(fij, axis = 1)
        Xij = X * filters
        jnp.sum(Xij, axis = 0, where = ones)
        return X + jnp.sum(Xij, axis=0)

In [None]:
conv = CfConv()
convs = conv(y, filters)
print(convs.shape)

(5, 3, 2)


In [None]:
class InteractionBlock(nn.Module):
    atom_embeddings_dim: int
    rbf_min: float
    rbf_max: float
    n_rbf: int
    activation: nn.Module = ssp_layer

    def setup(self):
        self.in_atom_wise = nn.Dense(
            self.atom_embeddings_dim,
            self.atom_embeddings_dim
            )
        
        self.out_atom_wise = nn.Sequential([
            nn.Dense(self.atom_embeddings_dim, self.atom_embeddings_dim),
            self.activation(),
            nn.Dense(self.atom_embeddings_dim, self.atom_embeddings_dim)
            ])

        self.filters = filter_generator(
            self.atom_embeddings_dim,
            self.rbf_min,
            self.rbf_max,
            self.n_rbf
            )

        self.cf_conv = CfConv()

        self.distances = R_distances()

    @partial(jax.vmap, in_axes = (None, 0, 0), out_axes = 0)
    def __call__(self, X, R):
        X_in = self.in_atom_wise(X)
        R_distances = self.distances(R)
        fils = self.filters(R_distances)
        X_conv = self.cf_conv(X_in, fils)
        V = self.out_atom_wise(X_conv)
        return X + V

In [290]:
emb = AtomEmbedding(num_atom_types=2, embedding_dim=2)
intblock = InteractionBlock(2, 0., 30., 10)
Z = batch["atom_types"]
R = batch["positions"]
key1, key2, key3 = random.split(key1, 3)
params = emb.init(key2, Z)
X = emb.apply(params, Z)
params_int = intblock.init(key3, X, R)
interation = intblock.apply(params_int, X, R)
print(interation.shape)

(5, 3, 2)


In [297]:
class SchNet(nn.Module):
    atom_embedding_dim: int = 64
    atom_wise_out: int = 32
    n_interactions: int = 3
    n_atom_types: int = 2
    rbf_min: float = 0.
    rbf_max: float = 30.
    n_rbf: int = 300
    activation: nn.Module = ssp_layer

    def setup(self):
        self.embedding = AtomEmbedding(
            self.n_atom_types,
            self.atom_embedding_dim
        )

        self.interactions = nn.Sequential([
            InteractionBlock(
                self.atom_embedding_dim, self.rbf_min, self.rbf_max, self.n_rbf, self.activation
            )
        for _ in range(self.n_interactions)])

        self.output_layers = nn.Sequential([
            nn.Dense(self.atom_embedding_dim, self.atom_wise_out),
            self.activation(),
            nn.Dense(self.atom_wise_out, 1)
        ])

    @partial(jax.vmap, in_axes = (None, 0, 0), out_axes = 0)
    def __call__(self, Z, R):
        print(Z.shape)
        X = self.embedding(Z)
        X_interacted = X
        X_interacted = self.interactions(X_interacted, R)
        atom_outputs = self.output_layers(X_interacted)
        predicted_energies = jnp.sum(atom_outputs, axis = -1)
        return predicted_energies

In [298]:
sch = SchNet()
Z = batch["atom_types"]
R = batch["positions"]
key1, key2 = random.split(key1, 2)
params_sch = sch.init(key2, Z, R)
energies = sch.apply(params_sch, Z, R)


(3,)


TypeError: iteration over a 0-d array

In [None]:
class SchNet(nnx.Module):
    def __init__(
            self,
            n_batch,
            atom_embedding_dim=64,
            n_interactions=3,
            n_atom_types=3,
            rbf_min=0.,
            rbf_max=30.,
            n_rbf=300,
            rngs: nnx.Rngs = nnx.Rngs(0),
            activation: nnx.Module = ssp_layer
    ):
        super().__init__()
        self.n_atom_types = n_atom_types
        self.embedding = AtomEmbedding(n_atom_types, atom_embedding_dim, rngs=rngs, n_batch=n_batch)

        self.interactions = [
            InteractionBlock(
                atom_embedding_dim, rbf_min, rbf_max, n_rbf, rngs, activation
            )
            for _ in range(n_interactions)
        ]

        self.output_layers = nnx.Sequential(
            nnx.Linear(atom_embedding_dim, 32, rngs=rngs),
            activation(),
            nnx.Linear(32, 1, rngs=rngs)
        )

        self.distances = R_distances()

    def __call__(self, batch):
        Z = batch[0]
        R = batch[1]
        R_distances = self.distances(R)
        X = self.embedding(Z)
        X_interacted = X
        for _, interaction in enumerate(self.interactions):
            X_interacted = interaction(X_interacted, R_distances)

        atom_outputs = self.output_layers(X_interacted)
        predicted_energies = jnp.sum(atom_outputs, axis=1)
        return predicted_energies


In [284]:
X.shape

(5, 3)

In [None]:
class InteractionBlock(nnx.Module):

    def __init__(
            self,
            atom_embeddings_dim,
            rbf_min,
            rbf_max,
            n_rbf,
            rngs: nnx.Rngs,
            activation=ssp_layer,
    ):
        super().__init__()
        self.in_atom_wise = nnx.Linear(
            atom_embeddings_dim,
            atom_embeddings_dim,
            rngs=rngs
        )
        
        self.cf_conv = CfConv(
            atom_embeddings_dim,
            rbf_min,
            rbf_max,
            n_rbf,
            rngs=rngs,
            activation=activation
        )

        self.out_atom_wise = nnx.Sequential(
            nnx.Linear(atom_embeddings_dim, atom_embeddings_dim, rngs=rngs),
            activation(),
            nnx.Linear(atom_embeddings_dim, atom_embeddings_dim, rngs=rngs)
        )

    def __call__(self, X, R_distances):
        X_in = self.in_atom_wise(X)
        X_conv = self.cf_conv(X_in, R_distances)
        V = self.out_atom_wise(X_conv)
        return X + V


In [205]:
sample_y = y[0]
sample_fil = filters[0]

In [207]:
print(sample_y.shape, sample_fil.shape)

(3, 2) (3, 3, 2)


In [218]:
print(sample_y)
print("----")
print(sample_fil)

[[0.5235133  0.39598924]
 [1.3146187  0.842151  ]
 [1.3146187  0.842151  ]]
----
[[[-0.04081513 -0.04983531]
  [ 0.          0.        ]
  [ 0.          0.        ]]

 [[ 0.          0.        ]
  [-0.04081513 -0.04983531]
  [ 0.00211587  0.00583054]]

 [[ 0.          0.        ]
  [ 0.00211587  0.00583054]
  [-0.04081513 -0.04983531]]]


In [223]:
-0.0408151 * 1.3146187

-0.05365629370237

In [225]:
xd = (sample_fil * sample_y);xd

Array([[[-0.02136726, -0.01973425],
        [ 0.        ,  0.        ],
        [ 0.        ,  0.        ]],

       [[ 0.        ,  0.        ],
        [-0.05365633, -0.04196886],
        [ 0.00278156,  0.0049102 ]],

       [[ 0.        ,  0.        ],
        [ 0.00278156,  0.0049102 ],
        [-0.05365633, -0.04196886]]], dtype=float32)

In [234]:
ones = jnp.ones(shape=xd.shape)

In [None]:
ones.at[i,i].set([0.,0.]) for i in range(len(ones))

Array([[[0., 0.],
        [1., 1.],
        [1., 1.]],

       [[1., 1.],
        [1., 1.],
        [1., 1.]],

       [[1., 1.],
        [1., 1.],
        [1., 1.]]], dtype=float32)

In [240]:
for i in range(len(ones)):
    ones = ones.at[i,i].set([0.,0.])

In [250]:
ones = jnp.ones(shape=xd.shape)

In [None]:
def func(i, val):
    

In [258]:
f = lambda i, ones_i: ones_i.at[i,i].set([0., 0.])

In [256]:
x = lambda a,b : a*b

In [257]:
x(1,2)

2

In [259]:
f(0,ones)

Array([[[0., 0.],
        [1., 1.],
        [1., 1.]],

       [[1., 1.],
        [1., 1.],
        [1., 1.]],

       [[1., 1.],
        [1., 1.],
        [1., 1.]]], dtype=float32)

In [262]:
jax.lax.fori_loop(0, len(ones), lambda i, ones_i: ones_i.at[i,i].set(jnp.zeros(shape=ones_i[i,i].shape)), ones)

Array([[[0., 0.],
        [1., 1.],
        [1., 1.]],

       [[1., 1.],
        [0., 0.],
        [1., 1.]],

       [[1., 1.],
        [1., 1.],
        [0., 0.]]], dtype=float32)

In [241]:
ones

Array([[[0., 0.],
        [1., 1.],
        [1., 1.]],

       [[1., 1.],
        [0., 0.],
        [1., 1.]],

       [[1., 1.],
        [1., 1.],
        [0., 0.]]], dtype=float32)

In [243]:
xd

Array([[[-0.02136726, -0.01973425],
        [ 0.        ,  0.        ],
        [ 0.        ,  0.        ]],

       [[ 0.        ,  0.        ],
        [-0.05365633, -0.04196886],
        [ 0.00278156,  0.0049102 ]],

       [[ 0.        ,  0.        ],
        [ 0.00278156,  0.0049102 ],
        [-0.05365633, -0.04196886]]], dtype=float32)

In [245]:
xdd = jnp.sum(xd, axis = 0, where = ones); xdd

Array([[0.        , 0.        ],
       [0.00278156, 0.0049102 ],
       [0.00278156, 0.0049102 ]], dtype=float32)

In [248]:
xd

Array([[[-0.02136726, -0.01973425],
        [ 0.        ,  0.        ],
        [ 0.        ,  0.        ]],

       [[ 0.        ,  0.        ],
        [-0.05365633, -0.04196886],
        [ 0.00278156,  0.0049102 ]],

       [[ 0.        ,  0.        ],
        [ 0.00278156,  0.0049102 ],
        [-0.05365633, -0.04196886]]], dtype=float32)

In [247]:
xd

Array([[[-0.02136726, -0.01973425],
        [ 0.        ,  0.        ],
        [ 0.        ,  0.        ]],

       [[ 0.        ,  0.        ],
        [-0.05365633, -0.04196886],
        [ 0.00278156,  0.0049102 ]],

       [[ 0.        ,  0.        ],
        [ 0.00278156,  0.0049102 ],
        [-0.05365633, -0.04196886]]], dtype=float32)

In [249]:
sample_y + xdd

Array([[0.5235133 , 0.39598924],
       [1.3174002 , 0.84706116],
       [1.3174002 , 0.84706116]], dtype=float32)

In [231]:
jnp.fill_diagonal(xd, 0, inplace=False)

ValueError: All dimensions of input must be of equal length

In [208]:
sample_fil

Array([[[-0.04081513, -0.04983531],
        [ 0.        ,  0.        ],
        [ 0.        ,  0.        ]],

       [[ 0.        ,  0.        ],
        [-0.04081513, -0.04983531],
        [ 0.00211587,  0.00583054]],

       [[ 0.        ,  0.        ],
        [ 0.00211587,  0.00583054],
        [-0.04081513, -0.04983531]]], dtype=float32)

In [214]:
jnp.tri(3, k=-1)*sample_fil

ValueError: Incompatible shapes for broadcasting: shapes=[(3, 3), (3, 3, 2)]

In [209]:
jnp.triu(sample_fil)

Array([[[-0.04081513, -0.04983531],
        [ 0.        ,  0.        ],
        [ 0.        ,  0.        ]],

       [[ 0.        ,  0.        ],
        [ 0.        , -0.04983531],
        [ 0.        ,  0.        ]],

       [[ 0.        ,  0.        ],
        [ 0.        ,  0.00583054],
        [ 0.        ,  0.        ]]], dtype=float32)

In [211]:
sample_fil[1,0]

Array([0., 0.], dtype=float32)

In [None]:
class CfConv(nn.Module):
    
    @partial(jax.vmap, in_axes = (None, 0, 0), out_axes = 0)
    def __call__(self, X, fij):
        ones = jnp.ones(shape=fij.shape)
        
        filters = jnp.sum(fij, axis = 1)
        Xij = X * filters
        return X + jnp.sum(Xij, axis=0)

In [203]:
distances.shape

(5, 3, 3)

In [202]:
conv = CfConv(2, 0., 30., 10)
key1, key2 = random.split(key1, 2)
params_conv = conv.init(key2, y, distances)
convs = conv.apply(params_conv, y, distances)
print(convs.shape)

(3, 3, 1, 2)
(3, 3, 1, 2)
(5, 3, 2)


In [None]:
class CfConv(nnx.Module):
    def __init__(
            self,
            atom_embeddings_dim,
            rbf_min,
            rbf_max,
            n_rbf,
            rngs: nnx.Rngs = nnx.Rngs(0),
            activation=ssp_layer):
        super().__init__()

        self.rbf = RadialBasisFunctions(rbf_min, rbf_max, n_rbf)
        self.filters = filter_generator(
            atom_embeddings_dim,
            rbf_min,
            rbf_max,
            n_rbf,
            rngs,
            activation
            )

    def __call__(self, X, d_ij):
        fij = jnp.sum(self.filters(d_ij), axis = 1)
        X_ij = X * fij
        return X + jnp.sum(X_ij, axis=0)


In [42]:
class AtomEmbedding(nn.Module):
    num_atom_types: int
    embedding_dim: int
    n_batch: int
    def setup(self):
        self.embedding = nn.Embed(
            num_embeddings=self.num_atom_types,
            features=self.embedding_dim
        )

    def __call__(self, atom_types):
        ord_atom_types = jnp.unique(atom_types)
        mask = jnp.array([jnp.where(ord_atom_types==elem) for elem in atom_types.flatten()]).flatten()


        return jnp.array(jnp.split(self.embedding(mask), self.n_batch))

In [45]:
z = sample["atom_types"]
emb = AtomEmbedding(num_atom_types=2, embedding_dim=64)
key1, key2 = random.split(random.key(0), 2)
params = emb.init(key2, z)
y = emb.apply(params, z)

In [47]:
y.shape

(3, 1, 1, 64)

In [None]:

# Atom Embedding
# It takes the number of different atoms in the molecular systems.
# For each atom, create an atom embedding.

@jax.jit
def count_unique(x):
  x = jnp.sort(x)
  return 1 + (x[1:] != x[:-1]).sum()

class AtomEmbedding(nn.Module):
    num_atom_types: int
    embedding_dim: int
    n_batch: int
    def setup(self):
        self.embedding = nn.Embed(
            num_embeddings=self.num_atom_types,
            features=self.embedding_dim
        )

    def __call__(self, atom_types):
        ord_atom_types = jnp.unique(atom_types)
        mask = jnp.array([jnp.where(ord_atom_types==elem) for elem in atom_types.flatten()]).flatten()


        return jnp.array(jnp.split(self.embedding(mask), self.n_batch))

In [3]:
r = jnp.array([[[1., 2., 3.], [4., 5., 6.]],[[1., 2., 3.], [4., 5., 6.]],[[1., 2., 3.], [4., 5., 6.]]])
z = jnp.array([[1,8], [1,8], [1,8]])

rbf_min = 0.
rbf_max = 10.
n_rbf = 4
num_atom_types = 2
embedding_dim=5
n_batch = 3
emb = AtomEmbedding(num_atom_types=num_atom_types, embedding_dim=embedding_dim, n_batch=n_batch)

In [4]:
key1, key2 = random.split(random.key(0), 2)
params = emb.init(key2, z)
y = emb.apply(params, z)

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))
print('output:\n', y)

initialized parameter shapes:
 {'params': {'embedding': {'embedding': (2, 5)}}}
output:
 [[[ 0.02224981 -0.03897022  0.4141354  -0.64674884  0.04295832]
  [ 0.28187937  0.35178837  0.28512442  0.165287    0.21933578]]

 [[ 0.02224981 -0.03897022  0.4141354  -0.64674884  0.04295832]
  [ 0.28187937  0.35178837  0.28512442  0.165287    0.21933578]]

 [[ 0.02224981 -0.03897022  0.4141354  -0.64674884  0.04295832]
  [ 0.28187937  0.35178837  0.28512442  0.165287    0.21933578]]]


In [5]:
# Given a positions thing, r = (r_1, ..., r_n), returns
# a matrix with the relative distance among them
# (d_ij)_ij = |r_i - r_j|^2

class R_distances(nn.Module):

    def __call__(self, R):
        num_atoms = len(R[0])
        Rij = jnp.array([[[r[i] - r[j] for j in range(num_atoms)] for i in range(num_atoms)] for r in R])
        d_ij = jnp.linalg.norm(Rij, axis=-1)
        return d_ij

In [6]:
distances = R_distances()
dis = distances(r)

In [7]:
dis

Array([[[0.      , 5.196152],
        [5.196152, 0.      ]],

       [[0.      , 5.196152],
        [5.196152, 0.      ]],

       [[0.      , 5.196152],
        [5.196152, 0.      ]]], dtype=float32)

In [None]:
class RadialBasisFunctions(nn.Module):
    rbf_min: float
    rbf_max: float
    n_rbf: int
    gamma: float = 10

    def setup(self):
        self.centers = jnp.linspace(self.rbf_min, self.rbf_max, self.n_rbf).reshape(1, -1)

    def __call__(self, d_ij):
        diff = d_ij[..., None] - self.centers
        return jnp.exp(-self.gamma * jnp.pow(diff, 2))

In [45]:
num_atom_types = 2
embedding_dim=5
n_batch = 10
rbf = RadialBasisFunctions(rbf_min, rbf_max, n_rbf)

In [46]:
# ReLU function

class relu_layer(nn.Module):
    def __call__(self, x):
        return nn.relu(x)

# Shift Softplus layer

class ssp_layer(nn.Module):
    def __call__(self, x):
        return jnp.log(0.5 * jnp.exp(x) + 0.5)

In [None]:
class filter_generator(nn.Module):
    atom_embeddings_dim: int
    rbf_min: float
    rbf_max: float
    n_rbf: int
    activation: nn.Module = ssp_layer

    def setup(self):
        self.rbf = RadialBasisFunctions(self.rbf_min, self.rbf_max, self.n_rbf)
        self.w_layers = nn.Sequential([
            nn.Dense(self.n_rbf, self.atom_embeddings_dim),
            self.activation(),
            nn.Dense(self.atom_embeddings_dim, self.atom_embeddings_dim),
            self.activation()
        ])

    def __call__(self, d_ij):
        rbfs = self.rbf(d_ij)
        Wij = self.w_layers(rbfs)
        return Wij


In [48]:
fil_gen = filter_generator(atom_embeddings_dim=embedding_dim, rbf_min=rbf_min, rbf_max=rbf_max, n_rbf=n_rbf)

In [50]:
key1, key2 = random.split(key1)

params_f = fil_gen.init(key2, dis)
y_f = fil_gen.apply(params_f, dis)

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params_f)))
print('output:\n', y_f, y_f.shape)

initialized parameter shapes:
 {'params': {'w_layers': {'layers_0': {'bias': (4,), 'kernel': (4, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}}
output:
 [[[[-0.01525517  0.36826965  0.08764708 -0.09227908 -0.18746187]
   [ 0.          0.          0.          0.          0.        ]]

  [[ 0.          0.          0.          0.          0.        ]
   [-0.01525517  0.36826965  0.08764708 -0.09227908 -0.18746187]]]


 [[[-0.01525517  0.36826965  0.08764708 -0.09227908 -0.18746187]
   [ 0.          0.          0.          0.          0.        ]]

  [[ 0.          0.          0.          0.          0.        ]
   [-0.01525517  0.36826965  0.08764708 -0.09227908 -0.18746187]]]


 [[[-0.01525517  0.36826965  0.08764708 -0.09227908 -0.18746187]
   [ 0.          0.          0.          0.          0.        ]]

  [[ 0.          0.          0.          0.          0.        ]
   [-0.01525517  0.36826965  0.08764708 -0.09227908 -0.18746187]]]] (3, 2, 2, 5)
