Skip to content

Commit 7af981e

Browse files
committed
Save pdb trajectory
1 parent 1d8a4a4 commit 7af981e

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

main.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from argparse import ArgumentParser
2+
3+
from utils.animation import save_trajectory
24
from utils.args import parse_args, str2bool
35
from systems import System
46
import matplotlib.pyplot as plt
@@ -203,11 +205,14 @@ def main():
203205
x_0 += args.base_sigma * eps
204206

205207
x_t_det = setup.sample_paths(state_q, x_0, args.dt, args.T, args.BS, None)
208+
# In case we have a second order integration scheme, we remove the velocity for plotting
209+
x_t_det_no_vel = x_t_det[:, :, :system.A.shape[0]]
206210

207-
if system.plot:
208-
# In case we have a second order integration scheme, we remove the velocity for plotting
209-
x_t_det_no_vel = x_t_det[:, :, :system.A.shape[0]]
211+
if system.mdtraj_topology:
212+
save_trajectory(system.mdtraj_topology, x_t_det_no_vel[0].reshape(1, -1, 3), f'{args.save_dir}/det_0.pdb')
213+
save_trajectory(system.mdtraj_topology, x_t_det_no_vel[-1].reshape(1, -1, 3), f'{args.save_dir}/det_-1.pdb')
210214

215+
if system.plot:
211216
plot_energy(system, [x_t_det_no_vel[0], x_t_det_no_vel[-1]], args.log_plots)
212217
show_or_save_fig(args.save_dir, 'path_energy_deterministic', args.extension)
213218

@@ -216,10 +221,13 @@ def main():
216221

217222
key, path_key = jax.random.split(key)
218223
x_t_stoch = setup.sample_paths(state_q, x_0, args.dt, args.T, args.BS, path_key)
224+
x_t_stoch_no_vel = x_t_stoch[:, :, :system.A.shape[0]]
219225

220-
if system.plot:
221-
x_t_stoch_no_vel = x_t_stoch[:, :, :system.A.shape[0]]
226+
if system.mdtraj_topology:
227+
save_trajectory(system.mdtraj_topology, x_t_stoch_no_vel[0].reshape(1, -1, 3), f'{args.save_dir}/stoch_0.pdb')
228+
save_trajectory(system.mdtraj_topology, x_t_stoch_no_vel[-1].reshape(1, -1, 3), f'{args.save_dir}/stoch_-1.pdb')
222229

230+
if system.plot:
223231
plot_energy(system, [x_t_stoch_no_vel[0], x_t_stoch_no_vel[-1]], args.log_plots)
224232
show_or_save_fig(args.save_dir, 'path_energy_stochastic', args.extension)
225233

systems.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
class System:
1818
def __init__(self, U: Callable[[ArrayLike], ArrayLike], A: ArrayLike, B: ArrayLike, mass: ArrayLike, plot,
19-
force_clip: float):
19+
force_clip: float, mdtraj_topology: Optional[md.Topology] = None):
2020
assert A.shape == B.shape == mass.shape
2121

2222
self.U = jax.jit(U)
@@ -28,6 +28,7 @@ def __init__(self, U: Callable[[ArrayLike], ArrayLike], A: ArrayLike, B: ArrayLi
2828
self.mass = mass
2929

3030
self.plot = plot
31+
self.mdtraj_topology = mdtraj_topology
3132

3233
@classmethod
3334
def from_name(cls, name: str, force_clip: float) -> Self:
@@ -109,4 +110,4 @@ def U(_x):
109110
else:
110111
raise ValueError(f"Unknown cv: {cv}")
111112

112-
return cls(U, A, B, mass, plot, force_clip)
113+
return cls(U, A, B, mass, plot, force_clip, md.Topology.from_openmm(A_pdb.topology))

0 commit comments

Comments
 (0)