In [2]:
import jax
import flax
import jax.numpy as jnp
from flax import linen as nn
from flax import nnx
from typing import Dict

#from schnetpack.nn.scatter import scatter_add

In [2]:
def radius_graph(x, r=1):
    distances = jnp.array([[jnp.linalg.norm(x[i] - x[j]) for i in range(len(x))] for j in range(len(x))])
    indices = jnp.array([[[i,j] for i in range(len(x))] for j in range(len(x))])
    mask = (distances<r)
    row, col = indices[mask][:,0], indices[mask][:,1]
    return jnp.array(row), jnp.array(col)

In [None]:
class AtomEmbedding(nnx.Module):
    def __init__(self, num_atom_types, embedding_dim, rngs: nnx.Rngs):
        super(AtomEmbedding, self).__init__()
        self.embedding = nnx.Embed(
            num_embeddings=num_atom_types,
            features=embedding_dim,
            rngs=rngs
        )

    def __call__(self, atom_types):
        return self.embedding(atom_types)

In [None]:
class InteractionBlock(nnx.Module):
    def __init__(self, embedding_dim, hidden_dim, num_filters, rngs: nnx.Rngs):
        super(InteractionBlock, self).__init__()
        self.hidden_dim = hidden_dim
        self.embedding_dim = embedding_dim

        self.rbf = nnx.Linear(1, num_filters, rngs=rngs)

        self.mlp_1 = nnx.Linear(num_filters, hidden_dim, rngs=rngs)

        self.mlp_2 = nnx.Linear(hidden_dim, embedding_dim, rngs=rngs)

        self.update_mlp_1 = nnx.Linear(embedding_dim, hidden_dim, rngs=rngs)

        self.update_mlp_2 = nnx.Linear(hidden_dim, embedding_dim, rngs=rngs)

    def __call__(self, x, edge_index, edge_distance):
        rbf = self.rbf(jnp.expand_dims(edge_distance, 1))
        filters = self.mlp_2(nnx.relu(self.mlp_1(rbf)))
        messages = x[edge_index[0]] * filters
        out = jax.lax.scatter_add(messages, edge_index[1], dimension_numbers=0)
        out = self.update_mlp_2(nnx.relu(self.update_mlp_1(out)))
        return x + out

In [None]:
class SchNet_JAX(nnx.Module):
    def __init__(
            self,
            num_atom_types,
            rngs: nnx.Rngs,
            embedding_dim=128,
            hidden_dim=128,
            num_filters=64,
            num_interactions=3
    ):
        super(SchNet_JAX, self).__init__()
        self.hidden_dim = hidden_dim
        self.rngs = rngs
        self.embedding_dim = embedding_dim
        self.atom_embedding = AtomEmbedding(num_atom_types + 1, embedding_dim, rngs)
        self.interactions = [
            InteractionBlock(
                embedding_dim,
                hidden_dim,
                num_filters,
                rngs
            )
            for _ in range(num_interactions)
        ]

    def __call__(self, data):
        atom_types = data[0]
        pos = data[1]
        energy = data[2]
        x = self.atom_embedding(atom_types)
        row, col = radius_graph(pos, r=5.0)
        edge_distance = jnp.linalg.norm((pos[row] - pos[col]))**2
        print(edge_distance)
        print((row, col))
        for interaction in self.interactions:
            x = interaction(x, (row, col), edge_distance)

        energy = nnx.Linear(self.embedding_dim, self.hidden_dim, self.rngs)(x)
        energy = nnx.relu(energy)
        energy = nnx.Linear(self.hidden_dim, 1, rngs=self.rngs)(energy)
        energy = jnp.array([jnp.sum(energy[i*3:(i+1)*3]) for i in range(3)])
        return energy


In [388]:
xd = jnp.array([0,1,2,3,4])

In [400]:
batch[1].shape

(10, 3, 3)

In [402]:
jnp.linalg.norm(batch[1][xd] - batch[1][xd], axis=0)

Array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)

In [389]:
model = SchNet_JAX(8, rngs=nnx.Rngs(0))

In [390]:
model(batch)

0.0
(Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
       2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3,
       4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5,
       6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7,
       8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32), Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4,
       4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6,
       6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8,
       8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9], dtype=int32))


ValueError: axis 1 is out of bounds for array of dimension 1

In [255]:
coso

Array([3, 6, 9], dtype=int32)

In [None]:
import jax_dataloader as jdl

jdl.manual_seed(1234)

# Example: Create a single water molecule
def create_water_molecule():
    # Atom types: Oxygen (0), Hydrogen (1)
    atom_types = jnp.array([0, 1, 1], dtype=jnp.float64)
    # Positions in angstroms
    pos = jnp.array([
        [0.0, 0.0, 0.0],        # Oxygen
        [0.96, 0.0, 0.0],       # Hydrogen 1
        [-0.24, 0.93, 0.0]      # Hydrogen 2
    ], dtype=jnp.float64)
    # Assume total energy is -76.0 eV (for example)
    energy = jnp.array([-76.0], dtype=jnp.float64)
    return atom_types, pos, energy

def create_water_dataset(n_data=100):
    atom_types_list = []
    pos_list = []
    energy_list = []
    for _ in range(n_data):
        atom_types, pos, energy = create_water_molecule()
        atom_types_list.append(atom_types)
        pos_list.append(pos)
        energy_list.append(energy)
    #array_dataset = jdl.ArrayDataset(jnp.array(atom_types_list), jnp.array(pos), jnp.array(energy))
    #return array_dataset
    # Must be: [num_atoms,]
    atom_types = jnp.array(atom_types_list)
    atom_types = atom_types.reshape(atom_types.shape[0]*atom_types.shape[1])

    # Must be: [num_atoms, 3]
    pos = jnp.array(pos_list)
    pos = pos.reshape()
    energy = jnp.array(pos_list)
    return jdl.ArrayDataset(atom_types, pos, energy)


# Create a dataset with multiple identical water molecules
dataset = create_water_dataset()
dataloader = jdl.DataLoader(dataset, 'jax', batch_size=10, shuffle=True)
#loader = jdl.DataLoader(dataset=dataset, batch_size=10, backend='jax', shuffle=True)


  atom_types = jnp.array([0, 1, 1], dtype=jnp.float64)
  pos = jnp.array([
  energy = jnp.array([-76.0], dtype=jnp.float64)


AssertionError: All arrays must have the same dimension.

In [242]:
class Block(nnx.Module):


  def __init__(self, rngs):
    self.linear = nnx.Linear(5, 10, rngs=rngs)
    self.bn = nnx.BatchNorm(10, rngs=rngs)
    self.dropout = nnx.Dropout(0.1, rngs=rngs)

  def __call__(self, x):
    return nnx.relu(self.dropout(self.bn(self.linear(x))))

x = jnp.ones((1, 5))
model = Block(nnx.Rngs(0))



y = model(x)

In [248]:
y[0]

Array([0.0000000e+00, 4.1395313e-05, 0.0000000e+00, 0.0000000e+00,
       0.0000000e+00, 0.0000000e+00, 1.0471470e-05, 1.0471470e-05,
       0.0000000e+00, 0.0000000e+00], dtype=float32)

In [191]:
import numpy as np

In [None]:
import numpy as np

In [215]:
water_data = np.load("data_water.npz")

In [227]:
water_data["z"]

array([8, 1, 1])

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import radius_graph
from torch_scatter import scatter_add
from torch_geometric.data import Data#, DataLoader
from torch_geometric.loader import DataLoader
import jax_dataloader as jdl


# Example: Create a single water molecule
def create_water_molecule_pytorch():
    # Atom types: Oxygen (0), Hydrogen (1)
    atom_types = torch.tensor([0, 1, 1], dtype=torch.long)
    # Positions in angstroms
    pos = torch.tensor([
        [0.0, 0.0, 0.0],        # Oxygen
        [0.96, 0.0, 0.0],       # Hydrogen 1
        [-0.24, 0.93, 0.0]      # Hydrogen 2
    ], dtype=torch.float)
    # Assume total energy is -76.0 eV (for example)
    energy = torch.tensor([-76.0], dtype=torch.float)
    data = Data(x=atom_types, pos=pos, y=energy)
    return data

# Create a dataset with multiple identical water molecules
dataset = [create_water_molecule_pytorch() for _ in range(100)]
loader = DataLoader(dataset, batch_size=10, shuffle=True)


In [5]:
dataset = torch.utils.data.TensorDataset(dataset)

AttributeError: 'list' object has no attribute 'size'

In [4]:
loader = jdl.DataLoader(dataset, "pytorch", batch_size=10, shuffle=True)

BeartypeCallHintParamViolation: Method jax_dataloader.loaders.torch.DataLoaderPytorch.__init__() parameter dataset=[Data(x=[3], y=[1], pos=[3, 3]), Data(x=[3], y=[1], pos=[3, 3]), Data(x=[3], y=[1], pos=[3...])] violates type hint typing.Union[jax_dataloader.datasets.Dataset, torch.utils.data.dataset.Dataset, typing.Annotated[NoneType, beartype.vale.Is[lambda _: hf_datasets is not None]]], as list [Data(x=[3], y=[1], pos=[3, 3]), Data(x=[3], y=[1], pos=[3, 3]), Data(x=[3], y=[1], pos=[3...])]:
* Not <class "jax_dataloader.datasets.Dataset">.
* Not instance of <class "torch.utils.data.dataset.Dataset">.
* Not instance of <class "builtins.NoneType">.

In [403]:
def create_water_dataset_from_npz(data_name: str):
    water_data = np.load(data_name)
    energy = water_data["E"]
    n_data = len(energy)
    atom_types = jnp.array([water_data["z"] for _ in range(n_data)])
    positions = water_data["R"]
    return jdl.ArrayDataset(atom_types, positions, energy)


# Create a dataset with multiple identical water molecules
dataset = create_water_dataset_from_npz("data_water.npz")
dataloader = jdl.DataLoader(dataset, 'jax', batch_size=10, shuffle=True)
#loader = jdl.DataLoader(dataset=dataset, batch_size=10, backend='jax', shuffle=True)


In [404]:
batch = next(iter(dataloader))

In [408]:
batch[0].shape

(10, 3)

In [409]:
batch[1].shape

(10, 3, 3)

In [219]:
water_data["R"].shape

(99999, 3, 3)

In [217]:
water_data.files

['type',
 'R',
 'R_units',
 'z',
 'E',
 'E_units',
 'F',
 'F_units',
 'Q',
 'spin_state',
 'name',
 'README',
 'theory',
 'r_unit',
 'e_unit',
 'D',
 'D_units',
 'md5']

In [214]:
f["z"]

array([8, 1, 1])

In [194]:
xd = np.load("model_water.npz")

In [201]:
xd["z"]

array([8, 1, 1])

In [195]:
xd.files

['type',
 'code_version',
 'dataset_name',
 'dataset_theory',
 'solver_name',
 'z',
 'idxs_train',
 'md5_train',
 'idxs_valid',
 'md5_valid',
 'n_test',
 'md5_test',
 'f_err',
 'R_desc',
 'R_d_desc_alpha',
 'c',
 'std',
 'sig',
 'lam',
 'alphas_F',
 'perms',
 'tril_perms_lin',
 'use_E',
 'e_err',
 'r_unit',
 'e_unit']

In [182]:
batch = next(iter(dataloader))

In [190]:
batch[1].shape

(10, 3, 3)

In [186]:
batch[0].reshape(30,)

array([0., 1., 1., 0., 1., 1., 0., 1., 1., 0., 1., 1., 0., 1., 1., 0., 1.,
       1., 0., 1., 1., 0., 1., 1., 0., 1., 1., 0., 1., 1.], dtype=float32)

In [177]:
i=0
for batch in dataloader:
    atoms = batch[0]
    pos = batch[1]
    break

In [179]:
atoms.shape

(10, 3)

In [175]:
pos[0]

array([[ 0.  ,  0.  ,  0.  ],
       [ 0.96,  0.  ,  0.  ],
       [-0.24,  0.93,  0.  ]], dtype=float32)

In [170]:
def radius_graph(x, r=1):
    distances = jnp.array([[jnp.linalg.norm(x[i] - x[j]) for i in range(len(x))] for j in range(len(x))])
    indices = jnp.array([[[i,j] for i in range(len(x))] for j in range(len(x))])
    mask = (distances<r)
    row, col = indices[mask][:,0], indices[mask][:,1]
    return row, col

In [171]:
num_filters = 64
hidden_dim = 128
embedding_dim = 128
num_atom_types = 2

# Create a radius graph (edges based on distance)
edge_index = radius_graph(pos, r=5.0)  # Adjust radius as needed

In [173]:
print(pos.shape)

(10, 3, 3)


In [161]:
jnp.linalg.norm(pos[0] - pos[1])

Array(0., dtype=float32)

In [156]:
jnp.dot(pos[0] - pos[1],pos[0] - pos[1])

Array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)

In [None]:
num_filters = 64
hidden_dim = 128
embedding_dim = 128
num_atom_types = 2

# Create a radius graph (edges based on distance)
edge_index = radius_graph(pos, r=5.0, loop=False)  # Adjust radius as needed
atom_embedding = AtomEmbedding(num_atom_types, embedding_dim)
# Compute distances for edges
edge_distance = (pos[edge_index[0]] - pos[edge_index[1]]).norm(p=2, dim=1)
rbf = nn.Linear(1, num_filters)(edge_distance.unsqueeze(1))
mlp = nn.Sequential(
            nn.Linear(num_filters, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, embedding_dim)
        )
update_mlp = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, embedding_dim)
        )
x = atom_embedding(atom_types)
filters = mlp(rbf)
messages = x[edge_index[0]] * filters
print(messages.shape)
print(edge_index.shape)
out = scatter_add(messages, edge_index[1], dim=0)
print(out.shape)
out = update_mlp(out)
output = x + out

array([[[ 0.  ,  0.  ,  0.  ],
        [ 0.96,  0.  ,  0.  ],
        [-0.24,  0.93,  0.  ]],

       [[ 0.  ,  0.  ,  0.  ],
        [ 0.96,  0.  ,  0.  ],
        [-0.24,  0.93,  0.  ]],

       [[ 0.  ,  0.  ,  0.  ],
        [ 0.96,  0.  ,  0.  ],
        [-0.24,  0.93,  0.  ]],

       [[ 0.  ,  0.  ,  0.  ],
        [ 0.96,  0.  ,  0.  ],
        [-0.24,  0.93,  0.  ]],

       [[ 0.  ,  0.  ,  0.  ],
        [ 0.96,  0.  ,  0.  ],
        [-0.24,  0.93,  0.  ]],

       [[ 0.  ,  0.  ,  0.  ],
        [ 0.96,  0.  ,  0.  ],
        [-0.24,  0.93,  0.  ]],

       [[ 0.  ,  0.  ,  0.  ],
        [ 0.96,  0.  ,  0.  ],
        [-0.24,  0.93,  0.  ]],

       [[ 0.  ,  0.  ,  0.  ],
        [ 0.96,  0.  ,  0.  ],
        [-0.24,  0.93,  0.  ]],

       [[ 0.  ,  0.  ,  0.  ],
        [ 0.96,  0.  ,  0.  ],
        [-0.24,  0.93,  0.  ]],

       [[ 0.  ,  0.  ,  0.  ],
        [ 0.96,  0.  ,  0.  ],
        [-0.24,  0.93,  0.  ]]], dtype=float32)

In [None]:
edge_index = radius_graph(pos, r=5.0, loop=False)
atom_types = batch.x
edge_distances = (pos[edge_index[0]] - pos[edge_index[1]]).norm(p=2, dim=1)
print(edge_distances)
print(edge_index)
print(pos[18], pos[0])

In [47]:
jnp.array(dataset[0]).shape[0]

100

In [48]:
jnp.array(dataset[1]).shape[0]

100

In [36]:
jnp.array(dataset[2]).shape

(100, 1)

In [10]:
X = jnp.arange(100).reshape(10, 10)
y = jnp.arange(10)
# Create an `ArrayDataset`
arr_ds = jdl.ArrayDataset(X, y)

In [44]:
X[0].shape

(10,)

()

In [11]:
z = jnp.arange(10)
arr_dss = jdl.ArrayDataset(X, y, z)

In [19]:
dataloader = jdl.DataLoader(arr_dss, 'jax', batch_size=5, shuffle=True)


In [20]:
batch = next(iter(dataloader))

In [12]:
new_array = jnp.expand_dims(array, 2)
print(new_array)
print(new_array.shape)

[[[1]
  [2]]

 [[3]
  [4]]]
(2, 2, 1)
