|
| 1 | +from argparse import ArgumentParser |
| 2 | + |
| 3 | +from utils.args import parse_args |
| 4 | +from systems import System |
| 5 | +import matplotlib.pyplot as plt |
| 6 | + |
| 7 | +parser = ArgumentParser() |
| 8 | +parser.add_argument('--out', type=str, default=None, help="Specify a path where the data will be stored.") |
| 9 | +parser.add_argument('--config', type=str, help='Path to the config yaml file') |
| 10 | + |
| 11 | +# system configuration |
| 12 | +parser.add_argument('--test_system', type=str, |
| 13 | + choices=['double_well', 'double_well_hard', 'double_well_dual_channel', 'mueller_brown']) |
| 14 | +parser.add_argument('--start', type=str, help="Path to pdb file with the start structure A") |
| 15 | +parser.add_argument('--target', type=str, help="Path to pdb file with the target structure B") |
| 16 | + |
| 17 | +parser.add_argument('--T', type=float, required=True, |
| 18 | + help="Transition time in the base unit of the system. For molecular simulations, this is in picoseconds.") |
| 19 | +parser.add_argument('--xi', type=float, required=True) |
| 20 | + |
| 21 | +# training |
| 22 | +parser.add_argument('--epochs', type=int, default=10_000, help="Number of epochs the system is training for.") |
| 23 | +parser.add_argument('--BS', type=int, default=512, help="Batch size used for training.") |
| 24 | +parser.add_argument('--lr', type=float, default=1e-4, help="Learning rate") |
| 25 | + |
| 26 | +parser.add_argument('--seed', type=int, default=1, help="The seed that will be used for initialization") |
| 27 | + |
| 28 | +# inference |
| 29 | +parser.add_argument('--num_paths', type=int, default=1000, help="The number of paths that will be generated.") |
| 30 | +parser.add_argument('--dt', type=float, required=True) |
| 31 | +# TODO: add sampling method. it would be easy to just do a few MD steps from A and then use those. Might also be out of distribution, not sure |
| 32 | +# TODO: I think this could also be a reason why the paths are all the same |
| 33 | +# TODO: maybe we can also use MD_STEP(A) and MD_STEP(B) as a dynamic input to the neural network instead of using fixed A and B.s |
| 34 | + |
| 35 | + |
| 36 | +# TODO: remove this |
| 37 | +# parser.add_argument('--mechanism', type=str, choices=['one-way-shooting', 'two-way-shooting'], required=True) |
| 38 | +# parser.add_argument('--states', type=str, default='phi-psi', choices=['phi-psi', 'rmsd']) |
| 39 | +# parser.add_argument('--fixed_length', type=int, default=0) |
| 40 | +# parser.add_argument('--warmup', type=int, default=0) |
| 41 | +# parser.add_argument('--num_steps', type=int, default=10, |
| 42 | +# help='The number of MD steps taken at once. More takes longer to compile but runs faster in the end.') |
| 43 | +# parser.add_argument('--resume', action='store_true') |
| 44 | +# parser.add_argument('--override', action='store_true') |
| 45 | +# parser.add_argument('--ensure_connected', action='store_true', |
| 46 | +# help='Ensure that the initial path connects A with B by prepending A and appending B.') |
| 47 | + |
| 48 | +if __name__ == '__main__': |
| 49 | + args = parse_args(parser) |
| 50 | + assert args.test_system or args.start and args.target, "Either specify a test system or provide start and target structures" |
| 51 | + assert not ( |
| 52 | + args.test_system and args.start and args.target), "Specify either a test system or provide start and target structures, not both" |
| 53 | + |
| 54 | + print(f'Config: {args}') |
| 55 | + |
| 56 | + if args.test_system: |
| 57 | + system = System.from_name(args.test_system) |
| 58 | + else: |
| 59 | + raise NotImplementedError |
| 60 | + # system = System.from_forcefield(args.start, args.target) |
| 61 | + |
| 62 | + import jax.numpy as jnp |
| 63 | + import jax |
| 64 | + from tqdm import trange |
| 65 | + from flax.training import train_state |
| 66 | + import optax |
| 67 | + import model.diagonal as diagonal |
| 68 | + from model.train import train |
| 69 | + from model import MLPq |
| 70 | + |
| 71 | + N = int(args.T / args.dt) |
| 72 | + |
| 73 | + # You can play around with any model here |
| 74 | + model = MLPq([128, 128, 128]) |
| 75 | + |
| 76 | + # TODO: parameterize mixtures, weights, and base_sigma |
| 77 | + base_sigma = 2.5 * 1e-2 |
| 78 | + setup = diagonal.FirstOrderSetup(system, model, args.T, 1, False, base_sigma) |
| 79 | + |
| 80 | + key = jax.random.PRNGKey(args.seed) |
| 81 | + key, init_key = jax.random.split(key) |
| 82 | + params_q = setup.model_q.init(init_key, jnp.ones([args.BS, 1])) |
| 83 | + |
| 84 | + optimizer_q = optax.adam(learning_rate=args.lr) |
| 85 | + state_q = train_state.TrainState.create(apply_fn=setup.model_q.apply, params=params_q, tx=optimizer_q) |
| 86 | + loss_fn = setup.construct_loss(state_q, args.xi, args.BS) |
| 87 | + |
| 88 | + key, train_key = jax.random.split(key) |
| 89 | + state_q, loss_plot = train(loss_fn, state_q, args.epochs, train_key) |
| 90 | + print("Number of potential evaluations", args.BS * args.epochs) |
| 91 | + |
| 92 | + plt.plot(loss_plot) |
| 93 | + plt.show() |
| 94 | + |
| 95 | + t = args.T * jnp.linspace(0, 1, args.BS).reshape((-1, 1)) |
| 96 | + key, path_key = jax.random.split(key) |
| 97 | + eps = jax.random.normal(path_key, [args.BS, 2]) |
| 98 | + mu_t, sigma_t, _ = state_q.apply_fn(state_q.params, t) |
| 99 | + samples = mu_t + sigma_t * eps |
| 100 | + # plot_energy_surface() |
| 101 | + # plt.scatter(samples[:, 0], samples[:, 1]) |
| 102 | + # plt.scatter(A[0, 0], A[0, 1], color='red') |
| 103 | + # plt.scatter(B[0, 0], B[0, 1], color='orange') |
| 104 | + # plt.show() |
| 105 | + |
| 106 | + mu_t = lambda _t: state_q.apply_fn(state_q.params, _t)[0] |
| 107 | + sigma_t = lambda _t: state_q.apply_fn(state_q.params, _t)[1] |
| 108 | + |
| 109 | + |
| 110 | + def dmudt(_t): |
| 111 | + _dmudt = jax.jacrev(lambda _t: mu_t(_t).sum(0), argnums=0) |
| 112 | + return _dmudt(_t).squeeze().T |
| 113 | + |
| 114 | + |
| 115 | + def dsigmadt(_t): |
| 116 | + _dsigmadt = jax.jacrev(lambda _t: sigma_t(_t).sum(0)) |
| 117 | + return _dsigmadt(_t).squeeze().T |
| 118 | + |
| 119 | + |
| 120 | + u_t = jax.jit(lambda _t, _x: dmudt(_t) + dsigmadt(_t) / sigma_t(_t) * (_x - mu_t(_t))) |
| 121 | + |
| 122 | + key, loc_key = jax.random.split(key) |
| 123 | + x_t = jnp.ones((args.BS, N + 1, 2)) * system.A[None:, ] |
| 124 | + eps = jax.random.normal(key, shape=(args.BS, 2)) |
| 125 | + x_t = x_t.at[:, 0, :].set(x_t[:, 0, :] + sigma_t(jnp.zeros((args.BS, 1))) * eps) |
| 126 | + t = jnp.zeros((args.BS, 1)) |
| 127 | + for i in trange(N): |
| 128 | + dx = args.dt * u_t(t, x_t[:, i, :]) |
| 129 | + x_t = x_t.at[:, i + 1, :].set(x_t[:, i, :] + dx) |
| 130 | + t += args.dt |
| 131 | + |
| 132 | + x_t_det = x_t.copy() |
| 133 | + |
| 134 | + u_t = jax.jit( |
| 135 | + lambda _t, _x: dmudt(_t) + (dsigmadt(_t) / sigma_t(_t) - 0.5 * (args.xi / sigma_t(_t)) ** 2) * (_x - mu_t(_t))) |
| 136 | + |
| 137 | + # TODO: find a better way then resetting BS |
| 138 | + BS = args.num_paths |
| 139 | + key, loc_key = jax.random.split(key) |
| 140 | + x_t = jnp.ones((BS, N + 1, 2)) * system.A[None, :] |
| 141 | + eps = jax.random.normal(key, shape=(BS, 2)) |
| 142 | + x_t = x_t.at[:, 0, :].set(x_t[:, 0, :] + sigma_t(jnp.zeros((BS, 1))) * eps) |
| 143 | + t = jnp.zeros((BS, 1)) |
| 144 | + for i in trange(N): |
| 145 | + key, loc_key = jax.random.split(key) |
| 146 | + eps = jax.random.normal(key, shape=(BS, 2)) |
| 147 | + dx = args.dt * u_t(t, x_t[:, i, :]) + jnp.sqrt(args.dt) * args.xi * eps |
| 148 | + x_t = x_t.at[:, i + 1, :].set(x_t[:, i, :] + dx) |
| 149 | + t += args.dt |
| 150 | + |
| 151 | + x_t_stoch = x_t.copy() |
0 commit comments