Skip to content

Commit

Permalink
Merge pull request #252 from carterbox/fix-disabled-threadpool
Browse files Browse the repository at this point in the history
BUG: Fix disabled threadpool
  • Loading branch information
carterbox committed Jan 31, 2023
2 parents f22501a + ecd11ca commit 3a4b85c
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 50 deletions.
16 changes: 14 additions & 2 deletions .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ jobs:
steps:
- uses: actions/checkout@v3

- run: >
conda config --remove channels defaults || true
conda config --add channels conda-forge
conda config --show channels
name: Configure Conda to only use conda-forge
- run: >
conda create --quiet --yes
-n tike
Expand Down Expand Up @@ -63,7 +69,7 @@ jobs:
export TIKE_TEST_CI
pytest -vs tests
name: Run tests
- run: |
cd tests/result
zip -r9 ../../result.zip .
Expand All @@ -89,6 +95,12 @@ jobs:
steps:
- uses: actions/checkout@v3

- run: >
conda config --remove channels defaults || true
conda config --add channels conda-forge
conda config --show channels
name: Configure Conda to only use conda-forge
- run: >
conda create --quiet --yes
-n tike
Expand Down Expand Up @@ -118,7 +130,7 @@ jobs:
export TIKE_TEST_CI
mpiexec -n 2 python -m pytest -vs tests
name: Run tests with MPI
- run: |
cd tests/result
zip -r9 ../../result.zip .
Expand Down
72 changes: 29 additions & 43 deletions src/tike/communicators/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import warnings

import cupy as cp
import cupy.typing as cpt
import numpy as np


Expand Down Expand Up @@ -41,13 +42,13 @@ class ThreadPool(ThreadPoolExecutor):

def __init__(
self,
workers,
workers: typing.Union[int, typing.Tuple[int, ...]],
xp=cp,
device_count=None,
device_count: typing.Union[int, None] = None,
):
self.device_count = cp.cuda.runtime.getDeviceCount(
) if device_count is None else device_count
if type(workers) is int:
if isinstance(workers, int):
if workers < 1:
raise ValueError(f"Provide workers > 0, not {workers}.")
if workers > self.device_count:
Expand All @@ -57,7 +58,7 @@ def __init__(
workers = min(workers, self.device_count)
if workers == 1:
# Respect "with cp.cuda.Device()" blocks for single thread
workers = (cp.cuda.Device().id,)
workers = (int(cp.cuda.Device().id),)
else:
workers = tuple(range(workers))
for w in workers:
Expand Down Expand Up @@ -113,7 +114,8 @@ def bcast(
"""
assert stride >= 1, f"Stride cannot be less than 1; it is {stride}."
assert stride <= len(x), f"Stride cannot be greater than {len(x)}; it is {stride}."
assert stride <= len(
x), f"Stride cannot be greater than {len(x)}; it is {stride}."

def f(worker):
idx = self.workers.index(worker) % stride
Expand All @@ -140,7 +142,8 @@ def gather(
merge = self.xp.stack
axis = 0
else:
assert x[0].ndim > 0, "Cannot concatenate zero-dimensional arrays; use `axis=None`"
assert x[
0].ndim > 0, "Cannot concatenate zero-dimensional arrays; use `axis=None`"
merge = self.xp.concatenate
with self.Device(worker):
return merge(
Expand Down Expand Up @@ -169,7 +172,8 @@ def f(x, worker):
merge = np.stack
axis = 0
else:
assert x[0].ndim > 0, "Cannot concatenate zero-dimensional arrays; use `axis=None`"
assert x[
0].ndim > 0, "Cannot concatenate zero-dimensional arrays; use `axis=None`"
merge = np.concatenate
return merge(
self.map(f, x, self.workers),
Expand Down Expand Up @@ -197,9 +201,9 @@ def f(worker):

def scatter(
self,
x: typing.List[cp.array],
x: typing.List[cpt.NDArray],
stride: int = 1,
) -> typing.List[cp.array]:
) -> typing.List[cpt.NDArray]:
"""Scatter each x with given stride.
scatter_bcast(x=[0, 1], stride=3) -> [0, 0, 0, 1, 1, 1]
Expand All @@ -218,7 +222,6 @@ def scatter(
"""
assert stride >= 1, f"Stride cannot be less than 1; it is {stride}."
assert stride <= len(x), f"Stride cannot be greater than {len(x)}; it is {stride}."

def f(worker):
idx = self.workers.index(worker) // stride
Expand All @@ -228,9 +231,9 @@ def f(worker):

def scatter_bcast(
self,
x: typing.List[cp.array],
x: typing.List[cpt.NDArray],
stride: int = 1,
) -> typing.List[cp.array]:
) -> typing.List[cpt.NDArray]:
"""Scatter each x with given stride and then broadcast nearby.
scatter_bcast(x=[0, 1], stride=3) -> [0, 0, 0, 1, 1, 1]
Expand All @@ -250,36 +253,18 @@ def scatter_bcast(
"""
assert stride >= 1, f"Stride cannot be less than 1; it is {stride}."
assert stride <= len(x), f"Stride cannot be greater than {len(x)}; it is {stride}."

def s(bworkers, chunk):
# First, scatter to leader of each group
leaders = self.workers[::stride]

def b(worker):
return self._copy_to(chunk, worker)
def f(worker):
idx = leaders.index(worker)
return self._copy_to(x[idx], worker)

return list(self.map(b, bworkers, workers=bworkers))
scattered = self.map(f, leaders)

bworkers = []
if stride == 1:
sworkers = self.workers[:len(x)]
for i in range(len(x)):
bworkers.append(self.workers[i::len(x)])
else:
sworkers = self.workers[::stride]
for i in sworkers:
bworkers.append(self.workers[i:(i + stride)])

a = self.map(s, bworkers, x, workers=sworkers)
output = [None] * self.num_workers
i, j = 0, 0
for si in bworkers:
for bi in si:
output[bi] = a[i][j]
j += 1
i += 1
j = 0

return output
# Then, broadcast within each group
return self.scatter(scattered, stride=stride)

def reduce_gpu(
self,
Expand All @@ -302,7 +287,8 @@ def reduce_gpu(
"""
assert stride >= 1, f"Stride cannot be less than 1; it is {stride}."
assert stride <= len(x), f"Stride cannot be greater than {len(x)}; it is {stride}."
assert stride <= len(
x), f"Stride cannot be greater than {len(x)}; it is {stride}."

# if self.num_workers == 1:
# return x
Expand Down Expand Up @@ -361,7 +347,8 @@ def allreduce(

stride = len(x) if stride is None else stride
assert stride >= 1, f"Stride cannot be less than 1; it is {stride}."
assert stride <= len(x), f"Stride cannot be greater than {len(x)}; it is {stride}."
assert stride <= len(
x), f"Stride cannot be greater than {len(x)}; it is {stride}."

def f(worker):
group_start = stride * (self.workers.index(worker) // stride)
Expand All @@ -379,7 +366,7 @@ def map(
self,
func,
*iterables,
workers: typing.Union[typing.List[int], None] = None,
workers: typing.Union[typing.Tuple[int, ...], None] = None,
**kwargs,
) -> list:
"""ThreadPoolExecutor.map, but wraps call in a cuda.Device context."""
Expand All @@ -390,5 +377,4 @@ def f(worker, *args):

workers = self.workers if workers is None else workers

# return list(super().map(f, workers, *iterables))
return list(map(f, workers, *iterables))
return list(super().map(f, workers, *iterables))
10 changes: 5 additions & 5 deletions src/tike/lamino/bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,11 @@ def reconstruct(
"iterations.", rtol, i)
break

result['cost'] = operator.asarray(costs)
result['cost'] = operator.asnumpy(result['cost'])
result['obj'] = comm.pool.gather_host(result['obj'][:obj_split])
result['step_length'] = operator.asnumpy(result['step_length'])
return result
result['cost'] = operator.asarray(costs)
result['cost'] = operator.asnumpy(result['cost'])
result['obj'] = comm.pool.gather_host(result['obj'][:obj_split])
result['step_length'] = operator.asnumpy(result['step_length'])
return result
else:
raise ValueError(
"The '{}' algorithm is not an available.".format(algorithm))

0 comments on commit 3a4b85c

Please sign in to comment.