Skip to content

Commit 358c7e0

Browse files
committed
Add no_plots
1 parent 459744b commit 358c7e0

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

main.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
parser.add_argument('--dt', type=float, required=True)
8686

8787
# plotting
88+
parser.add_argument('--no_plots', type=str2bool, nargs='?', const=True, default=False, help="Disable all plots.")
8889
parser.add_argument('--log_plots', type=str2bool, nargs='?', const=True, default=False,
8990
help="Save plots in log scale where possible")
9091
parser.add_argument('--extension', type=str, default='pdf', help="Extension of the saved plots.")
@@ -107,6 +108,9 @@ def main():
107108
else:
108109
system = System.from_pdb(args.start, args.target, args.forcefield, args.cv, args.force_clip)
109110

111+
if args.no_plots:
112+
system.plot = None
113+
110114
if args.xi:
111115
xi = args.xi
112116
else:
@@ -182,13 +186,12 @@ def main():
182186
log_scale(args.log_plots, False, True)
183187
show_or_save_fig(args.save_dir, 'loss_plot', args.extension)
184188

189+
t = args.T * jnp.linspace(0, 1, args.BS, dtype=jnp.float32).reshape((-1, 1))
190+
key, path_key = jax.random.split(key)
191+
mu_t, _, w_logits = state_q.apply_fn(state_q.params, t)
192+
w = jax.nn.softmax(w_logits)
193+
print('Weights of mixtures:', w)
185194
if system.plot:
186-
t = args.T * jnp.linspace(0, 1, args.BS, dtype=jnp.float32).reshape((-1, 1))
187-
key, path_key = jax.random.split(key)
188-
mu_t, _, w_logits = state_q.apply_fn(state_q.params, t)
189-
w = jax.nn.softmax(w_logits)
190-
print('Weights of mixtures:', w)
191-
192195
mu_t_no_vel = mu_t[:, :, :system.A.shape[0]]
193196
num_trajectories = jnp.array((w * 100).round(), dtype=int)
194197

0 commit comments

Comments
 (0)