Skip to content

Commit 86806d5

Browse files
committed
Fix xi for second order
1 parent 4e372c9 commit 86806d5

File tree

3 files changed

+35
-44
lines changed

3 files changed

+35
-44
lines changed

main.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@
5656

5757
def main():
5858
# TODO: force clipping
59-
# TODO: temperature
59+
print("!!!!Next todos: plot ALDP")
60+
6061
args = parse_args(parser)
6162
assert args.test_system or args.start and args.target, "Either specify a test system or provide start and target structures"
6263
assert not (
@@ -86,24 +87,26 @@ def main():
8687
from model import MLP
8788

8889
model = MLP([128, 128, 128])
89-
setup = qsetup.construct(system, model, args.ode, args.parameterization, args)
90+
setup = qsetup.construct(system, model, args.ode, args.parameterization, xi, args)
9091

9192
key = jax.random.PRNGKey(args.seed)
9293
key, init_key = jax.random.split(key)
9394
params_q = setup.model_q.init(init_key, jnp.zeros([args.BS, 1], dtype=jnp.float32))
9495

9596
optimizer_q = optax.adam(learning_rate=args.lr)
9697
state_q = train_state.TrainState.create(apply_fn=setup.model_q.apply, params=params_q, tx=optimizer_q)
97-
loss_fn = setup.construct_loss(state_q, xi, args.gamma, args.BS)
98+
loss_fn = setup.construct_loss(state_q, args.gamma, args.BS)
9899

99100
key, train_key = jax.random.split(key)
100101
state_q, loss_plot = train(state_q, loss_fn, args.epochs, train_key)
101102
print("Number of potential evaluations", args.BS * args.epochs)
102103

104+
if jnp.isnan(jnp.array(loss_plot)).any():
105+
print("Warning: Loss contains NaNs")
103106
plt.plot(loss_plot)
104107
show_or_save_fig(args.save_dir, 'loss_plot.pdf')
105108

106-
# TODO: how to plot this nicely?
109+
print("!!!TODO: how to plot this nicely?")
107110
t = args.T * jnp.linspace(0, 1, args.BS, dtype=jnp.float32).reshape((-1, 1))
108111
key, path_key = jax.random.split(key)
109112
eps = jax.random.normal(path_key, [args.BS, args.num_gaussians, setup.A.shape[-1]])
@@ -122,15 +125,15 @@ def main():
122125
eps = jax.random.normal(key, shape=x_0.shape)
123126
x_0 += args.base_sigma * eps
124127

125-
x_t_det = setup.sample_paths(state_q, x_0, args.dt, args.T, args.BS, None, None)
128+
x_t_det = setup.sample_paths(state_q, x_0, args.dt, args.T, args.BS, None)
126129

127130
if system.plot:
128131
# In case we have a second order integration scheme, we remove the velocity for plotting
129132
system.plot(title='Deterministic Paths', trajectories=x_t_det[:, :, :system.A.shape[0]])
130133
show_or_save_fig(args.save_dir, 'paths_deterministic.pdf')
131134

132135
key, path_key = jax.random.split(key)
133-
x_t_stoch = setup.sample_paths(state_q, x_0, args.dt, args.T, args.BS, xi, path_key)
136+
x_t_stoch = setup.sample_paths(state_q, x_0, args.dt, args.T, args.BS, path_key)
134137

135138
if system.plot:
136139
system.plot(title='Stochastic Paths', trajectories=x_t_stoch[:, :, :system.A.shape[0]])

training/diagonal.py

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class DiagonalSetup(QSetup, ABC):
5353
def _drift(self, _x: ArrayLike, gamma: float) -> ArrayLike:
5454
raise NotImplementedError
5555

56-
def construct_loss(self, state_q: TrainState, xi: ArrayLike, gamma: float, BS: int) -> Callable[
56+
def construct_loss(self, state_q: TrainState, gamma: float, BS: int) -> Callable[
5757
[Union[FrozenVariableDict, Dict[str, Any]], ArrayLike], ArrayLike]:
5858

5959
def loss_fn(params_q: Union[FrozenVariableDict, Dict[str, Any]], key: ArrayLike) -> ArrayLike:
@@ -80,14 +80,14 @@ def v_t(_eps, _t):
8080
log_q_t = -(relative_mixture_weights / (_sigma_t ** 2) * (_x - _mu_t)).sum(axis=1)
8181
u_t = (relative_mixture_weights * (1 / _sigma_t * _dsigmadt * (_x - _mu_t) + _dmudt)).sum(axis=1)
8282

83-
return u_t - self._drift(_x.reshape(BS, ndim), gamma) + 0.5 * (xi ** 2) * log_q_t
83+
return u_t - self._drift(_x.reshape(BS, ndim), gamma) + 0.5 * (self.xi ** 2) * log_q_t
8484

85-
loss = 0.5 * ((v_t(eps, t) / xi) ** 2).sum(-1, keepdims=True)
85+
loss = 0.5 * ((v_t(eps, t) / self.xi) ** 2).sum(-1, keepdims=True)
8686
return loss.mean()
8787

8888
return loss_fn
8989

90-
def u_t(self, state_q: TrainState, t: ArrayLike, x_t: ArrayLike, xi: ArrayLike, *args, **kwargs) -> ArrayLike:
90+
def u_t(self, state_q: TrainState, t: ArrayLike, x_t: ArrayLike, deterministic: bool, *args, **kwargs) -> ArrayLike:
9191
_mu_t, _sigma_t, _w_logits, _dmudt, _dsigmadt = forward_and_derivatives(state_q, t)
9292
_x = x_t[:, None, :]
9393

@@ -96,56 +96,45 @@ def u_t(self, state_q: TrainState, t: ArrayLike, x_t: ArrayLike, xi: ArrayLike,
9696

9797
_u_t = (relative_mixture_weights * (1 / _sigma_t * _dsigmadt * (_x - _mu_t) + _dmudt)).sum(axis=1)
9898

99-
if xi == 0:
99+
if deterministic:
100100
return _u_t
101101

102102
log_q_t = -(relative_mixture_weights / (_sigma_t ** 2) * (_x - _mu_t)).sum(axis=1)
103103

104-
return _u_t + 0.5 * (xi ** 2) * log_q_t
104+
return _u_t + 0.5 * (self.xi ** 2) * log_q_t
105105

106106

107107
class FirstOrderSetup(DiagonalSetup):
108-
def __init__(self, system: System, model: nn.module, T: float, base_sigma: float, num_mixtures: int,
108+
def __init__(self, system: System, model: nn.module, xi: ArrayLike, T: float, base_sigma: float, num_mixtures: int,
109109
trainable_weights: bool):
110110
model_q = DiagonalWrapper(model, T, system.A, system.B, num_mixtures, trainable_weights, base_sigma)
111-
super().__init__(system, model_q, T, base_sigma, num_mixtures)
111+
super().__init__(system, model_q, xi, T, base_sigma, num_mixtures)
112112

113113
def _drift(self, _x: ArrayLike, gamma: float) -> ArrayLike:
114114
return -self.system.dUdx(_x / (gamma * self.system.mass))
115115

116116

117117
class SecondOrderSetup(DiagonalSetup):
118-
def __init__(self, system: System, model: nn.module, T: float, base_sigma: float, num_mixtures: int,
118+
def __init__(self, system: System, model: nn.module, xi: ArrayLike, T: float, base_sigma: float, num_mixtures: int,
119119
trainable_weights: bool):
120120
# We pad the A and B matrices with zeros to account for the velocity
121121
self._A = jnp.hstack([system.A, jnp.zeros_like(system.A)])
122122
self._B = jnp.hstack([system.B, jnp.zeros_like(system.B)])
123123

124+
xi_velocity = jnp.ones_like(system.A) * xi
125+
xi_pos = jnp.zeros_like(xi_velocity) + 1e-4
126+
127+
xi_second_order = jnp.concatenate((xi_pos, xi_velocity), axis=-1)
128+
124129
model_q = DiagonalWrapper(model, T, self._A, self._B, num_mixtures, trainable_weights, base_sigma)
125-
super().__init__(system, model_q, T, base_sigma, num_mixtures)
130+
super().__init__(system, model_q, xi_second_order, T, base_sigma, num_mixtures)
126131

127132
def _drift(self, _x: ArrayLike, gamma: float) -> ArrayLike:
128133
# number of dimensions without velocity
129134
ndim = self.system.A.shape[0]
130135

131136
return jnp.hstack([_x[:, ndim:] / self.system.mass, -self.system.dUdx(_x[:, :ndim]) - _x[:, ndim:] * gamma])
132137

133-
def _xi_to_second_order(self, xi: ArrayLike) -> ArrayLike:
134-
if xi.shape == self.model_q.A.shape:
135-
return xi
136-
137-
xi_velocity = jnp.ones_like(self.system.A) * xi
138-
xi_pos = jnp.zeros_like(xi_velocity) + 1e-4
139-
140-
return jnp.concatenate((xi_pos, xi_velocity), axis=-1)
141-
142-
def construct_loss(self, state_q: TrainState, xi: ArrayLike, gamma: float, BS: int) -> Callable[
143-
[Union[FrozenVariableDict, Dict[str, Any]], ArrayLike], ArrayLike]:
144-
return super().construct_loss(state_q, self._xi_to_second_order(xi), gamma, BS)
145-
146-
def u_t(self, state_q: TrainState, t: ArrayLike, x_t: ArrayLike, xi: ArrayLike, *args, **kwargs) -> ArrayLike:
147-
return super().u_t(state_q, t, x_t, self._xi_to_second_order(xi), *args, **kwargs)
148-
149138
@property
150139
def A(self):
151140
return self._A

training/qsetup.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,17 @@ class QSetup(ABC):
2121
"""
2222
system: System
2323
model_q: nn.Module
24+
xi: ArrayLike
2425

2526
@abstractmethod
2627
def construct_loss(self, *args, **kwargs) -> Callable:
2728
raise NotImplementedError
2829

2930
def sample_paths(self, state_q: TrainState, x_0: ArrayLike, dt: float, T: float, BS: int,
30-
xi: Optional[float], key: Optional[ArrayLike], *args, **kwargs) -> ArrayLike:
31+
key: Optional[ArrayLike], *args, **kwargs) -> ArrayLike:
32+
"""Sample paths. If key is None, the sampling is deterministic. Otherwise, it is stochastic."""
3133
assert x_0.ndim == 2
3234
assert T / dt == int(T / dt), "dt must divide T evenly"
33-
assert (xi is None) == (key is None), "xi and key must be both None or both specified"
3435
N = int(T / dt)
3536

3637
num_paths = x_0.shape[0]
@@ -39,10 +40,7 @@ def sample_paths(self, state_q: TrainState, x_0: ArrayLike, dt: float, T: float,
3940
x_t = x_t.at[:, 0, :].set(x_0)
4041

4142
t = jnp.zeros((BS, 1), dtype=jnp.float32)
42-
if key is None:
43-
u = jax.jit(lambda _t, _x: self.u_t(state_q, _t, _x, 0, *args, **kwargs))
44-
else:
45-
u = jax.jit(lambda _t, _x: self.u_t(state_q, _t, _x, xi, *args, **kwargs))
43+
u = jax.jit(lambda _t, _x: self.u_t(state_q, _t, _x, key is None, *args, **kwargs))
4644

4745
for i in trange(N):
4846
for j in range(0, num_paths, BS):
@@ -61,17 +59,17 @@ def sample_paths(self, state_q: TrainState, x_0: ArrayLike, dt: float, T: float,
6159
else:
6260
# For stochastic sampling we compute the noise
6361
key, iter_key = jax.random.split(key)
64-
noise = xi * jax.random.normal(iter_key, shape=(BS, ndim))
62+
noise = self.xi * jax.random.normal(iter_key, shape=(BS, ndim))
6563

66-
new_x = cur_x_t + dt * u(t, cur_x_t, *args, **kwargs) + jnp.sqrt(dt) * noise
64+
new_x = cur_x_t + dt * u(t, cur_x_t) + jnp.sqrt(dt) * noise
6765
x_t = x_t.at[j:j_end, i + 1, :].set(new_x[:j_end - j])
6866

6967
t += dt
7068

7169
return x_t
7270

7371
@abstractmethod
74-
def u_t(self, state_q: TrainState, t: ArrayLike, x_t: ArrayLike, xi: ArrayLike, *args, **kwargs) -> ArrayLike:
72+
def u_t(self, state_q: TrainState, t: ArrayLike, x_t: ArrayLike, deterministic: bool, *args, **kwargs) -> ArrayLike:
7573
raise NotImplementedError
7674

7775
@property
@@ -83,20 +81,21 @@ def B(self):
8381
return self.system.B
8482

8583

86-
def construct(system: System, model: nn.module, ode: str, parameterization: str, args: argparse.Namespace) -> QSetup:
84+
def construct(system: System, model: nn.module, ode: str, parameterization: str, xi: ArrayLike,
85+
args: argparse.Namespace) -> QSetup:
8786
from training import diagonal
8887

8988
if ode == 'first_order':
9089
if parameterization == 'diagonal':
91-
return diagonal.FirstOrderSetup(system, model, args.T, args.base_sigma, args.num_gaussians,
90+
return diagonal.FirstOrderSetup(system, model, xi, args.T, args.base_sigma, args.num_gaussians,
9291
args.trainable_weights)
9392
elif args.parameterization == 'low_rank':
9493
raise NotImplementedError("Low-rank parameterization not implemented")
9594
else:
9695
raise ValueError(f"Unknown parameterization: {args.parameterization}")
9796
elif args.ode == 'second_order':
9897
if parameterization == 'diagonal':
99-
return diagonal.SecondOrderSetup(system, model, args.T, args.base_sigma, args.num_gaussians,
98+
return diagonal.SecondOrderSetup(system, model, xi, args.T, args.base_sigma, args.num_gaussians,
10099
args.trainable_weights)
101100
else:
102101
raise NotImplementedError("Second-order ODE not implemented")

0 commit comments

Comments
 (0)