85
85
parser .add_argument ('--dt' , type = float , required = True )
86
86
87
87
# plotting
88
+ parser .add_argument ('--no_plots' , type = str2bool , nargs = '?' , const = True , default = False , help = "Disable all plots." )
88
89
parser .add_argument ('--log_plots' , type = str2bool , nargs = '?' , const = True , default = False ,
89
90
help = "Save plots in log scale where possible" )
90
91
parser .add_argument ('--extension' , type = str , default = 'pdf' , help = "Extension of the saved plots." )
@@ -107,6 +108,9 @@ def main():
107
108
else :
108
109
system = System .from_pdb (args .start , args .target , args .forcefield , args .cv , args .force_clip )
109
110
111
+ if args .no_plots :
112
+ system .plot = None
113
+
110
114
if args .xi :
111
115
xi = args .xi
112
116
else :
@@ -182,13 +186,12 @@ def main():
182
186
log_scale (args .log_plots , False , True )
183
187
show_or_save_fig (args .save_dir , 'loss_plot' , args .extension )
184
188
189
+ t = args .T * jnp .linspace (0 , 1 , args .BS , dtype = jnp .float32 ).reshape ((- 1 , 1 ))
190
+ key , path_key = jax .random .split (key )
191
+ mu_t , _ , w_logits = state_q .apply_fn (state_q .params , t )
192
+ w = jax .nn .softmax (w_logits )
193
+ print ('Weights of mixtures:' , w )
185
194
if system .plot :
186
- t = args .T * jnp .linspace (0 , 1 , args .BS , dtype = jnp .float32 ).reshape ((- 1 , 1 ))
187
- key , path_key = jax .random .split (key )
188
- mu_t , _ , w_logits = state_q .apply_fn (state_q .params , t )
189
- w = jax .nn .softmax (w_logits )
190
- print ('Weights of mixtures:' , w )
191
-
192
195
mu_t_no_vel = mu_t [:, :, :system .A .shape [0 ]]
193
196
num_trajectories = jnp .array ((w * 100 ).round (), dtype = int )
194
197
0 commit comments