From ac697274c4f6c0fdfa5a871b51815238605d1448 Mon Sep 17 00:00:00 2001 From: Stefan Doerr Date: Thu, 11 Jan 2024 10:33:49 +0200 Subject: [PATCH] reimplemented calculation of angles and dihedrals in cython. Speed up by 1000x for large molecules --- .gitignore | 3 + moleculekit/cython_utils/cython_utils.pyx | 112 ++++++++++++++++++++++ moleculekit/util.py | 45 +++------ setup.py | 1 + 4 files changed, 130 insertions(+), 31 deletions(-) create mode 100644 moleculekit/cython_utils/cython_utils.pyx diff --git a/.gitignore b/.gitignore index 8d1344f..c12ec31 100644 --- a/.gitignore +++ b/.gitignore @@ -39,3 +39,6 @@ moleculekit/occupancy_utils/occupancy_utils.html moleculekit/wrapping/wrapping.c moleculekit/wrapping/wrapping.cpp moleculekit/wrapping/wrapping.html +moleculekit/cython_utils/cython_utils.c +moleculekit/cython_utils/cython_utils.cpp +moleculekit/cython_utils/cython_utils.html diff --git a/moleculekit/cython_utils/cython_utils.pyx b/moleculekit/cython_utils/cython_utils.pyx new file mode 100644 index 0000000..c197550 --- /dev/null +++ b/moleculekit/cython_utils/cython_utils.pyx @@ -0,0 +1,112 @@ +# cython: cdivision=True + +import numpy as np +from math import sqrt +cimport numpy as np +from libcpp.vector cimport vector +from libcpp cimport bool +from libc.math cimport sqrt, round +from cython.parallel import prange + +# We now need to fix a datatype for our arrays. I've used the variable +# DTYPE for this, which is assigned to the usual NumPy runtime +# type info object. +INT32 = np.int32 +INT64 = np.int64 +UINT32 = np.uint32 +FLOAT32 = np.float32 +FLOAT64 = np.float64 + +# "ctypedef" assigns a corresponding compile-time type to DTYPE_t. For +# every type in the numpy module there's a corresponding compile-time +# type with a _t-suffix. +ctypedef np.int32_t INT32_t +ctypedef np.int64_t INT64_t +ctypedef np.uint32_t UINT32_t +ctypedef np.float32_t FLOAT32_t +ctypedef np.float64_t FLOAT64_t + +import cython + +@cython.boundscheck(False) # turn off bounds-checking for entire function +@cython.wraparound(False) # turn off negative index wrapping for entire function +def calculateAnglesAndDihedrals( + UINT32_t[:,:] bonds, + bool cyclicdih, + int n_atoms, + ): + # Same as dist_trajectory but instead of returning distances it returns index + # pairs of atoms that are within a certain distance threshold + cdef int i, j, k, a, b, b1, b2, min_v, max_v, n_neigh, n_angles, a1, a2 + cdef int n_bonds = bonds.shape[0] + cdef vector[vector[UINT32_t]] neighbors + cdef vector[vector[UINT32_t]] angles + cdef vector[vector[UINT32_t]] dihedrals + cdef vector[UINT32_t] buffer + cdef vector[UINT32_t] x, y + + for i in range(n_atoms): + for j in range(n_bonds): + b1 = bonds[j, 0] + b2 = bonds[j, 1] + if b1 == i: + buffer.push_back(b2) + elif b2 == i: + buffer.push_back(b1) + neighbors.push_back(buffer) + buffer.clear() + + for i in range(n_atoms): + n_neigh = neighbors[i].size() + for j in range(n_neigh): + for k in range(j+1, n_neigh): + a = neighbors[i][j] + b = neighbors[i][k] + if a != b: + if a < b: + min_v = a + max_v = b + else: + min_v = b + max_v = a + buffer.push_back(min_v) + buffer.push_back(i) + buffer.push_back(max_v) + angles.push_back(buffer) + buffer.clear() + + n_angles = angles.size() + for a1 in range(n_angles): + for a2 in range(a1 + 1, n_angles): + x = angles[a1] + y = angles[a2] + if x[1] == y[0] and x[2] == y[1] and (cyclicdih or (x[0] != y[2])): + buffer.push_back(x[0]) + buffer.push_back(x[1]) + buffer.push_back(x[2]) + buffer.push_back(y[2]) + dihedrals.push_back(buffer) + buffer.clear() + if x[1] == y[2] and x[2] == y[1] and (cyclicdih or (x[0] != y[0])): + buffer.push_back(x[0]) + buffer.push_back(x[1]) + buffer.push_back(x[2]) + buffer.push_back(y[0]) + dihedrals.push_back(buffer) + buffer.clear() + if y[1] == x[0] and y[2] == x[1] and (cyclicdih or (y[0] != x[2])): + buffer.push_back(y[0]) + buffer.push_back(y[1]) + buffer.push_back(y[2]) + buffer.push_back(x[2]) + dihedrals.push_back(buffer) + buffer.clear() + if y[1] == x[0] and y[0] == x[1] and (cyclicdih or (y[2] != x[2])): + buffer.push_back(y[2]) + buffer.push_back(y[1]) + buffer.push_back(y[0]) + buffer.push_back(x[2]) + dihedrals.push_back(buffer) + buffer.clear() + + return neighbors, angles, dihedrals \ No newline at end of file diff --git a/moleculekit/util.py b/moleculekit/util.py index f45d5f3..a5e2079 100644 --- a/moleculekit/util.py +++ b/moleculekit/util.py @@ -509,44 +509,17 @@ def _get_pdb_entity_sequences(entities): return results -def guessAnglesAndDihedrals(bonds, cyclicdih=False): +def calculateAnglesAndDihedrals(bonds, cyclicdih=False): """ - Generate a guess of angle and dihedral N-body terms based on a list of bond index pairs. + Calculate all angles and dihedrals from a set of bonds. """ + from moleculekit.cython_utils import calculateAnglesAndDihedrals as _calculate - import networkx as nx - - g = nx.Graph() - g.add_nodes_from(np.unique(bonds)) - g.add_edges_from([tuple(b) for b in bonds]) - - angles = [] - for n in g.nodes(): - neighbors = list(g.neighbors(n)) - for e1 in range(len(neighbors)): - for e2 in range(e1 + 1, len(neighbors)): - angles.append((neighbors[e1], n, neighbors[e2])) + _, angles, dihedrals = _calculate(bonds, cyclicdih, np.max(bonds) + 1) angles = sorted([sorted([angle, angle[::-1]])[0] for angle in angles]) angles = np.array(angles, dtype=np.uint32) - dihedrals = [] - for a1 in range(len(angles)): - for a2 in range(a1 + 1, len(angles)): - a1a = angles[a1] - a2a = angles[a2] - a2f = a2a[ - ::-1 - ] # Flipped a2a. We don't need flipped a1a as it produces the flipped versions of these 4 - if np.all(a1a[1:] == a2a[:2]) and (cyclicdih or (a1a[0] != a2a[2])): - dihedrals.append(list(a1a) + [a2a[2]]) - if np.all(a1a[1:] == a2f[:2]) and (cyclicdih or (a1a[0] != a2f[2])): - dihedrals.append(list(a1a) + [a2f[2]]) - if np.all(a2a[1:] == a1a[:2]) and (cyclicdih or (a2a[0] != a1a[2])): - dihedrals.append(list(a2a) + [a1a[2]]) - if np.all(a2f[1:] == a1a[:2]) and (cyclicdih or (a2f[0] != a1a[2])): - dihedrals.append(list(a2f) + [a1a[2]]) - dihedrals = sorted( [sorted([dihedral, dihedral[::-1]])[0] for dihedral in dihedrals] ) @@ -560,6 +533,16 @@ def guessAnglesAndDihedrals(bonds, cyclicdih=False): return angles, dihedrals +def guessAnglesAndDihedrals(bonds, cyclicdih=False): + """ + Calculate all angles and dihedrals from a set of bonds. + """ + logger.warning( + "guessAnglesAndDihedrals is deprecated. Please use calculateAnglesAndDihedrals instead." + ) + return calculateAnglesAndDihedrals(bonds, cyclicdih) + + def assertSameAsReferenceDir(compareDir, outdir="."): """Check if files in refdir are present in the directory given as second argument AND their content matches. diff --git a/setup.py b/setup.py index 3105952..138cf4c 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,7 @@ "moleculekit/atomselect_utils/atomselect_utils.pyx", "moleculekit/distance_utils/distance_utils.pyx", "moleculekit/occupancy_utils/occupancy_utils.pyx", + "moleculekit/cython_utils/cython_utils.pyx", ] extentions = [ Extension(