Skip to content

Commit 4e372c9

Browse files
committed
Properly handle float64
1 parent d24085a commit 4e372c9

File tree

4 files changed

+12
-15
lines changed

4 files changed

+12
-15
lines changed

main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def main():
9090

9191
key = jax.random.PRNGKey(args.seed)
9292
key, init_key = jax.random.split(key)
93-
params_q = setup.model_q.init(init_key, jnp.zeros([args.BS, 1]))
93+
params_q = setup.model_q.init(init_key, jnp.zeros([args.BS, 1], dtype=jnp.float32))
9494

9595
optimizer_q = optax.adam(learning_rate=args.lr)
9696
state_q = train_state.TrainState.create(apply_fn=setup.model_q.apply, params=params_q, tx=optimizer_q)
@@ -104,7 +104,7 @@ def main():
104104
show_or_save_fig(args.save_dir, 'loss_plot.pdf')
105105

106106
# TODO: how to plot this nicely?
107-
t = args.T * jnp.linspace(0, 1, args.BS).reshape((-1, 1))
107+
t = args.T * jnp.linspace(0, 1, args.BS, dtype=jnp.float32).reshape((-1, 1))
108108
key, path_key = jax.random.split(key)
109109
eps = jax.random.normal(path_key, [args.BS, args.num_gaussians, setup.A.shape[-1]])
110110
mu_t, sigma_t, w_logits = state_q.apply_fn(state_q.params, t)
@@ -118,7 +118,7 @@ def main():
118118
# plt.show()
119119

120120
key, init_key = jax.random.split(key)
121-
x_0 = jnp.ones((args.num_paths, setup.A.shape[0])) * setup.A
121+
x_0 = jnp.ones((args.num_paths, setup.A.shape[0]), dtype=jnp.float32) * setup.A
122122
eps = jax.random.normal(key, shape=x_0.shape)
123123
x_0 += args.base_sigma * eps
124124

systems.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing import Self
1111
from utils.pdb import assert_same_molecule
1212
from utils.rmsd import kabsch_align
13-
13+
from dmff import Hamiltonian, NeighborList # This sets jax to use 64-bit precision
1414

1515
class System:
1616
def __init__(self, U: Callable[[ArrayLike], ArrayLike], A: ArrayLike, B: ArrayLike, mass: ArrayLike, plot):
@@ -49,17 +49,14 @@ def from_name(cls, name: str) -> Self:
4949

5050
@classmethod
5151
def from_pdb(cls, A: str, B: str, forcefield: [str], cv: Optional[str]) -> Self:
52-
print("WARNING!!!! This changes jax to double precision")
53-
from dmff import Hamiltonian, NeighborList
54-
5552
A_pdb, B_pdb = app.PDBFile(A), app.PDBFile(B)
5653
assert_same_molecule(A_pdb, B_pdb)
5754

5855
mass = [a.element.mass.value_in_unit(unit.dalton) for a in A_pdb.topology.atoms()]
59-
mass = jnp.broadcast_to(jnp.array(mass).reshape(-1, 1), (len(mass), 3)).reshape(-1)
56+
mass = jnp.broadcast_to(jnp.array(mass, dtype=jnp.float32).reshape(-1, 1), (len(mass), 3)).reshape(-1)
6057

61-
A = jnp.array(A_pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer))
62-
B = jnp.array(B_pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer))
58+
A = jnp.array(A_pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer), dtype=jnp.float32)
59+
B = jnp.array(B_pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer), dtype=jnp.float32)
6360
A, B = kabsch_align(A, B)
6461
A, B = A.reshape(-1), B.reshape(-1)
6562

training/diagonal.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from jax.typing import ArrayLike
44
from flax import linen as nn
55
import jax.numpy as jnp
6-
from typing import Union, Dict, Any, Callable, Tuple
6+
from typing import Union, Dict, Any, Callable, Tuple, Optional
77
from flax.training.train_state import TrainState
88
import jax
99
from flax.typing import FrozenVariableDict
@@ -60,8 +60,8 @@ def loss_fn(params_q: Union[FrozenVariableDict, Dict[str, Any]], key: ArrayLike)
6060
ndim = self.model_q.A.shape[-1]
6161
key = jax.random.split(key)
6262

63-
t = self.T * jax.random.uniform(key[0], [BS, 1])
64-
eps = jax.random.normal(key[1], [BS, 1, ndim])
63+
t = self.T * jax.random.uniform(key[0], [BS, 1], dtype=jnp.float32)
64+
eps = jax.random.normal(key[1], [BS, 1, ndim], dtype=jnp.float32)
6565

6666
def v_t(_eps, _t):
6767
"""This function is equal to v_t * xi ** 2."""

training/qsetup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ def sample_paths(self, state_q: TrainState, x_0: ArrayLike, dt: float, T: float,
3535

3636
num_paths = x_0.shape[0]
3737
ndim = x_0.shape[1]
38-
x_t = jnp.ones((num_paths, N, ndim))
38+
x_t = jnp.zeros((num_paths, N, ndim), dtype=jnp.float32)
3939
x_t = x_t.at[:, 0, :].set(x_0)
4040

41-
t = jnp.zeros((BS, 1))
41+
t = jnp.zeros((BS, 1), dtype=jnp.float32)
4242
if key is None:
4343
u = jax.jit(lambda _t, _x: self.u_t(state_q, _t, _x, 0, *args, **kwargs))
4444
else:

0 commit comments

Comments
 (0)