In [194]:
import jax
import flax
import jax.numpy as jnp
from flax import linen as nn
from jax import random
import numpy as np
#from flax import nnx
from typing import Dict

import jax_dataloader as jdl
from functools import partial
#######

import torch
from torch_geometric.loader import DataLoader
from torch.utils.data import Dataset
from torch_geometric.data import Data#, DataLoader


In [195]:
class WaterDataset(Dataset):
    """Water molecules dataset."""

    def __init__(self, npz_file):
        """
        Arguments:
            npz_file (string): Path to the npz file with data.
        """
        data = np.load(npz_file)
        self.E = data["E"]
        self.Z = data["z"]
        self.R = data["R"]
        self.E_norm = (data["E"] - np.mean(data["E"]))/np.std(data["E"])

    def __len__(self):
        return len(self.E)
    
    def energy_std(self):
        return np.std(self.E)
    
    def energy_mean(self):
        return np.mean(self.E)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        sample = {'positions': self.R[idx], 'atom_types': self.Z, 'energy': self.E_norm[idx]}

        return sample

In [196]:
water_dataset = WaterDataset("data_water.npz")

In [197]:
water_loader = jdl.DataLoader(water_dataset, backend='pytorch', batch_size=5, shuffle=True, drop_last=False)

In [198]:
batch=next(iter(water_loader))

In [199]:
batch["energy"]

array([[-0.73456615],
       [ 3.46393816],
       [-0.0459205 ],
       [-0.75315944],
       [ 1.30973884]])

In [200]:
water_dataset.energy_std()

np.float64(9.767767751286364)

In [201]:
class AtomEmbedding(nn.Module):
    num_atom_types: int
    embedding_dim: int
    def setup(self):
        self.embedding = nn.Embed(
            num_embeddings=self.num_atom_types,
            features=self.embedding_dim
        )

    def __call__(self, atom_types):
        ord_atom_types = jnp.unique(atom_types, size=self.num_atom_types)
        mask = jnp.asarray([jnp.where(ord_atom_types==elem, size=1) for elem in atom_types]).squeeze()
        return jnp.array(self.embedding(mask)).squeeze()

In [202]:
# Given a positions thing, r = (r_1, ..., r_n), returns
# a matrix with the relative distance among them
# (d_ij)_ij = |r_i - r_j|^2

class R_distances(nn.Module):

    def __call__(self, R):
        Rij = jnp.array([[r_i - r_j for r_j in R] for r_i in R])
        d_ij = jnp.linalg.norm(Rij, axis=-1)
        return d_ij

In [203]:
class RadialBasisFunctions(nn.Module):
    rbf_min: float
    rbf_max: float
    n_rbf: int
    gamma: float = 10

    def setup(self):
        self.centers = jnp.linspace(self.rbf_min, self.rbf_max, self.n_rbf).reshape(1, -1)

    def __call__(self, d_ij):
        diff = d_ij[..., None] - self.centers
        return jnp.exp(-self.gamma * jnp.pow(diff, 2))

In [204]:
# ReLU function

class relu_layer(nn.Module):
    def __call__(self, x):
        return nn.relu(x)

# Shift Softplus layer

class ssp_layer(nn.Module):
    def __call__(self, x):
        return jnp.log(0.5 * jnp.exp(x) + 0.5)

In [205]:
class filter_generator(nn.Module):
    atom_embeddings_dim: int
    rbf_min: float
    rbf_max: float
    n_rbf: int
    activation: nn.Module = ssp_layer

    def setup(self):
        self.rbf = RadialBasisFunctions(
            self.rbf_min, 
            self.rbf_max, 
            self.n_rbf
            )
        
        self.w_layers = nn.Sequential([
            nn.Dense(self.n_rbf, self.atom_embeddings_dim),
            self.activation(),
            nn.Dense(self.atom_embeddings_dim, self.atom_embeddings_dim),
            self.activation()
            ])

    def __call__(self, d_ij):
        rbfs = self.rbf(d_ij)
        Wij = self.w_layers(rbfs)
        return Wij


In [206]:
class CfConv(nn.Module):
    
    def __call__(self, X, fij):
        ones = jnp.ones(shape=fij.shape)
        jax.lax.fori_loop(0, len(ones), lambda i, ones_i: ones_i.at[i,i].set(jnp.zeros(shape=ones_i[i,i].shape)), ones)
        #filters = jnp.sum(fij, axis = 1)
        X_j = X[None, :, :]
        Xij = X_j * fij
        #jnp.sum(Xij, axis = 0, where = ones)
        return X + jnp.sum(Xij, axis=1, where = ones)

In [207]:
class InteractionBlock(nn.Module):
    atom_embeddings_dim: int
    rbf_min: float
    rbf_max: float
    n_rbf: int
    activation: nn.Module = ssp_layer

    def setup(self):
        self.in_atom_wise = nn.Dense(
            self.atom_embeddings_dim,
            self.atom_embeddings_dim
            )
        
        self.out_atom_wise = nn.Sequential([
            nn.Dense(self.atom_embeddings_dim, self.atom_embeddings_dim),
            self.activation(),
            nn.Dense(self.atom_embeddings_dim, self.atom_embeddings_dim)
            ])

        self.filters = filter_generator(
            self.atom_embeddings_dim,
            self.rbf_min,
            self.rbf_max,
            self.n_rbf
            )

        self.cf_conv = CfConv()

        self.distances = R_distances()

    def __call__(self, X, R):
        X_in = self.in_atom_wise(X)
        R_distances = self.distances(R)
        fils = self.filters(R_distances)
        X_conv = self.cf_conv(X_in, fils)
        V = self.out_atom_wise(X_conv)
        return X + V, R

In [208]:
class SchNet(nn.Module):
    atom_embedding_dim: int = 64
    atom_wise_out: int = 32
    n_interactions: int = 3
    n_atom_types: int = 2
    rbf_min: float = 0.
    rbf_max: float = 30.
    n_rbf: int = 300
    activation: nn.Module = ssp_layer

    def setup(self):
        self.embedding = AtomEmbedding(
            self.n_atom_types,
            self.atom_embedding_dim
        )

        self.interaction = InteractionBlock(self.atom_embedding_dim, self.rbf_min, self.rbf_max, self.n_rbf, self.activation)

        self.interactions = nn.Sequential([
            InteractionBlock(self.atom_embedding_dim, self.rbf_min, self.rbf_max, self.n_rbf, self.activation)
            for _ in range(self.n_interactions)
        ])

        self.output_layers = nn.Sequential([
            nn.Dense(self.atom_embedding_dim, self.atom_wise_out),
            self.activation(),
            nn.Dense(self.atom_wise_out, 1)
        ])

    @partial(jax.vmap, in_axes = (None, 0, 0), out_axes = 0)
    def __call__(self, Z, R):
        X = self.embedding(Z)
        X_interacted = X
        X_interacted, R = self.interactions(X = X_interacted, R = R)
        atom_outputs = self.output_layers(X_interacted)
        predicted_energies = jnp.sum(atom_outputs)
        return predicted_energies

In [209]:
sch = SchNet()
Z = batch["atom_types"]
R = batch["positions"]
key1, key2 = random.split(random.key(0), 2)
params_sch = sch.init(key2, Z, R)
energies = sch.apply(params_sch, Z, R)
print(energies.shape)


(5,)


In [210]:
########################### Training ###########################

In [211]:
@partial(jax.jit, static_argnums=(1,))
def mse(params, model, Z, R, E):
    E_pred = model.apply(params, Z, R)
    mse_v = jnp.inner(E - E_pred, E - E_pred) / 2.0
    return jnp.mean(mse_v)

In [212]:
loss_grad_fn = jax.value_and_grad(mse)

In [213]:
@jax.jit
def update_params(params, learning_rate, grads):
    params = jax.tree_util.tree_map(
        lambda p, g: p - learning_rate * g, params, grads
    )
    return params

In [218]:
# Training shit

learning_rate = 0.005
momentum = 0.005
batchsize = 1_000
epochs = 1_000
freq = 1

model = SchNet(atom_embedding_dim=16)

water_dataset = WaterDataset("data_water.npz")

train_dataset, test_dataset = torch.utils.data.random_split(water_dataset, [0.8, 0.2])

train_loader = jdl.DataLoader(train_dataset, backend='pytorch', batch_size=batchsize, shuffle=True, drop_last=False)
test_loader = jdl.DataLoader(test_dataset, backend="pytorch", batch_size=batchsize, drop_last=False)

In [221]:
batch = next(iter(test_loader))
Z = batch["atom_types"]
R = batch["positions"]
key1, key2 = random.split(random.key(0), 2)
params = model.init(key2, Z, R)
total_loss = 0.
test_loss = 0.
learning_rate=1E-5
for epoch in range(epochs):
    j = 0
    total_loss = 0.
    test_loss = 0.
    for i, batch in enumerate(train_loader):
        Z = batch["atom_types"]
        R = batch["positions"]
        E = batch["energy"]
        loss, grads = loss_grad_fn(params, model, Z, R, E)
        params = update_params(params, learning_rate, grads)
        total_loss += loss
        j += 1
    print(f"Train loss in epoch {epoch} | {total_loss / j}")

    j = 0
    for i, batch in enumerate(test_loader):
        Z = batch["atom_types"]
        R = batch["positions"]
        E = batch["energy"]
        test_loss += mse(params, model, Z, R, E)
        j += 1
    print(f"Test error in epoch {epoch} | {test_loss / j}")

Train loss in epoch 0 | nan
Test error in epoch 0 | nan


KeyboardInterrupt: 

In [None]:
model = SchNet()
Z = batch["atom_types"]
R = batch["positions"]
E = batch["energy"]
key1, key2 = random.split(random.key(0), 2)
params = model.init(key2, Z, R)
mse(params, Z, R, E)

Array(371.93015, dtype=float32)

In [190]:
import optax

def create_train_step(key, Z, R, model, optimizer):
    params = model.init(key, Z, R)
    opt_state = optimizer.init(params)

    @jax.jit
    def loss_fn(params, Z_batched, R_batched, E_batched):
        pred_energies = model.apply(params, Z_batched, R_batched)
        return jnp.mean(jnp.pow(pred_energies - E, 2))
    
    @jax.jit
    def train_step(params, opt_state, Z, R, E):
        loss, grads = jax.value_and_grad(loss_fn)(params, Z, R, E)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss
    
    @jax.jit
    def test_fn(params, Z, R, E):
        pred_energies = model.apply(params, Z, R)
        return jnp.mean(jnp.pow(pred_energies - E, 2))

    return train_step, test_fn, params, opt_state

In [None]:
# Training shit

import optax

learning_rate = 0.01
momentum = 0.005
batchsize = 5_000

model = SchNet(atom_embedding_dim=32, n_rbf=30)

water_dataset = WaterDataset("data_water.npz")

train_dataset, test_dataset = torch.utils.data.random_split(water_dataset, [0.8, 0.2])

train_loader = jdl.DataLoader(train_dataset, backend='pytorch', batch_size=batchsize, shuffle=True, drop_last=False)
test_loader = jdl.DataLoader(test_dataset, backend="pytorch", batch_size=batchsize, drop_last=False)

batch = next(iter(water_loader))

Z = batch["atom_types"]
R = batch["positions"]

optimizer = optax.adamw(learning_rate=learning_rate)

train_step, test_fn, params, opt_state = create_train_step(key2, Z, R, model, optimizer)

In [192]:
freq = 10
epochs = 1_000

num_test = len(test_dataset)
for epoch in range(epochs):
  total_loss, total_mse, total_kl = 0.0, 0.0, 0.0
  for i, batch in enumerate(train_loader):

    Z = batch["atom_types"]
    R = batch["positions"]
    E = batch["energy"]

    params, opt_state, loss = train_step(params, opt_state, Z, R, E)

    total_loss += loss

    if i > 0 and not i % freq:
      print(f"epoch {epoch} | step {i} | loss: {total_loss / freq}")
      total_loss = 0.

  test_loss = 0.
  j = 0
  for i, batch in enumerate(test_loader):
    Z = batch["atom_types"]
    R = batch["positions"]
    E = batch["energy"]
    test_loss += test_fn(params, Z, R, E)
    j += 1
  print(f"Test error in epoch {epoch} | loss: {test_loss / j}")

epoch 0 | step 10 | loss: 12.767768859863281
Test error in epoch 0 | loss: 1.0893431901931763
epoch 1 | step 10 | loss: 1.4749388694763184
Test error in epoch 1 | loss: 1.068005084991455
epoch 2 | step 10 | loss: 1.1945760250091553
Test error in epoch 2 | loss: 0.9841251969337463
epoch 3 | step 10 | loss: 1.1398496627807617
Test error in epoch 3 | loss: 0.9777716398239136
epoch 4 | step 10 | loss: 1.1300773620605469
Test error in epoch 4 | loss: 0.9771242737770081
epoch 5 | step 10 | loss: 1.1282262802124023
Test error in epoch 5 | loss: 0.9765509366989136
epoch 6 | step 10 | loss: 1.127841591835022
Test error in epoch 6 | loss: 0.9767917394638062
epoch 7 | step 10 | loss: 1.1277782917022705
Test error in epoch 7 | loss: 0.976589560508728
epoch 8 | step 10 | loss: 1.1277592182159424
Test error in epoch 8 | loss: 0.9766359329223633
epoch 9 | step 10 | loss: 1.1277509927749634
Test error in epoch 9 | loss: 0.9765997529029846
epoch 10 | step 10 | loss: 1.127746343612671
Test error in epoc

KeyboardInterrupt: 

In [None]:
def opt_params(params, loss_fn):
    losses, grads = jax.value_and_grad(loss_fn)(params,)