Skip to content

Commit

Permalink
Type hints for lib module (#3729)
Browse files Browse the repository at this point in the history
* Type hints for mdamath.py

* fix errors in mdamath.py

* changed input from NDarray to arraylike

* Added type annotations to init.py 

Added type annotations to init to avoid mypy from raising errors when other modules are type checked .

* Allowing mypy to type check lib module

* Update changes in init.py

* Update __init__.py

* type hints for pkdtree

* Fix all errros in pkdtree.py

* Chage npt.NDArray to np.ndarray

* Update pkdtree.py

* Update pkdtree.py

* Update pkdtree.py

* Update NeighborSearch.py

* Update NeighborSearch.py

* Update pkdtree.py

Co-authored-by: Jonathan Barnoud <jonathan@barnoud.net>
  • Loading branch information
umak1106 and jbarnoud committed Sep 12, 2022
1 parent 0788165 commit e7ee5a4
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 37 deletions.
8 changes: 0 additions & 8 deletions mypy.ini
Expand Up @@ -18,9 +18,6 @@ ignore_errors = True
[mypy-MDAnalysis.core.*]
ignore_errors = True

[mypy-MDAnalysis.lib.*]
ignore_errors = True

[mypy-MDAnalysis.selections.*]
ignore_errors = True

Expand All @@ -47,8 +44,3 @@ ignore_errors = True

[mypy-MDAnalysis.version]
ignore_errors = True

[mypy-MDAnalysis.*]
ignore_errors = True


3 changes: 1 addition & 2 deletions package/MDAnalysis/__init__.py
Expand Up @@ -179,8 +179,7 @@
_CONVERTERS: Dict = {}
# Registry of TopologyAttributes
_TOPOLOGY_ATTRS: Dict = {} # {attrname: cls}
_TOPOLOGY_TRANSPLANTS: Dict = {}
# {name: [attrname, method, transplant class]}
_TOPOLOGY_TRANSPLANTS: Dict = {} # {name: [attrname, method, transplant class]}
_TOPOLOGY_ATTRNAMES: Dict = {} # {lower case name w/o _ : name}


Expand Down
16 changes: 13 additions & 3 deletions package/MDAnalysis/lib/NeighborSearch.py
Expand Up @@ -31,6 +31,9 @@
import numpy as np
from MDAnalysis.lib.distances import capped_distance
from MDAnalysis.lib.util import unique_int_1d
from MDAnalysis.core.groups import AtomGroup, SegmentGroup, ResidueGroup
import numpy.typing as npt
from typing import Optional, Union, List


class AtomNeighborSearch(object):
Expand All @@ -41,7 +44,8 @@ class AtomNeighborSearch(object):
:class:`~MDAnalysis.lib.distances.capped_distance`.
"""

def __init__(self, atom_group, box=None):
def __init__(self, atom_group: AtomGroup,
box: Optional[npt.ArrayLike] = None) -> None:
"""
Parameters
Expand All @@ -58,7 +62,10 @@ def __init__(self, atom_group, box=None):
self._u = atom_group.universe
self._box = box

def search(self, atoms, radius, level='A'):
def search(self, atoms: AtomGroup,
radius: float,
level: str = 'A'
) -> Optional[Union[AtomGroup, ResidueGroup, SegmentGroup]]:
"""
Return all atoms/residues/segments that are within *radius* of the
atoms in *atoms*.
Expand Down Expand Up @@ -102,7 +109,10 @@ def search(self, atoms, radius, level='A'):
unique_idx = unique_int_1d(np.asarray(pairs[:, 1], dtype=np.intp))
return self._index2level(unique_idx, level)

def _index2level(self, indices, level):
def _index2level(self,
indices: List[int],
level: str
) -> Union[AtomGroup, ResidueGroup, SegmentGroup]:
"""Convert list of atom_indices in a AtomGroup to either the
Atoms or segments/residues containing these atoms.
Expand Down
30 changes: 19 additions & 11 deletions package/MDAnalysis/lib/mdamath.py
Expand Up @@ -63,11 +63,13 @@
from . import util
from ._cutil import (make_whole, find_fragments, _sarrus_det_single,
_sarrus_det_multiple)
import numpy.typing as npt
from typing import Union

# geometric functions


def norm(v):
def norm(v: npt.ArrayLike) -> float:
r"""Calculate the norm of a vector v.
.. math:: v = \sqrt{\mathbf{v}\cdot\mathbf{v}}
Expand All @@ -90,7 +92,8 @@ def norm(v):
return np.sqrt(np.dot(v, v))


def normal(vec1, vec2):
# typing: numpy
def normal(vec1: npt.ArrayLike, vec2: npt.ArrayLike) -> np.ndarray:
r"""Returns the unit vector normal to two vectors.
.. math::
Expand All @@ -110,7 +113,8 @@ def normal(vec1, vec2):
return normal / n


def pdot(a, b):
# typing: numpy
def pdot(a: npt.ArrayLike, b: npt.ArrayLike) -> np.ndarray:
"""Pairwise dot product.
``a`` must be the same shape as ``b``.
Expand All @@ -127,7 +131,8 @@ def pdot(a, b):
return np.einsum('ij,ij->i', a, b)


def pnorm(a):
# typing: numpy
def pnorm(a: npt.ArrayLike) -> np.ndarray:
"""Euclidean norm of each vector in a matrix
Parameters
Expand All @@ -141,7 +146,7 @@ def pnorm(a):
return pdot(a, a)**0.5


def angle(a, b):
def angle(a: npt.ArrayLike, b: npt.ArrayLike) -> float:
"""Returns the angle between two vectors in radians
.. versionchanged:: 0.11.0
Expand All @@ -156,7 +161,7 @@ def angle(a, b):
return np.arccos(x)


def stp(vec1, vec2, vec3):
def stp(vec1: npt.ArrayLike, vec2: npt.ArrayLike, vec3: npt.ArrayLike) -> float:
r"""Takes the scalar triple product of three vectors.
Returns the volume *V* of the parallel epiped spanned by the three
Expand All @@ -172,7 +177,7 @@ def stp(vec1, vec2, vec3):
return np.dot(vec3, np.cross(vec1, vec2))


def dihedral(ab, bc, cd):
def dihedral(ab: npt.ArrayLike, bc: npt.ArrayLike, cd: npt.ArrayLike) -> float:
r"""Returns the dihedral angle in radians between vectors connecting A,B,C,D.
The dihedral measures the rotation around bc::
Expand All @@ -194,7 +199,8 @@ def dihedral(ab, bc, cd):
return (x if stp(ab, bc, cd) <= 0.0 else -x)


def sarrus_det(matrix):
# typing: numpy
def sarrus_det(matrix: np.ndarray) -> Union[float, np.ndarray]:
"""Computes the determinant of a 3x3 matrix according to the
`rule of Sarrus`_.
Expand Down Expand Up @@ -239,7 +245,8 @@ def sarrus_det(matrix):
return _sarrus_det_multiple(m.reshape((-1, 3, 3))).reshape(shape[:-2])


def triclinic_box(x, y, z):
# typing: numpy
def triclinic_box(x: npt.ArrayLike, y: npt.ArrayLike, z: npt.ArrayLike) -> np.ndarray:
"""Convert the three triclinic box vectors to
``[lx, ly, lz, alpha, beta, gamma]``.
Expand Down Expand Up @@ -301,7 +308,8 @@ def triclinic_box(x, y, z):
return np.zeros(6, dtype=np.float32)


def triclinic_vectors(dimensions, dtype=np.float32):
# typing: numpy
def triclinic_vectors(dimensions: npt.ArrayLike, dtype: npt.DTypeLike = np.float32) -> np.ndarray:
"""Convert ``[lx, ly, lz, alpha, beta, gamma]`` to a triclinic matrix
representation.
Expand Down Expand Up @@ -399,7 +407,7 @@ def triclinic_vectors(dimensions, dtype=np.float32):
return box_matrix


def box_volume(dimensions):
def box_volume(dimensions: npt.ArrayLike) -> float:
"""Return the volume of the unitcell described by `dimensions`.
The volume is computed as the product of the box matrix trace, with the
Expand Down
42 changes: 29 additions & 13 deletions package/MDAnalysis/lib/pkdtree.py
Expand Up @@ -37,6 +37,8 @@
from .util import unique_rows

from MDAnalysis.lib.distances import apply_PBC
import numpy.typing as npt
from typing import Optional, ClassVar

__all__ = [
'PeriodicKDTree'
Expand All @@ -61,7 +63,8 @@ class PeriodicKDTree(object):
:func:`MDAnalysis.lib.distances.undo_augment` function.
"""
def __init__(self, box=None, leafsize=10):

def __init__(self, box: npt.ArrayLike = None, leafsize: int = 10) -> None:
"""
Parameters
Expand All @@ -82,7 +85,7 @@ def __init__(self, box=None, leafsize=10):
self.dim = 3 # 3D systems
self.box = box
self._built = False
self.cutoff = None
self.cutoff: Optional[float] = None

@property
def pbc(self):
Expand All @@ -95,7 +98,7 @@ def pbc(self):
"""
return self.box is not None

def set_coords(self, coords, cutoff=None):
def set_coords(self, coords: npt.ArrayLike, cutoff: Optional[float] = None) -> None:
"""Constructs KDTree from the coordinates
Wrapping of coordinates to the primary unit cell is enforced
Expand Down Expand Up @@ -126,23 +129,24 @@ def set_coords(self, coords, cutoff=None):
MDAnalysis.lib.distances.augment_coordinates
"""
# If no cutoff distance is provided but PBC aware
if self.pbc and (cutoff is None):
raise RuntimeError('Provide a cutoff distance'
' with tree.set_coords(...)')

# set coords dtype to float32
# augment coordinates will work only with float32
coords = np.asarray(coords, dtype=np.float32)

# If no cutoff distance is provided but PBC aware
if self.pbc:
self.cutoff = cutoff
if cutoff is None:
raise RuntimeError('Provide a cutoff distance'
' with tree.set_coords(...)')

# Bring the coordinates in the central cell
self.coords = apply_PBC(coords, self.box)
# generate duplicate images
self.aug, self.mapping = augment_coordinates(self.coords,
self.box,
self.cutoff)
cutoff)
# Images + coords
self.all_coords = np.concatenate([self.coords, self.aug])
self.ckdt = cKDTree(self.all_coords, leafsize=self.leafsize)
Expand All @@ -155,7 +159,8 @@ def set_coords(self, coords, cutoff=None):
self.ckdt = cKDTree(self.coords, self.leafsize)
self._built = True

def search(self, centers, radius):
# typing: numpy
def search(self, centers: npt.ArrayLike, radius: float) -> np.ndarray:
"""Search all points within radius from centers and their periodic images.
All the centers coordinates are wrapped around the central cell
Expand All @@ -179,6 +184,9 @@ def search(self, centers, radius):

# Sanity check
if self.pbc:
if self.cutoff is None:
raise ValueError(
"Cutoff needs to be provided when working with PBC.")
if self.cutoff < radius:
raise RuntimeError('Set cutoff greater or equal to the radius.')
# Bring all query points to the central cell
Expand All @@ -202,17 +210,19 @@ def search(self, centers, radius):
self._indices = np.asarray(unique_int_1d(self._indices))
return self._indices

def get_indices(self):
# typing: numpy
def get_indices(self) -> np.ndarray:
"""Return the neighbors from the last query.
Returns
------
indices : list
indices : NDArray
neighbors for the last query points and search radius
"""
return self._indices

def search_pairs(self, radius):
# typing: numpy
def search_pairs(self, radius: float) -> np.ndarray:
"""Search all the pairs within a specified radius
Parameters
Expand All @@ -229,6 +239,9 @@ def search_pairs(self, radius):
raise RuntimeError(' Unbuilt Tree. Run tree.set_coords(...)')

if self.pbc:
if self.cutoff is None:
raise ValueError(
"Cutoff needs to be provided when working with PBC.")
if self.cutoff < radius:
raise RuntimeError('Set cutoff greater or equal to the radius.')

Expand All @@ -245,7 +258,7 @@ def search_pairs(self, radius):
pairs = unique_rows(pairs)
return pairs

def search_tree(self, centers, radius):
def search_tree(self, centers: npt.ArrayLike, radius: float) -> np.ndarray:
"""
Searches all the pairs within `radius` between `centers`
and ``coords``
Expand Down Expand Up @@ -285,6 +298,9 @@ class initialization

# Sanity check
if self.pbc:
if self.cutoff is None:
raise ValueError(
"Cutoff needs to be provided when working with PBC.")
if self.cutoff < radius:
raise RuntimeError('Set cutoff greater or equal to the radius.')
# Bring all query points to the central cell
Expand Down

0 comments on commit e7ee5a4

Please sign in to comment.