Skip to content

Commit 577ccba

Browse files
committed
Add first refactor structure
1 parent 9444fd0 commit 577ccba

File tree

10 files changed

+434
-9
lines changed

10 files changed

+434
-9
lines changed

environment.yml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
name: tps-flow-md
22
channels:
3-
- pytorch
43
- defaults
54
- conda-forge
65
dependencies:
76
- python=3.11.7
87
- pip=23.3.2
9-
- pytorch=2.1.2
108
- openmm=8.1.1
119
- mdtraj=1.9.8
12-
- jax=0.4.23
13-
- flax=0.8.2
1410
- tqdm=4.65.0
11+
- jax=0.4.26
12+
- jaxlib=0.4.26
13+
- flax=0.8.3
14+
- matplotlib=3.8.4
15+
- scipy=1.13.1
16+
- scikit-image=0.23.2
17+
- ParmEd=4.2.2
1518
- pip:
1619
- dmff @ git+https://github.com/deepmodeling/DMFF@v1.0.0
17-
- matplotlib==3.8.2
1820
- rdkit==2023.3.3
19-
- ParmEd==4.2.2
20-
- scikit-image==0.23.2

main.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
from argparse import ArgumentParser
2+
3+
from utils.args import parse_args
4+
from systems import System
5+
import matplotlib.pyplot as plt
6+
7+
parser = ArgumentParser()
8+
parser.add_argument('--out', type=str, default=None, help="Specify a path where the data will be stored.")
9+
parser.add_argument('--config', type=str, help='Path to the config yaml file')
10+
11+
# system configuration
12+
parser.add_argument('--test_system', type=str,
13+
choices=['double_well', 'double_well_hard', 'double_well_dual_channel', 'mueller_brown'])
14+
parser.add_argument('--start', type=str, help="Path to pdb file with the start structure A")
15+
parser.add_argument('--target', type=str, help="Path to pdb file with the target structure B")
16+
17+
parser.add_argument('--T', type=float, required=True,
18+
help="Transition time in the base unit of the system. For molecular simulations, this is in picoseconds.")
19+
parser.add_argument('--xi', type=float, required=True)
20+
21+
# training
22+
parser.add_argument('--epochs', type=int, default=10_000, help="Number of epochs the system is training for.")
23+
parser.add_argument('--BS', type=int, default=512, help="Batch size used for training.")
24+
parser.add_argument('--lr', type=float, default=1e-4, help="Learning rate")
25+
26+
parser.add_argument('--seed', type=int, default=1, help="The seed that will be used for initialization")
27+
28+
# inference
29+
parser.add_argument('--num_paths', type=int, default=1000, help="The number of paths that will be generated.")
30+
parser.add_argument('--dt', type=float, required=True)
31+
# TODO: add sampling method. it would be easy to just do a few MD steps from A and then use those. Might also be out of distribution, not sure
32+
# TODO: I think this could also be a reason why the paths are all the same
33+
# TODO: maybe we can also use MD_STEP(A) and MD_STEP(B) as a dynamic input to the neural network instead of using fixed A and B.s
34+
35+
36+
# TODO: remove this
37+
# parser.add_argument('--mechanism', type=str, choices=['one-way-shooting', 'two-way-shooting'], required=True)
38+
# parser.add_argument('--states', type=str, default='phi-psi', choices=['phi-psi', 'rmsd'])
39+
# parser.add_argument('--fixed_length', type=int, default=0)
40+
# parser.add_argument('--warmup', type=int, default=0)
41+
# parser.add_argument('--num_steps', type=int, default=10,
42+
# help='The number of MD steps taken at once. More takes longer to compile but runs faster in the end.')
43+
# parser.add_argument('--resume', action='store_true')
44+
# parser.add_argument('--override', action='store_true')
45+
# parser.add_argument('--ensure_connected', action='store_true',
46+
# help='Ensure that the initial path connects A with B by prepending A and appending B.')
47+
48+
if __name__ == '__main__':
49+
args = parse_args(parser)
50+
assert args.test_system or args.start and args.target, "Either specify a test system or provide start and target structures"
51+
assert not (
52+
args.test_system and args.start and args.target), "Specify either a test system or provide start and target structures, not both"
53+
54+
print(f'Config: {args}')
55+
56+
if args.test_system:
57+
system = System.from_name(args.test_system)
58+
else:
59+
raise NotImplementedError
60+
# system = System.from_forcefield(args.start, args.target)
61+
62+
import jax.numpy as jnp
63+
import jax
64+
from tqdm import trange
65+
from flax.training import train_state
66+
import optax
67+
import model.diagonal as diagonal
68+
from model.train import train
69+
from model import MLPq
70+
71+
N = int(args.T / args.dt)
72+
73+
# You can play around with any model here
74+
model = MLPq([128, 128, 128])
75+
76+
# TODO: parameterize mixtures, weights, and base_sigma
77+
base_sigma = 2.5 * 1e-2
78+
setup = diagonal.FirstOrderSetup(system, model, args.T, 1, False, base_sigma)
79+
80+
key = jax.random.PRNGKey(args.seed)
81+
key, init_key = jax.random.split(key)
82+
params_q = setup.model_q.init(init_key, jnp.ones([args.BS, 1]))
83+
84+
optimizer_q = optax.adam(learning_rate=args.lr)
85+
state_q = train_state.TrainState.create(apply_fn=setup.model_q.apply, params=params_q, tx=optimizer_q)
86+
loss_fn = setup.construct_loss(state_q, args.xi, args.BS)
87+
88+
key, train_key = jax.random.split(key)
89+
state_q, loss_plot = train(loss_fn, state_q, args.epochs, train_key)
90+
print("Number of potential evaluations", args.BS * args.epochs)
91+
92+
plt.plot(loss_plot)
93+
plt.show()
94+
95+
t = args.T * jnp.linspace(0, 1, args.BS).reshape((-1, 1))
96+
key, path_key = jax.random.split(key)
97+
eps = jax.random.normal(path_key, [args.BS, 2])
98+
mu_t, sigma_t, _ = state_q.apply_fn(state_q.params, t)
99+
samples = mu_t + sigma_t * eps
100+
# plot_energy_surface()
101+
# plt.scatter(samples[:, 0], samples[:, 1])
102+
# plt.scatter(A[0, 0], A[0, 1], color='red')
103+
# plt.scatter(B[0, 0], B[0, 1], color='orange')
104+
# plt.show()
105+
106+
mu_t = lambda _t: state_q.apply_fn(state_q.params, _t)[0]
107+
sigma_t = lambda _t: state_q.apply_fn(state_q.params, _t)[1]
108+
109+
110+
def dmudt(_t):
111+
_dmudt = jax.jacrev(lambda _t: mu_t(_t).sum(0), argnums=0)
112+
return _dmudt(_t).squeeze().T
113+
114+
115+
def dsigmadt(_t):
116+
_dsigmadt = jax.jacrev(lambda _t: sigma_t(_t).sum(0))
117+
return _dsigmadt(_t).squeeze().T
118+
119+
120+
u_t = jax.jit(lambda _t, _x: dmudt(_t) + dsigmadt(_t) / sigma_t(_t) * (_x - mu_t(_t)))
121+
122+
key, loc_key = jax.random.split(key)
123+
x_t = jnp.ones((args.BS, N + 1, 2)) * system.A[None:, ]
124+
eps = jax.random.normal(key, shape=(args.BS, 2))
125+
x_t = x_t.at[:, 0, :].set(x_t[:, 0, :] + sigma_t(jnp.zeros((args.BS, 1))) * eps)
126+
t = jnp.zeros((args.BS, 1))
127+
for i in trange(N):
128+
dx = args.dt * u_t(t, x_t[:, i, :])
129+
x_t = x_t.at[:, i + 1, :].set(x_t[:, i, :] + dx)
130+
t += args.dt
131+
132+
x_t_det = x_t.copy()
133+
134+
u_t = jax.jit(
135+
lambda _t, _x: dmudt(_t) + (dsigmadt(_t) / sigma_t(_t) - 0.5 * (args.xi / sigma_t(_t)) ** 2) * (_x - mu_t(_t)))
136+
137+
# TODO: find a better way then resetting BS
138+
BS = args.num_paths
139+
key, loc_key = jax.random.split(key)
140+
x_t = jnp.ones((BS, N + 1, 2)) * system.A[None, :]
141+
eps = jax.random.normal(key, shape=(BS, 2))
142+
x_t = x_t.at[:, 0, :].set(x_t[:, 0, :] + sigma_t(jnp.zeros((BS, 1))) * eps)
143+
t = jnp.zeros((BS, 1))
144+
for i in trange(N):
145+
key, loc_key = jax.random.split(key)
146+
eps = jax.random.normal(key, shape=(BS, 2))
147+
dx = args.dt * u_t(t, x_t[:, i, :]) + jnp.sqrt(args.dt) * args.xi * eps
148+
x_t = x_t.at[:, i + 1, :].set(x_t[:, i, :] + dx)
149+
t += args.dt
150+
151+
x_t_stoch = x_t.copy()

model/__init__.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from abc import ABC, abstractmethod
2+
from flax import linen as nn
3+
from jax.typing import ArrayLike
4+
5+
6+
class WrappedModule(ABC, nn.Module):
7+
other: nn.Module
8+
T: float
9+
10+
def __call__(self, t: ArrayLike):
11+
t = t / self.T
12+
13+
h = self.other(t)
14+
return self._post_process(t, h)
15+
16+
@abstractmethod
17+
def _post_process(self, t: ArrayLike, h: ArrayLike):
18+
raise NotImplementedError
19+
20+
21+
class MLPq(nn.Module):
22+
hidden_dims: ArrayLike
23+
24+
@nn.compact
25+
def __call__(self, t):
26+
h = t - 0.5
27+
for dim in self.hidden_dims:
28+
h = nn.Dense(dim)(h)
29+
h = nn.swish(h)
30+
31+
return h

model/diagonal.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from flax import linen as nn
2+
import jax.numpy as jnp
3+
from typing import Union, Dict, Any, Callable
4+
from flax.training.train_state import TrainState
5+
import jax
6+
from flax.typing import FrozenVariableDict
7+
from jax.typing import ArrayLike
8+
from model import WrappedModule
9+
from model.setup import TrainSetup
10+
from systems import System
11+
12+
13+
class DiagonalWrapper(WrappedModule):
14+
A: ArrayLike
15+
B: ArrayLike
16+
num_mixtures: int
17+
trainable_weights: bool
18+
base_sigma: float
19+
20+
@nn.compact
21+
def _post_process(self, t: ArrayLike, h: ArrayLike):
22+
print('WARNING: Gaussian Mixture Model not implemented yet')
23+
ndim = self.A.shape[0]
24+
h = nn.Dense(2 * ndim)(h)
25+
26+
mu = (1 - t) * self.A[None, :] + t * self.B[None, :] + (1 - t) * t * h[:, :ndim]
27+
sigma = (1 - t) * self.base_sigma + t * self.base_sigma + (1 - t) * t * jnp.exp(h[:, ndim:])
28+
w_logits = self.param('w_logits', nn.initializers.zeros_init(),
29+
(self.num_mixtures,)) if self.trainable_weights else jnp.zeros(self.num_mixtures)
30+
31+
return mu, sigma, w_logits
32+
33+
34+
class FirstOrderSetup(TrainSetup):
35+
36+
def __init__(self, system: System, model: nn.module, T: float, num_mixtures: int, trainable_weights: bool,
37+
base_sigma: float):
38+
model_q = DiagonalWrapper(model, T, system.A, system.B, num_mixtures, trainable_weights, base_sigma)
39+
super().__init__(system, model_q)
40+
self.system = system
41+
self.T = T
42+
43+
def construct_loss(self, state_q: TrainState, xi: float, BS: float) -> Callable[
44+
[Union[FrozenVariableDict, Dict[str, Any]], ArrayLike], float]:
45+
print('WARNING: Gaussian Mixture Loss not implemented yet')
46+
47+
def loss_fn(params_q: Union[FrozenVariableDict, Dict[str, Any]], key: ArrayLike) -> float:
48+
key = jax.random.split(key)
49+
t = self.T * jax.random.uniform(key[0], [BS, 1])
50+
eps = jax.random.normal(key[1], [BS, 2])
51+
52+
mu_t = lambda _t: state_q.apply_fn(params_q, _t)[0]
53+
sigma_t = lambda _t: state_q.apply_fn(params_q, _t)[1]
54+
55+
def dmudt(_t):
56+
_dmudt = jax.jacrev(lambda _t: mu_t(_t).sum(0))
57+
return _dmudt(_t).squeeze().T
58+
59+
def dsigmadt(_t):
60+
_dsigmadt = jax.jacrev(lambda _t: sigma_t(_t).sum(0))
61+
return _dsigmadt(_t).squeeze().T
62+
63+
def v_t(_eps, _t):
64+
u_t = dmudt(_t) + dsigmadt(_t) * _eps
65+
_x = mu_t(_t) + sigma_t(_t) * _eps
66+
out = (u_t + self.system.dUdx(_x)) - 0.5 * (xi ** 2) * _eps / sigma_t(t)
67+
return out
68+
69+
loss = 0.5 * ((v_t(eps, t) / xi) ** 2).sum(1, keepdims=True)
70+
return loss.mean()
71+
72+
return loss_fn

model/setup.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from abc import ABC, abstractmethod
2+
from dataclasses import dataclass
3+
from typing import Callable
4+
5+
from flax import linen as nn
6+
7+
from systems import System
8+
9+
10+
@dataclass
11+
class TrainSetup(ABC):
12+
system: System
13+
model_q: nn.Module
14+
15+
@abstractmethod
16+
def construct_loss(self, *args, **kwargs) -> Callable:
17+
raise NotImplementedError

model/train.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from typing import Callable, Tuple
2+
from flax.training.train_state import TrainState
3+
import jax
4+
from jax.typing import ArrayLike
5+
from tqdm import trange
6+
7+
8+
def train(loss_fn: Callable, state_q: TrainState, epochs: int, key: ArrayLike) -> Tuple[TrainState, list[float]]:
9+
@jax.jit
10+
def train_step(_state_q: TrainState, _key: ArrayLike) -> (TrainState, float):
11+
grad_fn = jax.value_and_grad(loss_fn, argnums=0)
12+
loss, grads = grad_fn(_state_q.params, _key)
13+
_state_q = _state_q.apply_gradients(grads=grads)
14+
return _state_q, loss
15+
16+
loss_plot = []
17+
18+
for _ in trange(epochs):
19+
key, loc_key = jax.random.split(key)
20+
state_q, loss = train_step(state_q, loc_key)
21+
loss_plot.append(loss)
22+
23+
return state_q, loss_plot

potentials.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import jax
2+
import jax.numpy as jnp
3+
4+
5+
@jax.jit
6+
def U_double_well(xs, a=1.0, b=-4.0, c=0, d=1.0, beta=1.0):
7+
x, y = xs[:, 0], xs[:, 1]
8+
return beta * (a * (x ** 4) + b * (x ** 2) + c * x + 0.5 * d * (y ** 2))
9+
10+
11+
@jax.jit
12+
def U_double_well_hard(xs, beta=1.0):
13+
A = jnp.array([[-3, 0]])
14+
B = jnp.array([[3, 0]])
15+
U1 = -(((xs - A) @ jnp.array([[1, 0.5], [0.5, 1.0]])) * (xs - A)).sum(1)
16+
U2 = -(((xs - B) @ jnp.array([[1, -0.5], [-0.5, 1.0]])) * (xs - B)).sum(1)
17+
out = -jnp.log(jnp.exp(U1 - jnp.maximum(U1, U2)) + jnp.exp(U2 - jnp.maximum(U1, U2))) - jnp.maximum(U1, U2)
18+
return beta * out
19+
20+
21+
@jax.jit
22+
def U_double_well_dual_channel(xs, beta=1.0):
23+
x, y = xs[:, 0], xs[:, 1]
24+
borders = x ** 6 + y ** 6
25+
e1 = +2.0 * jnp.exp(-(12.0 * (x - 0.00) ** 2 + 12.0 * (y - 0.00) ** 2))
26+
e2 = -1.0 * jnp.exp(-(12.0 * (x + 0.50) ** 2 + 12.0 * (y + 0.00) ** 2))
27+
e3 = -1.0 * jnp.exp(-(12.0 * (x - 0.50) ** 2 + 12.0 * (y + 0.00) ** 2))
28+
return beta * (borders + e1 + e2 + e3)
29+
30+
31+
@jax.jit
32+
def U_mueller_brown(xs, beta=1.0):
33+
x, y = xs[:, 0], xs[:, 1]
34+
e1 = -200 * jnp.exp(-(x - 1) ** 2 - 10 * y ** 2)
35+
e2 = -100 * jnp.exp(-x ** 2 - 10 * (y - 0.5) ** 2)
36+
e3 = -170 * jnp.exp(-6.5 * (0.5 + x) ** 2 + 11 * (x + 0.5) * (y - 1.5) - 6.5 * (y - 1.5) ** 2)
37+
e4 = 15.0 * jnp.exp(0.7 * (1 + x) ** 2 + 0.6 * (x + 1) * (y - 1) + 0.7 * (y - 1) ** 2)
38+
return beta * (e1 + e2 + e3 + e4)
39+
40+
41+
double_well = (U_double_well,)
42+
double_well_hard = (U_double_well_hard,)
43+
double_well_dual_channel = (U_double_well_dual_channel,)
44+
mueller_brown = (U_mueller_brown, jnp.array([-0.55828035, 1.44169]), jnp.array([0.62361133, 0.02804632]))

0 commit comments

Comments
 (0)