In [8]:
import jax
import flax
import jax.numpy as jnp
from flax import linen as nn
from flax import nnx
from typing import Dict

import torch

from schnetpack.nn.scatter import scatter_add

In [None]:
class RadialBasisFunctions(nnx.Module):

    def __init__(self,
                rbf_min,
                rbf_max,
                n_rbf,
                gamma=10):

        self.gamma = gamma
        self.centers = jnp.linspace(rbf_min, rbf_max, n_rbf)

    def __call__(self, R_distances):
        diff = R_distances[..., None] - self.centers
        return jnp.exp(-self.gamma * jnp.pow(diff, 2))


class CfConv(nnx.Module):

    def __init__(self,
                atom_embeddings_dim: int,
                rbf_min: float, rbf_max: float,
                n_rbf: int,
                activation: nnx.Module):

        self.rbf = RadialBasisFunctions(rbf_min, rbf_max, n_rbf)
        self.w_layers = nnx.Sequential(
            nnx.Linear(n_rbf, atom_embeddings_dim),
            activation(),
            nnx.Linear(atom_embeddings_dim, atom_embeddings_dim),
            activation(),
        )

    def __call__(self, X, R_distances, idx_i, idx_j):

        # Given:
        # 1) R_distances[i][j] = ||r_i - r_j||

        # 2) rbf, 300

        radial_basis_distances = self.rbf(R_distances)

        # 3) dense, 64
        # 4) shifted softplus
        # 5) dense, 64
        # 6) shifted sloftplus

        Wij = self.w_layers(radial_basis_distances)

        # continuous-filter convolution output
        # X * W

        x_j = X[idx_j]
        x_ij = x_j * Wij
        return scatter_add(x_ij, idx_i, dim_size = X.shape[0])


class SchNetIteraction(nnx.Module):

    def __init__(
            self,
            atom_embeddings_dim: int,
            rbf_min: float,
            rbf_max: float,
            n_rbf: int,
            activation: nnx.Module,
    ):

        self.in_atom_wise = nnx.Linear(
            atom_embeddings_dim,
            atom_embeddings_dim
        )

        self.cf_conv = CfConv(
            atom_embeddings_dim,
            rbf_min,
            rbf_max,
            n_rbf,
            activation
        )

        self.out_atom_wise = nnx.Sequential(
            nn.Linear(
                atom_embeddings_dim,
                atom_embeddings_dim
            ),
            activation(),
            nn.Linear(
                atom_embeddings_dim,
                atom_embeddings_dim
            )
        )

        def __call__(
                self,
                X,
                R_distances,
                idx_i,
                idx_j,
        ):
            # 1) atom-wise, 64
            X_in = self.in_atom_wise(X)

            # 2) cfconv, 64
            X_conv = self.cf_conv(X_in, R_distances, idx_i, idx_j)

            # 3) atom-wise, 64
            # 4) shifted softplus
            # 5) atom-wise, 64

            V = self.out_atom_wise(X_conv)

            X_residual = X + V
            return X_residual

class ShiftedSoftPlus(nnx.Module):

    def __init__(self):
        self.log_one_half = jnp.log(jnp.array(0.5))
        self.softplus = nnx.softplus()

    def __call__(self, x):
        return self.log_one_half + self.softplus(x)

class PairwiseDistances(nnx.Module):

    @staticmethod
    def forward(inputs: Dict[str, jnp.ndarray]) -> jnp.ndarray:
        R = inputs["R"]
        idx_i = inputs["idx_i"]
        idx_j = inputs["idx_j"]

        Rij = R[idx_i] - R[idx_j]
        d_ij = jnp.norm(Rij, dim = -1)
        return d_ij

class SchNet(nnx.Module):
    def __init__(
            self,
            atom_embedding_dim=64,
            n_interactions=3,
            max_z=100,
            rbf_min=0.,
            rbf_max=30.,
            n_rbf=300,
            activation: nnx.Module = ShiftedSoftPlus,
            writer = None,
            running_mean_var = True,
    ):

        self.time_step = 0
        self.writer = writer
        self.max_z = max_z

        self.embedding = nnx.Embed(
            num_embeddings=max_z,
            features=atom_embedding_dim
        )

        self.interactions = [
            SchNetIteraction(
                atom_embedding_dim,
                rbf_min,
                rbf_max,
                n_rbf,
                activation
            )
            for _ in range(n_interactions)
        ]
f
        self.output_laters = nnx.Sequential(
            nnx.Linear(atom_embedding_dim, 32),
            activation(),
            nn.Linear(32,1),
        )

        self.pairwise = PairwiseDistances()

        # self.running_mean_var = running_mean_var

        # if running_mean_var:
        #     from welford_torch import Weldford

        # self.welford_E = Welford()
        # self.welford_F = Welford()

    def __call__(self, inputs: Dict):
        Z = inputs["Z"]

        N = inputs["N"]

        R_distances = self.pairwise(inputs)

        idx_i = inputs["idx_i"]
        idx_j = inputs["idx_j"]

        # 1) Embedding 64

        X = self.embedding(X)

        # 2), 3), 4) Each interaction 64 with residual layer
        X_interacted = X
        for i, interaction in enumerate(self.interaction):
            X_interacted = interaction(X_interacted, R_distances, idx_i, idx_j)

        # 5) atom-wise 32
        # 6) Shifted Softplus
        # 7) atom-wise 1

        atom_outputs = self.output_layers(X_interacted)

        # Assign Flattened Atoms back to molecules
        atom_partitions = jnp.split(
            atom_outputs, N.tolist() if isinstance(N, jnp.array) else N
        )

        # 8) Sum Pooling

        predicted_energies = jnp.stack([p.sum() for p in atom_partitions])

        self.time_step += 1
        return predicted_energies

In [None]:
# Example: Create a single water molecule
def create_water_molecule():
    # Atom types: Oxygen (0), Hydrogen (1)
    atom_types = jnp.array([0, 1, 1])
    # Positions in angstroms
    pos = jnp.array(
        [
        [0.0, 0.0, 0.0],        # Oxygen
        [0.96, 0.0, 0.0],       # Hydrogen 1
        [-0.24, 0.93, 0.0]      # Hydrogen 2
        ]
    )
    # Assume total energy is -76.0 eV (for example)
    energy = jnp.array([-76.0])
    data = Data(x=atom_types, pos=pos, y=energy)
    return data

# Create a dataset with multiple identical water molecules
dataset = [create_water_molecule() for _ in range(100)]
loader = DataLoader(dataset, batch_size=10, shuffle=True)


In [9]:
from flax.nnx import softplus

In [11]:
softplus(0.5)

Array(0.974077, dtype=float32, weak_type=True)

In [None]:

embedding = nn.Embedding(10, 20, padding_idx=0)