Skip to content

Commit

Permalink
Merge pull request #255 from carterbox/typing
Browse files Browse the repository at this point in the history
REF: Add type hints to some modules
  • Loading branch information
carterbox committed Feb 7, 2023
2 parents e10ca08 + d565b40 commit d6e56da
Show file tree
Hide file tree
Showing 11 changed files with 294 additions and 156 deletions.
7 changes: 4 additions & 3 deletions src/tike/communicators/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ class Comm:
def __init__(
self,
gpu_count,
mpi=NoMPIComm,
pool=ThreadPool,
mpi: typing.Union[typing.Type[MPIComm],
typing.Type[NoMPIComm]] = NoMPIComm,
pool: typing.Type[ThreadPool] = ThreadPool,
):
if isinstance(mpi, NoMPIComm):
self.use_mpi = False
Expand Down Expand Up @@ -95,7 +96,7 @@ def Allreduce_reduce_cpu(
def Allreduce_mean(
self,
x: typing.List[cp.ndarray],
axis: typing.Union[int, typing.List[int]] = 0,
axis: typing.Union[int, None] = 0,
) -> cp.ndarray:
"""Multi-process multi-GPU based mean."""
with cp.cuda.Device(self.pool.workers[0]):
Expand Down
53 changes: 26 additions & 27 deletions src/tike/communicators/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import warnings

import cupy as cp
import cupy.typing as cpt
import numpy as np


Expand Down Expand Up @@ -82,25 +81,25 @@ def num_workers(self):

def _copy_to(
self,
x: typing.Union[cp.array, np.array],
x: typing.Union[cp.ndarray, np.ndarray],
worker: int,
) -> cp.array:
) -> cp.ndarray:
with self.Device(worker):
return self.xp.asarray(x)

def _copy_host(
self,
x: cp.array,
x: cp.ndarray,
worker: int,
) -> np.array:
) -> np.ndarray:
with self.Device(worker):
return self.xp.asnumpy(x)

def bcast(
self,
x: typing.List[typing.Union[cp.array, np.array]],
x: typing.List[typing.Union[cp.ndarray, np.ndarray]],
stride: int = 1,
) -> typing.List[cp.array]:
) -> typing.List[cp.ndarray]:
"""Send each x to all device groups.
Parameters
Expand All @@ -125,10 +124,10 @@ def f(worker):

def gather(
self,
x: typing.List[cp.array],
worker: int = None,
x: typing.List[cp.ndarray],
worker: typing.Union[int, None] = None,
axis: typing.Union[int, None] = 0,
) -> cp.array:
) -> cp.ndarray:
"""Concatenate x on a single worker along the given axis.
Parameters
Expand All @@ -153,9 +152,9 @@ def gather(

def gather_host(
self,
x: typing.List[cp.array],
x: typing.List[cp.ndarray],
axis: typing.Union[int, None] = 0,
) -> np.array:
) -> np.ndarray:
"""Concatenate x on host along the given axis.
Parameters
Expand All @@ -182,9 +181,9 @@ def f(x, worker):

def all_gather(
self,
x: typing.List[cp.array],
x: typing.List[cp.ndarray],
axis: typing.Union[int, None] = 0,
) -> typing.List[cp.array]:
) -> typing.List[cp.ndarray]:
"""Concatenate x on all workers along the given axis.
Parameters
Expand All @@ -201,9 +200,9 @@ def f(worker):

def scatter(
self,
x: typing.List[cpt.NDArray],
x: typing.List[cp.ndarray],
stride: int = 1,
) -> typing.List[cpt.NDArray]:
) -> typing.List[cp.ndarray]:
"""Scatter each x with given stride.
scatter_bcast(x=[0, 1], stride=3) -> [0, 0, 0, 1, 1, 1]
Expand Down Expand Up @@ -231,9 +230,9 @@ def f(worker):

def scatter_bcast(
self,
x: typing.List[cpt.NDArray],
x: typing.List[cp.ndarray],
stride: int = 1,
) -> typing.List[cpt.NDArray]:
) -> typing.List[cp.ndarray]:
"""Scatter each x with given stride and then broadcast nearby.
scatter_bcast(x=[0, 1], stride=3) -> [0, 0, 0, 1, 1, 1]
Expand Down Expand Up @@ -268,10 +267,10 @@ def f(worker):

def reduce_gpu(
self,
x: typing.List[cp.array],
x: typing.List[cp.ndarray],
stride: int = 1,
workers: typing.Union[typing.List[int], None] = None,
) -> typing.List[cp.array]:
workers: typing.Union[typing.Tuple[int, ...], None] = None,
) -> typing.List[cp.ndarray]:
"""Reduce x by addition to a device group from all other devices.
reduce_gpu([0, 1, 2, 3, 4], stride=2) -> [6, 4]
Expand Down Expand Up @@ -303,18 +302,18 @@ def f(worker):
workers = self.workers[:stride] if workers is None else workers
return self.map(f, workers, workers=workers)

def reduce_cpu(self, x: typing.List[cp.array]) -> np.array:
def reduce_cpu(self, x: typing.List[cp.ndarray]) -> np.ndarray:
"""Reduce x by addition from all GPUs to a CPU buffer."""
assert len(x) <= self.num_workers, (
f"{len(x)} work is more than {self.num_workers} workers")
return np.sum(self.map(self._copy_host, x, self.workers), axis=0)

def reduce_mean(
self,
x: typing.List[cp.array],
axis: typing.Union[int, typing.List[int]] = 0,
x: typing.List[cp.ndarray],
axis: typing.Union[int, None] = 0,
worker: typing.Union[int, None] = None,
) -> cp.array:
) -> cp.ndarray:
"""Reduce x by addition to one GPU from all other GPUs."""
worker = self.workers[0] if worker is None else worker
return cp.mean(
Expand All @@ -325,9 +324,9 @@ def reduce_mean(

def allreduce(
self,
x: typing.List[cp.array],
x: typing.List[cp.ndarray],
stride: typing.Union[int, None] = None,
) -> typing.List[cp.array]:
) -> typing.List[cp.ndarray]:
"""All-reduce x by addition within device groups.
allreduce([0, 1, 2, 3, 4, 5, 6], stride=2) -> [1, 1, 5, 5, 9, 9, 6]
Expand Down
7 changes: 6 additions & 1 deletion src/tike/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
This module exists because support for broadcasting and complex values is
spotty in the NumPy and CuPy libraries.
"""
import typing

import numpy as np

Expand Down Expand Up @@ -56,7 +57,11 @@ def lstsq(a, b, weights=None):
return x


def orthogonalize_gs(x, axis=-1, N=None):
def orthogonalize_gs(
x,
axis: typing.Union[int, typing.Tuple[int, ...]] = -1,
N: typing.Union[int, None] = None,
):
"""Gram-schmidt orthogonalization for complex arrays.
Parameters
Expand Down
44 changes: 38 additions & 6 deletions src/tike/operators/cupy/cache.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
__author__ = "Daniel Ching"
__copyright__ = "Copyright (c) 2020, UChicago Argonne, LLC."

import typing

from cupyx.scipy.fft import fftn, ifftn, get_fft_plan
import cupy.cuda.cufft
import cupy.cuda.runtime
import numpy.typing as npt
import numpy as np


class CachedFFT():
Expand All @@ -24,9 +29,14 @@ def __exit__(self, type, value, traceback):
self.plan_cache.clear()
del self.plan_cache

def _get_fft_plan(self, a, axes=None, **kwargs):
def _get_fft_plan(
self,
a: npt.NDArray,
axes: typing.Tuple[int, ...] = (),
**kwargs,
) -> typing.Union[cupy.cuda.cufft.Plan1d, cupy.cuda.cufft.PlanNd]:
"""Cache multiple FFT plans at the same time."""
axes = tuple(range(a.ndim)) if axes is None else axes
axes = tuple(range(a.ndim)) if axes == () else axes
key = (*a.shape, *axes, cupy.cuda.runtime.getDevice())
if key in self.plan_cache:
plan = self.plan_cache[key]
Expand All @@ -35,16 +45,38 @@ def _get_fft_plan(self, a, axes=None, **kwargs):
self.plan_cache[key] = plan
return plan

def _fft2(self, a, *args, axes=(-2, -1), **kwargs):
def _fft2(
self,
a: npt.NDArray,
*args,
axes: typing.Tuple[int, int] = (-2, -1),
**kwargs,
) -> npt.NDArray[np.csingle]:
return self._fftn(a, *args, axes=axes, **kwargs)

def _ifft2(self, a, *args, axes=(-2, -1), **kwargs):
def _ifft2(
self,
a: npt.NDArray,
*args,
axes: typing.Tuple[int, int] = (-2, -1),
**kwargs,
) -> npt.NDArray[np.csingle]:
return self._ifftn(a, *args, axes=axes, **kwargs)

def _ifftn(self, a, *args, **kwargs):
def _ifftn(
self,
a: npt.NDArray,
*args,
**kwargs,
) -> npt.NDArray[np.csingle]:
with self._get_fft_plan(a, **kwargs):
return ifftn(a, *args, **kwargs)

def _fftn(self, a, *args, **kwargs):
def _fftn(
self,
a: npt.NDArray,
*args,
**kwargs,
) -> npt.NDArray[np.csingle]:
with self._get_fft_plan(a, **kwargs):
return fftn(a, *args, **kwargs)
55 changes: 28 additions & 27 deletions src/tike/operators/cupy/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from importlib_resources import files

import cupy as cp
import numpy.typing as npt
import numpy as np

from .operator import Operator
Expand All @@ -17,7 +18,7 @@
_adj_patch = cp.RawKernel(_cu_source, "adj_patch")


def _next_power_two(v):
def _next_power_two(v: int) -> int:
"""Return the next highest power of 2 of 32-bit v.
https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2
Expand Down Expand Up @@ -49,30 +50,30 @@ class Patch(Operator):

def fwd(
self,
images,
positions,
patches=None,
patch_width=None,
height=None,
width=None,
nrepeat=1,
images: npt.NDArray[np.csingle],
positions: npt.NDArray[np.single],
patches: npt.NDArray[np.csingle] = None,
patch_width: int = 0,
height: int = 0,
width: int = 0,
nrepeat: int = 1,
):
patch_width = patches.shape[-1] if patch_width is None else patch_width
patch_width = patches.shape[-1] if patch_width == 0 else patch_width
if patches is None:
patches = cp.zeros(
(*positions.shape[:-2], positions.shape[-2] * nrepeat,
patch_width, patch_width),
dtype='complex64',
shape=(*positions.shape[:-2], positions.shape[-2] * nrepeat,
patch_width, patch_width),
dtype=np.csingle,
)
assert patch_width <= patches.shape[-1]
assert images.shape[:-2] == positions.shape[:-2]
assert positions.shape[:-2] == patches.shape[:-3], (positions.shape,
patches.shape)
assert positions.shape[-2] * nrepeat == patches.shape[-3]
assert positions.shape[-1] == 2, positions.shape
assert images.dtype == 'complex64', f"{images.dtype}"
assert patches.dtype == 'complex64', f"{patches.dtype}"
assert positions.dtype == 'float32', f"{positions.dtype}"
assert images.dtype == np.csingle, f"{images.dtype}"
assert patches.dtype == np.csingle, f"{patches.dtype}"
assert positions.dtype == np.single, f"{positions.dtype}"
nimage = int(np.prod(images.shape[:-2]))
grids = (
positions.shape[-2],
Expand Down Expand Up @@ -100,20 +101,20 @@ def fwd(

def adj(
self,
positions,
patches,
images=None,
patch_width=None,
height=None,
width=None,
nrepeat=1,
positions: npt.NDArray[np.single],
patches: npt.NDArray[np.csingle],
images: npt.NDArray[np.csingle] = None,
patch_width: int = 0,
height: int = 0,
width: int = 0,
nrepeat: int = 1,
):
patch_width = patches.shape[-1] if patch_width is None else patch_width
patch_width = patches.shape[-1] if patch_width == 0 else patch_width
assert patch_width <= patches.shape[-1]
if images is None:
images = cp.zeros(
(*positions.shape[:-2], height, width),
dtype='complex64',
dtype=cp.csingle,
)
leading = images.shape[:-2]
height, width = images.shape[-2:]
Expand All @@ -124,9 +125,9 @@ def adj(
K = patches.shape[-3]
assert (N * nrepeat) % K == 0 and K >= nrepeat
assert patches.shape[-1] == patches.shape[-2]
assert images.dtype == 'complex64'
assert patches.dtype == 'complex64'
assert positions.dtype == 'float32'
assert images.dtype == np.csingle
assert patches.dtype == np.csingle
assert positions.dtype == np.single
nimage = int(np.prod(images.shape[:-2]))
grids = (
positions.shape[-2],
Expand Down

0 comments on commit d6e56da

Please sign in to comment.