Skip to content

Commit

Permalink
[python] Adjuests mpi workers based CUDA_VISIBLE_DEVICES (deepjavalib…
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored and KexinFeng committed Aug 16, 2023
1 parent 3e23857 commit 64d412d
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions engines/python/src/main/java/ai/djl/python/engine/PyEnv.java
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,14 @@ public void setTensorParallelDegree(int tensorParallelDegree) {

int getMpiWorkers() {
int gpuCount = CudaUtils.getGpuCount();
String visibleDevices = Utils.getenv("CUDA_VISIBLE_DEVICES");
if (gpuCount > 0 && visibleDevices != null) {
int visibleCount = visibleDevices.split(",").length;
if (visibleCount > gpuCount || visibleCount < 1) {
throw new AssertionError("Invalid CUDA_VISIBLE_DEVICES: " + visibleDevices);
}
gpuCount = visibleCount;
}
return gpuCount / getTensorParallelDegree();
}

Expand Down

0 comments on commit 64d412d

Please sign in to comment.