Skip to content

Commit 0c1d638

Browse files
committed
Add force clipping
1 parent 86806d5 commit 0c1d638

File tree

3 files changed

+16
-14
lines changed

3 files changed

+16
-14
lines changed

main.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
parser.add_argument('--epochs', type=int, default=10_000, help="Number of epochs the system is training for.")
4747
parser.add_argument('--BS', type=int, default=512, help="Batch size used for training.")
4848
parser.add_argument('--lr', type=float, default=1e-4, help="Learning rate")
49+
parser.add_argument('--force_clip', type=float, default=1e8, help="Clipping value for the force")
4950

5051
parser.add_argument('--seed', type=int, default=1, help="The seed that will be used for initialization")
5152

@@ -55,8 +56,9 @@
5556

5657

5758
def main():
58-
# TODO: force clipping
5959
print("!!!!Next todos: plot ALDP")
60+
# TODO: internal coordinates
61+
# TODO: neural network parameterization
6062

6163
args = parse_args(parser)
6264
assert args.test_system or args.start and args.target, "Either specify a test system or provide start and target structures"
@@ -70,9 +72,9 @@ def main():
7072
os.makedirs(args.save_dir, exist_ok=True)
7173

7274
if args.test_system:
73-
system = System.from_name(args.test_system)
75+
system = System.from_name(args.test_system, args.force_clip)
7476
else:
75-
system = System.from_pdb(args.start, args.target, args.forcefield, args.cv)
77+
system = System.from_pdb(args.start, args.target, args.forcefield, args.cv, args.force_clip)
7678

7779
if args.xi:
7880
xi = args.xi

potentials.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
1-
import jax
21
import jax.numpy as jnp
32

43

5-
@jax.jit
64
def U_double_well(xs, a=1.0, b=-4.0, c=0, d=1.0, beta=1.0):
75
x, y = xs[:, 0], xs[:, 1]
86
return beta * (a * (x ** 4) + b * (x ** 2) + c * x + 0.5 * d * (y ** 2))
97

108

11-
@jax.jit
129
def U_double_well_hard(xs, beta=1.0):
1310
A = jnp.array([[-3, 0]])
1411
B = jnp.array([[3, 0]])
@@ -18,7 +15,6 @@ def U_double_well_hard(xs, beta=1.0):
1815
return beta * out
1916

2017

21-
@jax.jit
2218
def U_double_well_dual_channel(xs, beta=1.0):
2319
x, y = xs[:, 0], xs[:, 1]
2420
borders = x ** 6 + y ** 6
@@ -28,7 +24,6 @@ def U_double_well_dual_channel(xs, beta=1.0):
2824
return beta * (borders + e1 + e2 + e3)
2925

3026

31-
@jax.jit
3227
def U_mueller_brown(xs, beta=1.0):
3328
x, y = xs[:, 0], xs[:, 1]
3429
e1 = -200 * jnp.exp(-(x - 1) ** 2 - 10 * y ** 2)

systems.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,26 @@
1111
from utils.pdb import assert_same_molecule
1212
from utils.rmsd import kabsch_align
1313
from dmff import Hamiltonian, NeighborList # This sets jax to use 64-bit precision
14+
import mdtraj as md
15+
1416

1517
class System:
16-
def __init__(self, U: Callable[[ArrayLike], ArrayLike], A: ArrayLike, B: ArrayLike, mass: ArrayLike, plot):
18+
def __init__(self, U: Callable[[ArrayLike], ArrayLike], A: ArrayLike, B: ArrayLike, mass: ArrayLike, plot,
19+
force_clip: float):
1720
assert A.shape == B.shape == mass.shape
1821

1922
self.U = U
20-
self.dUdx = jax.jit(jax.grad(lambda _x: U(_x).sum()))
23+
24+
dUdx = jax.grad(lambda _x: U(_x).sum())
25+
self.dUdx = jax.jit(jax.jit(lambda _x: jnp.clip(dUdx(_x), -force_clip, force_clip)))
2126

2227
self.A, self.B = A, B
2328
self.mass = mass
2429

2530
self.plot = plot
2631

2732
@classmethod
28-
def from_name(cls, name: str) -> Self:
33+
def from_name(cls, name: str, force_clip: float) -> Self:
2934
if name == 'double_well':
3035
U, A, B = potentials.double_well
3136
elif name == 'double_well_hard':
@@ -45,10 +50,10 @@ def from_name(cls, name: str) -> Self:
4550
U=U, states=list(zip(['A', 'B'], [A, B])), xlim=xlim, ylim=ylim, alpha=1.0
4651
)
4752
mass = jnp.array([1.0, 1.0])
48-
return cls(U, A, B, mass, plot)
53+
return cls(U, A, B, mass, plot, force_clip)
4954

5055
@classmethod
51-
def from_pdb(cls, A: str, B: str, forcefield: [str], cv: Optional[str]) -> Self:
56+
def from_pdb(cls, A: str, B: str, forcefield: [str], cv: Optional[str], force_clip: float) -> Self:
5257
A_pdb, B_pdb = app.PDBFile(A), app.PDBFile(B)
5358
assert_same_molecule(A_pdb, B_pdb)
5459

@@ -87,4 +92,4 @@ def U(_x):
8792
else:
8893
raise ValueError(f"Unknown cv: {cv}")
8994

90-
return cls(U, A, B, mass, None)
95+
return cls(U, A, B, mass, plot, force_clip)

0 commit comments

Comments
 (0)