Skip to content

Commit dc065e9

Browse files
committed
Fix plotting for single mixture
1 parent bc62c26 commit dc065e9

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

main.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,11 @@ def main():
186186
t = args.T * jnp.linspace(0, 1, args.BS, dtype=jnp.float32).reshape((-1, 1))
187187
key, path_key = jax.random.split(key)
188188
mu_t, _, w_logits = state_q.apply_fn(state_q.params, t)
189-
w = jax.nn.softmax(w_logits)[None, :, None]
189+
w = jax.nn.softmax(w_logits)
190190
print('Weights of mixtures:', w)
191191

192192
mu_t_no_vel = mu_t[:, :, :system.A.shape[0]]
193-
num_trajectories = jnp.array((w.squeeze() * 100).round(), dtype=int)
193+
num_trajectories = jnp.array((w * 100).round(), dtype=int)
194194

195195
trajectories = jnp.swapaxes(mu_t_no_vel, 0, 1)
196196
trajectories = (jnp.vstack([trajectories[i].repeat(n, axis=0) for i, n in enumerate(num_trajectories) if n > 0])
@@ -200,6 +200,7 @@ def main():
200200
show_or_save_fig(args.save_dir, 'mean_paths', args.extension)
201201

202202
if system.plot and system.A.shape[0] == 2:
203+
print('Animating gif, this might take a few seconds ...')
203204
plot_u_t(system, setup, state_q, args.T, args.save_dir, 'u_t', frames=100)
204205

205206
key, init_key = jax.random.split(key)

0 commit comments

Comments
 (0)