Skip to content

Commit

Permalink
Merge pull request #219 from DerwenAI/fix-cudf-dep
Browse files Browse the repository at this point in the history
addresses #198 with improved GPU detection/handling
  • Loading branch information
ceteri authored Feb 10, 2022
2 parents 938a2cc + 1e886d9 commit 58cf33e
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 17 deletions.
7 changes: 7 additions & 0 deletions changelog.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# `kglab` changelog

## 0.4.3

2022-02-10

* improved GPU detection when RAPIDS is not installed


## 0.4.2

2021-12-13
Expand Down
34 changes: 18 additions & 16 deletions kglab/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,39 +9,41 @@
import numpy as np # type: ignore # pylint: disable=E0401
import pandas as pd # type: ignore # pylint: disable=E0401

GPU_COUNT: int = 0


def get_gpu_count () -> int:
"""
Special handling for detecting GPU availability: an approach
recommended by the NVIDIA RAPIDS engineering team, since `nvml`
recommended by the NVidia RAPIDS engineering team, since `nvml`
bindings are difficult for Python libraries to keep updated.
This has the side-effect of importing the `cuDF` library, when
GPUs are available.
returns:
count of available GPUs
count of available GPUs, where `0` means none or disabled.
"""
global GPU_COUNT

if GPU_COUNT < 0:
return 0

try:
import pynvml # type: ignore # pylint: disable=E0401
pynvml.nvmlInit()

gpu_count = pynvml.nvmlDeviceGetCount()

if gpu_count > 0:
import cudf # type: ignore # pylint: disable=E0401,W0611,W0621
# print(f"using {gpu_count} GPUs")

GPU_COUNT = pynvml.nvmlDeviceGetCount()
except Exception: # pylint: disable=W0703
gpu_count = 0
GPU_COUNT = -1

return gpu_count
return GPU_COUNT


## NB: workaround for GitHub CI

if get_gpu_count() > 0:
import cudf # type: ignore # pylint: disable=E0401
try:
import cudf # type: ignore # pylint: disable=E0401
except Exception as e: # pylint: disable=W0703
# turn off GPU usage
#print(e)
GPU_COUNT = -1


def calc_quantile_bins (
Expand Down
2 changes: 1 addition & 1 deletion kglab/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
## Python version checking

MIN_PY_VERSION: typing.Tuple = (3, 7,)
__version__: str = "0.4.2"
__version__: str = "0.4.3"


def _versify (
Expand Down

0 comments on commit 58cf33e

Please sign in to comment.