Skip to content

cudart.cudaSetDevice before cudart.cudaGetDevice produces invalid results #24

@wence-

Description

@wence-

Using an environment with:

mamba create -n testing -c nvidia -c conda-forge python=3.9 'cuda-toolkit>=11.7' 'cuda-python>=11.7'
from cuda import cudart

print(cudart.cudaSetDevice(2))
print(cudart.cudaGetDevice())

=>

(<cudaError_t.cudaSuccess: 0>,)
(<cudaError_t.cudaSuccess: 0>, 0)

Expected result: the cudaGetDevice() call should return device 2, not device 0.

The problem appears to be because cudaSetDevice only calls ccudart.utils.lazyInitGlobal, whereas cudaGetDevice calls ccudart.utils.lazyInit (which calls lazyInitDevice(0)).

I think that cudaGetDevice just needs to not call lazyInit (the case of no context being in place is handled by the branch that calls cudaSetDevice(0))

https://github.com/NVIDIA/cuda-python/blob/main/cuda/_lib/ccudart/ccudart.pyx#L1039-L1045

Plausibly a patch like this?

diff --git a/cuda/_lib/ccudart/ccudart.pyx b/cuda/_lib/ccudart/ccudart.pyx
index d42d594..d7f3602 100644
--- a/cuda/_lib/ccudart/ccudart.pyx
+++ b/cuda/_lib/ccudart/ccudart.pyx
@@ -1032,9 +1032,6 @@ cdef cudaError_t _cudaGetDevice(int* device) nogil except ?cudaErrorCallRequires
     cdef cudaError_t err
     cdef ccuda.CUresult err_driver
     cdef ccuda.CUcontext context
-    err = m_global.lazyInit()
-    if err != cudaSuccess:
-        return err
 
     err_driver = ccuda._cuCtxGetCurrent(&context)
     if err_driver == ccuda.cudaError_enum.CUDA_ERROR_INVALID_CONTEXT or (err_driver == ccuda.cudaError_enum.CUDA_SUCCESS and context == NULL):
@@ -1045,14 +1042,16 @@ cdef cudaError_t _cudaGetDevice(int* device) nogil except ?cudaErrorCallRequires
         err_driver = ccuda._cuCtxGetCurrent(&context)
 
     if err_driver != ccuda.cudaError_enum.CUDA_SUCCESS:
-        _setLastError(err)
-        return err
+        _setLastError(<cudaError_t>err_driver)
+        return <cudaError_t>err
 
     found = False
     for deviceOrdinal in range(m_global._numDevices):
         if m_global._driverContext[deviceOrdinal] == context:
             found = True
             break
+    else:
+        return cudaErrorDeviceUninitialized
     device[0] = deviceOrdinal if found else 0
     return cudaSuccess
 

Note this has two other fixes:

  1. in the case where err_driver != CUDA_SUCCESS actually return the error code
  2. If after all this, we still can't find a context, return cudaErrorDeviceUninitialized (not sure if this is the correct error code)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions