Skip to content

Commit

Permalink
reimplemented calculation of angles and dihedrals in cython. Speed up…
Browse files Browse the repository at this point in the history
… by 1000x for large molecules
  • Loading branch information
stefdoerr committed Jan 11, 2024
1 parent fc5c90c commit ac69727
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 31 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,6 @@ moleculekit/occupancy_utils/occupancy_utils.html
moleculekit/wrapping/wrapping.c
moleculekit/wrapping/wrapping.cpp
moleculekit/wrapping/wrapping.html
moleculekit/cython_utils/cython_utils.c
moleculekit/cython_utils/cython_utils.cpp
moleculekit/cython_utils/cython_utils.html
112 changes: 112 additions & 0 deletions moleculekit/cython_utils/cython_utils.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# cython: cdivision=True

import numpy as np
from math import sqrt
cimport numpy as np
from libcpp.vector cimport vector
from libcpp cimport bool
from libc.math cimport sqrt, round
from cython.parallel import prange

# We now need to fix a datatype for our arrays. I've used the variable
# DTYPE for this, which is assigned to the usual NumPy runtime
# type info object.
INT32 = np.int32
INT64 = np.int64
UINT32 = np.uint32
FLOAT32 = np.float32
FLOAT64 = np.float64

# "ctypedef" assigns a corresponding compile-time type to DTYPE_t. For
# every type in the numpy module there's a corresponding compile-time
# type with a _t-suffix.
ctypedef np.int32_t INT32_t
ctypedef np.int64_t INT64_t
ctypedef np.uint32_t UINT32_t
ctypedef np.float32_t FLOAT32_t
ctypedef np.float64_t FLOAT64_t

import cython

@cython.boundscheck(False) # turn off bounds-checking for entire function
@cython.wraparound(False) # turn off negative index wrapping for entire function
def calculateAnglesAndDihedrals(
UINT32_t[:,:] bonds,
bool cyclicdih,
int n_atoms,
):
# Same as dist_trajectory but instead of returning distances it returns index
# pairs of atoms that are within a certain distance threshold
cdef int i, j, k, a, b, b1, b2, min_v, max_v, n_neigh, n_angles, a1, a2
cdef int n_bonds = bonds.shape[0]
cdef vector[vector[UINT32_t]] neighbors
cdef vector[vector[UINT32_t]] angles
cdef vector[vector[UINT32_t]] dihedrals
cdef vector[UINT32_t] buffer
cdef vector[UINT32_t] x, y

for i in range(n_atoms):
for j in range(n_bonds):
b1 = bonds[j, 0]
b2 = bonds[j, 1]
if b1 == i:
buffer.push_back(b2)
elif b2 == i:
buffer.push_back(b1)
neighbors.push_back(buffer)
buffer.clear()

for i in range(n_atoms):
n_neigh = neighbors[i].size()
for j in range(n_neigh):
for k in range(j+1, n_neigh):
a = neighbors[i][j]
b = neighbors[i][k]
if a != b:
if a < b:
min_v = a
max_v = b
else:
min_v = b
max_v = a
buffer.push_back(min_v)
buffer.push_back(i)
buffer.push_back(max_v)
angles.push_back(buffer)
buffer.clear()

n_angles = angles.size()
for a1 in range(n_angles):
for a2 in range(a1 + 1, n_angles):
x = angles[a1]
y = angles[a2]
if x[1] == y[0] and x[2] == y[1] and (cyclicdih or (x[0] != y[2])):
buffer.push_back(x[0])
buffer.push_back(x[1])
buffer.push_back(x[2])
buffer.push_back(y[2])
dihedrals.push_back(buffer)
buffer.clear()
if x[1] == y[2] and x[2] == y[1] and (cyclicdih or (x[0] != y[0])):
buffer.push_back(x[0])
buffer.push_back(x[1])
buffer.push_back(x[2])
buffer.push_back(y[0])
dihedrals.push_back(buffer)
buffer.clear()
if y[1] == x[0] and y[2] == x[1] and (cyclicdih or (y[0] != x[2])):
buffer.push_back(y[0])
buffer.push_back(y[1])
buffer.push_back(y[2])
buffer.push_back(x[2])
dihedrals.push_back(buffer)
buffer.clear()
if y[1] == x[0] and y[0] == x[1] and (cyclicdih or (y[2] != x[2])):
buffer.push_back(y[2])
buffer.push_back(y[1])
buffer.push_back(y[0])
buffer.push_back(x[2])
dihedrals.push_back(buffer)
buffer.clear()

return neighbors, angles, dihedrals
45 changes: 14 additions & 31 deletions moleculekit/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,44 +509,17 @@ def _get_pdb_entity_sequences(entities):
return results


def guessAnglesAndDihedrals(bonds, cyclicdih=False):
def calculateAnglesAndDihedrals(bonds, cyclicdih=False):
"""
Generate a guess of angle and dihedral N-body terms based on a list of bond index pairs.
Calculate all angles and dihedrals from a set of bonds.
"""
from moleculekit.cython_utils import calculateAnglesAndDihedrals as _calculate

import networkx as nx

g = nx.Graph()
g.add_nodes_from(np.unique(bonds))
g.add_edges_from([tuple(b) for b in bonds])

angles = []
for n in g.nodes():
neighbors = list(g.neighbors(n))
for e1 in range(len(neighbors)):
for e2 in range(e1 + 1, len(neighbors)):
angles.append((neighbors[e1], n, neighbors[e2]))
_, angles, dihedrals = _calculate(bonds, cyclicdih, np.max(bonds) + 1)

angles = sorted([sorted([angle, angle[::-1]])[0] for angle in angles])
angles = np.array(angles, dtype=np.uint32)

dihedrals = []
for a1 in range(len(angles)):
for a2 in range(a1 + 1, len(angles)):
a1a = angles[a1]
a2a = angles[a2]
a2f = a2a[
::-1
] # Flipped a2a. We don't need flipped a1a as it produces the flipped versions of these 4
if np.all(a1a[1:] == a2a[:2]) and (cyclicdih or (a1a[0] != a2a[2])):
dihedrals.append(list(a1a) + [a2a[2]])
if np.all(a1a[1:] == a2f[:2]) and (cyclicdih or (a1a[0] != a2f[2])):
dihedrals.append(list(a1a) + [a2f[2]])
if np.all(a2a[1:] == a1a[:2]) and (cyclicdih or (a2a[0] != a1a[2])):
dihedrals.append(list(a2a) + [a1a[2]])
if np.all(a2f[1:] == a1a[:2]) and (cyclicdih or (a2f[0] != a1a[2])):
dihedrals.append(list(a2f) + [a1a[2]])

dihedrals = sorted(
[sorted([dihedral, dihedral[::-1]])[0] for dihedral in dihedrals]
)
Expand All @@ -560,6 +533,16 @@ def guessAnglesAndDihedrals(bonds, cyclicdih=False):
return angles, dihedrals


def guessAnglesAndDihedrals(bonds, cyclicdih=False):
"""
Calculate all angles and dihedrals from a set of bonds.
"""
logger.warning(
"guessAnglesAndDihedrals is deprecated. Please use calculateAnglesAndDihedrals instead."
)
return calculateAnglesAndDihedrals(bonds, cyclicdih)


def assertSameAsReferenceDir(compareDir, outdir="."):
"""Check if files in refdir are present in the directory given as second argument AND their content matches.
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"moleculekit/atomselect_utils/atomselect_utils.pyx",
"moleculekit/distance_utils/distance_utils.pyx",
"moleculekit/occupancy_utils/occupancy_utils.pyx",
"moleculekit/cython_utils/cython_utils.pyx",
]
extentions = [
Extension(
Expand Down

0 comments on commit ac69727

Please sign in to comment.