Skip to content

Commit c8d7cd0

Browse files
committed
Fix second order path sampling shapes
1 parent b322960 commit c8d7cd0

File tree

4 files changed

+44
-27
lines changed

4 files changed

+44
-27
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
save_dir: ./out/toy/mueller_second_order_single_gaussian
2+
3+
test_system: mueller_brown
4+
ode: second_order
5+
parameterization: diagonal
6+
T: 20.0
7+
xi: 1.4142135
8+
xi_pos_noise: 1e-3
9+
gamma: 1.0
10+
11+
num_gaussians: 1
12+
trainable_weights: False
13+
base_sigma: 5e-3
14+
15+
epochs: 10000
16+
BS: 512
17+
18+
num_paths: 1000
19+
dt: 0.02
20+
21+
log_plots: True

main.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
parser.add_argument('--T', type=float, required=True,
3131
help="Transition time in the base unit of the system. For molecular simulations, this is in picoseconds.")
3232
parser.add_argument('--xi', type=float)
33+
parser.add_argument('--xi_pos_noise', type=float, default=1e-4,
34+
help="For second order SDEs we have to add a small noise to the positional xi. This is the value of this noise.")
3335
parser.add_argument('--temperature', type=float,
3436
help="The temperature of the system in Kelvin. Either specify this or xi.")
3537
parser.add_argument('--gamma', type=float, required=True)
@@ -98,7 +100,7 @@ def main():
98100
from model import MLP
99101

100102
model = MLP([128, 128, 128])
101-
setup = qsetup.construct(system, model, args.ode, args.parameterization, xi, args)
103+
setup, A, B = qsetup.construct(system, model, xi, args)
102104

103105
key = jax.random.PRNGKey(args.seed)
104106
key, init_key = jax.random.split(key)
@@ -144,7 +146,7 @@ def main():
144146
print("!!!TODO: how to plot this nicely?")
145147
t = args.T * jnp.linspace(0, 1, args.BS, dtype=jnp.float32).reshape((-1, 1))
146148
key, path_key = jax.random.split(key)
147-
eps = jax.random.normal(path_key, [args.BS, args.num_gaussians, setup.A.shape[-1]])
149+
eps = jax.random.normal(path_key, [args.BS, args.num_gaussians, A.shape[-1]])
148150
mu_t, sigma_t, w_logits = state_q.apply_fn(state_q.params, t)
149151
w = jax.nn.softmax(w_logits)[None, :, None]
150152
samples = (w * (mu_t + sigma_t * eps)).sum(axis=1)
@@ -156,7 +158,7 @@ def main():
156158
# plt.show()
157159

158160
key, init_key = jax.random.split(key)
159-
x_0 = jnp.ones((args.num_paths, setup.A.shape[0]), dtype=jnp.float32) * setup.A
161+
x_0 = jnp.ones((args.num_paths, A.shape[0]), dtype=jnp.float32) * A
160162
eps = jax.random.normal(key, shape=x_0.shape)
161163
x_0 += args.base_sigma * eps
162164

training/qsetup.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import argparse
22
from abc import ABC, abstractmethod
33
from dataclasses import dataclass
4-
from typing import Callable, Optional
4+
from typing import Callable, Optional, Tuple
55
from flax import linen as nn
66
from flax.training.train_state import TrainState
77

@@ -81,32 +81,37 @@ def B(self):
8181
return self.system.B
8282

8383

84-
def construct(system: System, model: nn.module, ode: str, parameterization: str, xi: ArrayLike,
85-
args: argparse.Namespace) -> QSetup:
84+
def construct(system: System, model: nn.module, xi: float, args: argparse.Namespace) -> Tuple[
85+
QSetup, ArrayLike, ArrayLike]:
86+
"""
87+
Construct a QSetup object based on the given arguments.
88+
return: QSetup, A, B
89+
"""
8690
from training.setups import diagonal
8791

88-
if ode == 'first_order':
92+
if args.ode == 'first_order':
8993
order = 'first'
9094
A = system.A
9195
B = system.B
92-
elif ode == 'second_order':
96+
elif args.ode == 'second_order':
9397
order = 'second'
9498

9599
# We pad the A and B matrices with zeros to account for the velocity
96-
A = jnp.hstack([system.A, jnp.zeros_like(system.A)])
97-
B = jnp.hstack([system.B, jnp.zeros_like(system.B)])
100+
A = jnp.hstack([system.A, jnp.zeros_like(system.A)], dtype=jnp.float32)
101+
B = jnp.hstack([system.B, jnp.zeros_like(system.B)], dtype=jnp.float32)
98102

99103
xi_velocity = jnp.ones_like(system.A) * xi
100-
xi_pos = jnp.zeros_like(xi_velocity) + 1e-4
104+
xi_pos = jnp.zeros_like(xi_velocity) + args.xi_pos_noise
101105

102-
xi = jnp.concatenate((xi_pos, xi_velocity), axis=-1)
106+
xi = jnp.concatenate((xi_pos, xi_velocity), axis=-1, dtype=jnp.float32)
107+
print("Setting xi to", xi)
103108
else:
104-
raise ValueError(f"Unknown ODE: {ode}")
109+
raise ValueError(f"Unknown ODE: {args.ode}")
105110

106-
if parameterization == 'diagonal':
111+
if args.parameterization == 'diagonal':
107112
wrapped_module = diagonal.DiagonalWrapper(
108113
model, args.T, A, B, args.num_gaussians, args.trainable_weights, args.base_sigma
109114
)
110-
return diagonal.DiagonalSetup(system, wrapped_module, xi, order, args.T)
115+
return diagonal.DiagonalSetup(system, wrapped_module, xi, order, args.T), A, B
111116
else:
112-
raise ValueError(f"Unknown parameterization: {parameterization}")
117+
raise ValueError(f"Unknown parameterization: {args.parameterization}")

training/setups/drift.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,7 @@ class DriftedSetup(QSetup, ABC):
1212
def __init__(self, system: System, model_q: nn.Module, xi: ArrayLike, order: str):
1313
"""Either instantiate with first or second order drift."""
1414
assert order == 'first' or order == 'second', "Order must be either 'first' or 'second'."
15-
1615
self.order = order
17-
self._A = system.A
18-
self._B = system.B
1916

2017
super().__init__(system, model_q, xi)
2118

@@ -27,11 +24,3 @@ def _drift(self, _x: ArrayLike, gamma: float) -> ArrayLike:
2724
ndim = self.system.A.shape[0]
2825

2926
return jnp.hstack([_x[:, ndim:] / self.system.mass, -self.system.dUdx(_x[:, :ndim]) - _x[:, ndim:] * gamma])
30-
31-
@property
32-
def A(self):
33-
return self._A
34-
35-
@property
36-
def B(self):
37-
return self._B

0 commit comments

Comments
 (0)