@@ -186,11 +186,11 @@ def main():
186
186
t = args .T * jnp .linspace (0 , 1 , args .BS , dtype = jnp .float32 ).reshape ((- 1 , 1 ))
187
187
key , path_key = jax .random .split (key )
188
188
mu_t , _ , w_logits = state_q .apply_fn (state_q .params , t )
189
- w = jax .nn .softmax (w_logits )[ None , :, None ]
189
+ w = jax .nn .softmax (w_logits )
190
190
print ('Weights of mixtures:' , w )
191
191
192
192
mu_t_no_vel = mu_t [:, :, :system .A .shape [0 ]]
193
- num_trajectories = jnp .array ((w . squeeze () * 100 ).round (), dtype = int )
193
+ num_trajectories = jnp .array ((w * 100 ).round (), dtype = int )
194
194
195
195
trajectories = jnp .swapaxes (mu_t_no_vel , 0 , 1 )
196
196
trajectories = (jnp .vstack ([trajectories [i ].repeat (n , axis = 0 ) for i , n in enumerate (num_trajectories ) if n > 0 ])
@@ -200,6 +200,7 @@ def main():
200
200
show_or_save_fig (args .save_dir , 'mean_paths' , args .extension )
201
201
202
202
if system .plot and system .A .shape [0 ] == 2 :
203
+ print ('Animating gif, this might take a few seconds ...' )
203
204
plot_u_t (system , setup , state_q , args .T , args .save_dir , 'u_t' , frames = 100 )
204
205
205
206
key , init_key = jax .random .split (key )
0 commit comments