Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUSPARSE_STATUS_INVALID_VALUE when using features_at_coordinates method #308

Closed
Milogav opened this issue Jan 26, 2021 · 11 comments
Closed
Labels
bug Something isn't working

Comments

@Milogav
Copy link

Milogav commented Jan 26, 2021

Describe the bug

I have found the following error when using features_at_coodinates method of a sparse tensor in gpu. This error does not happen when using cpu.

** On entry to cusparseSpMM_bufferSize() parameter number 1 (handle) had an illegal value: bad initialization or already destroyed

Traceback (most recent call last):
   ...
    feats = tensor.features_at_coordinates(query_xyz.to(device))
  File "/home/miguel/pyenvs/pycloud/lib/python3.8/site-packages/MinkowskiEngine/MinkowskiSparseTensor.py", line 667, in features_at_coordinates
    return MinkowskiInterpolationFunction().apply(
  File "/home/miguel/pyenvs/pycloud/lib/python3.8/site-packages/MinkowskiEngine/MinkowskiInterpolation.py", line 52, in forward
    out_feat, in_map, out_map, weights = fw_fn(
RuntimeError: CUSPARSE_STATUS_INVALID_VALUE at /tmp/pip-req-build-eio2y4yy/src/spmm.cu:243

To Reproduce
Steps to reproduce the behavior.

import MinkowskiEngine as ME
import torch


xyz = torch.Tensor([[0, 0, 0.1],
                    [0, 0, 0.2],
                    [0, 0, 0.3],
                    [0, 0, 0.4]])

query_xyz = torch.Tensor([[0, 0, 0, 0.1],
                          [0, 0, 0, 0.21],
                          [0, 0, 0, 0.31],
                          [0, 0, 0, 0.41]])

features = torch.Tensor([0.9, 0.8, 0.9, 0.3])[:, None]
q_size = 0.1
q_xyz = torch.round(xyz / q_size).type(torch.int64)
bq_xyz = ME.utils.batched_coordinates([q_xyz])

device = torch.device('cpu')
tensor = ME.SparseTensor(features, bq_xyz, device=device)
feats = tensor.features_at_coordinates(query_xyz.to(device) / q_size)  # in cpu, this runs ok

device = torch.device('cuda')
tensor = ME.SparseTensor(features, bq_xyz, device=device)
feats = tensor.features_at_coordinates(query_xyz.to(device) / q_size)

Desktop (please complete the following information):

  • OS: Ubuntu 20.04
  • Python version: 3.8.5
  • CUDA version: 11.2
  • NVIDIA Driver version: 460.32.03
  • Minkowski Engine version: 0.5.0
  • Output of the following command. (If you installed the latest MinkowskiEngine, simply call MinkowskiEngine.print_diagnostics())

==========System==========
Linux-5.8.0-40-generic-x86_64-with-glibc2.29
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=20.04
DISTRIB_CODENAME=focal
DISTRIB_DESCRIPTION="Ubuntu 20.04.1 LTS"
3.8.5 (default, Jul 28 2020, 12:59:40)
[GCC 9.3.0]
==========Pytorch==========
1.7.1
torch.cuda.is_available(): True
==========NVIDIA-SMI==========
/usr/bin/nvidia-smi
Driver Version 460.32.03
CUDA Version 11.2
VBIOS Version 90.02.30.00.39
Image Version G001.0000.02.04
==========NVCC==========
/usr/local/cuda/bin/nvcc
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Mon_Nov_30_19:08:53_PST_2020
Cuda compilation tools, release 11.2, V11.2.67
Build cuda_11.2.r11.2/compiler.29373293_0
==========CC==========
/usr/bin/c++
c++ (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Copyright (C) 2019 Free Software Foundation, Inc.
This is free software; see the source for copying conditions. There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

==========MinkowskiEngine==========
0.5.0
MinkowskiEngine compiled with CUDA Support: True
NVCC version MinkowskiEngine is compiled: 11020
CUDART version MinkowskiEngine is compiled: 11020

@chrischoy chrischoy added the bug Something isn't working label Feb 5, 2021
@chrischoy
Copy link
Contributor

chrischoy commented Feb 5, 2021

The code runs fine on

==========System==========
Linux-5.4.0-65-generic-x86_64-with-glibc2.10
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=20.04
DISTRIB_CODENAME=focal
DISTRIB_DESCRIPTION="Ubuntu 20.04.1 LTS"
3.8.5 (default, Sep  4 2020, 07:30:14) 
[GCC 7.3.0]
==========Pytorch==========
1.7.0
torch.cuda.is_available(): True
==========NVIDIA-SMI==========
/usr/bin/nvidia-smi
Driver Version 450.102.04
CUDA Version 11.0
VBIOS Version 90.02.2E.00.0C
Image Version G001.0000.02.04
==========NVCC==========
/usr/local/cuda/bin/nvcc
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Mon_Nov_30_19:08:53_PST_2020
Cuda compilation tools, release 11.2, V11.2.67
Build cuda_11.2.r11.2/compiler.29373293_0
==========CC==========
CC=g++-7
/usr/bin/g++-7
g++-7 (Ubuntu 7.5.0-6ubuntu2) 7.5.0
Copyright (C) 2017 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

==========MinkowskiEngine==========
0.5.0
MinkowskiEngine compiled with CUDA Support: True
NVCC version MinkowskiEngine is compiled: 11020
CUDART version MinkowskiEngine is compiled: 11020

and

==========System==========
Linux-5.4.0-65-generic-x86_64-with-glibc2.10
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=20.04
DISTRIB_CODENAME=focal
DISTRIB_DESCRIPTION="Ubuntu 20.04.1 LTS"
3.8.5 (default, Sep  4 2020, 07:30:14) 
[GCC 7.3.0]
==========Pytorch==========
1.7.0
torch.cuda.is_available(): True
==========NVIDIA-SMI==========
/usr/bin/nvidia-smi
Driver Version 460.32.03
CUDA Version 11.2
VBIOS Version 90.02.2E.00.0C
Image Version G001.0000.02.04
==========NVCC==========
/usr/local/cuda/bin/nvcc
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Mon_Nov_30_19:08:53_PST_2020
Cuda compilation tools, release 11.2, V11.2.67
Build cuda_11.2.r11.2/compiler.29373293_0
==========CC==========
/usr/bin/c++
c++ (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Copyright (C) 2019 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

==========MinkowskiEngine==========
0.5.0
MinkowskiEngine compiled with CUDA Support: True
NVCC version MinkowskiEngine is compiled: 11020
CUDART version MinkowskiEngine is compiled: 11020
  1. Please post the full output of the print_diagnostics.

  2. Can you try the latest MinkowskiEngine from the source?

@chrischoy
Copy link
Contributor

Similar issue on #312

@Milogav
Copy link
Author

Milogav commented Feb 5, 2021

Thanks for checking this out Chris.

I've edited my previous comment to include the full output of print_diagnostics.
I've tried installing from source but unfortunately I am getting the same error.

@chrischoy
Copy link
Contributor

I've tried pytorch==1.7.1 to match your setup closely, but could not reproduce the error on

==========System==========
Linux-5.4.0-65-generic-x86_64-with-glibc2.10
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=20.04
DISTRIB_CODENAME=focal
DISTRIB_DESCRIPTION="Ubuntu 20.04.1 LTS"
3.8.5 (default, Sep  4 2020, 07:30:14) 
[GCC 7.3.0]
==========Pytorch==========
1.7.1
torch.cuda.is_available(): True
==========NVIDIA-SMI==========
/usr/bin/nvidia-smi
Driver Version 460.32.03
CUDA Version 11.2
VBIOS Version 90.02.2E.00.0C
Image Version G001.0000.02.04
==========NVCC==========
/usr/local/cuda/bin/nvcc
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Mon_Nov_30_19:08:53_PST_2020
Cuda compilation tools, release 11.2, V11.2.67
Build cuda_11.2.r11.2/compiler.29373293_0
==========CC==========
/usr/bin/c++
c++ (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Copyright (C) 2019 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

==========MinkowskiEngine==========
0.5.1
MinkowskiEngine compiled with CUDA Support: True
NVCC version MinkowskiEngine is compiled: 11020
CUDART version MinkowskiEngine is compiled: 11020

It is very strange that there are two issues on this, but I could not reproduce this on all systems that I have.

Do you have a different machine or could try a docker image nvcr.io/nvidia/pytorch:20.10-py3 to see if it works?

@Milogav
Copy link
Author

Milogav commented Feb 8, 2021

I tried in another machine and I am getting the same error. Setup in this new machine is similar to the first one though:

==========System==========
Linux-5.8.0-41-generic-x86_64-with-glibc2.29
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=20.04
DISTRIB_CODENAME=focal
DISTRIB_DESCRIPTION="Ubuntu 20.04.2 LTS"
3.8.5 (default, Jul 28 2020, 12:59:40) 
[GCC 9.3.0]
==========Pytorch==========
1.7.1
torch.cuda.is_available(): True
==========NVIDIA-SMI==========
/usr/bin/nvidia-smi
Driver Version 460.32.03
CUDA Version 11.2
VBIOS Version 86.04.50.40.59
Image Version G001.0000.01.04
==========NVCC==========
/usr/local/cuda/bin/nvcc
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Mon_Nov_30_19:08:53_PST_2020
Cuda compilation tools, release 11.2, V11.2.67
Build cuda_11.2.r11.2/compiler.29373293_0
==========CC==========
/usr/bin/c++
c++ (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Copyright (C) 2019 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

==========MinkowskiEngine==========
0.5.1
MinkowskiEngine compiled with CUDA Support: True
NVCC version MinkowskiEngine is compiled: 11020
CUDART version MinkowskiEngine is compiled: 11020

On the docker image you suggest, the code runs fine with no error.

@chrischoy
Copy link
Contributor

chrischoy commented Feb 9, 2021

Can you try the latest MinkowskiEngine with debug flag?

git clone http://github.com/NVIDIA/MinkowskiEngine
cd MinkowskiEngine
python setup.py install --debug

And then, can you please post the entire error message when you run the above code?

@Milogav
Copy link
Author

Milogav commented Feb 9, 2021

Sure, this was the full output when installing from source with debug flag:

/home/.../repos/MinkowskiEngine/src/coordinate_map_manager.cpp:386 initializing a map with tensor stride: [1, 1, 1] string id: 
/home/.../repos/MinkowskiEngine/src/coordinate_map_manager.cpp:63 initialize_and_map
/home/.../repos/MinkowskiEngine/src/coordinate_map.hpp:186 Allocate 4 coordinates.
/home/.../repos/MinkowskiEngine/src/coordinate_map.hpp:133 tensor stride: [1, 1, 1]
/home/.../repos/MinkowskiEngine/src/coordinate_map_manager.cpp:71 mapping size: 4
/home/.../repos/MinkowskiEngine/src/coordinate_map_manager.hpp:247 insert map with tensor_stride [1, 1, 1]
/home/.../repos/MinkowskiEngine/src/coordinate_map_manager.hpp:251 map insertion 1
/home/.../repos/MinkowskiEngine/src/coordinate_map_cpu.hpp:149 neighbor_volume : 8 num_tfield: 4
/home/.../repos/MinkowskiEngine/src/coordinate_map_cpu.hpp:173 kernel map with 24 chunks and 1 stride.
/home/.../repos/MinkowskiEngine/src/coordinate_map_cpu.hpp:261 kernel index 0 size: 4
/home/.../repos/MinkowskiEngine/src/coordinate_map_cpu.hpp:261 kernel index 1 size: 3
/home/.../repos/MinkowskiEngine/src/coordinate_map_cpu.hpp:261 kernel index 2 size: 0
/home/.../repos/MinkowskiEngine/src/coordinate_map_cpu.hpp:261 kernel index 3 size: 0
/home/.../repos/MinkowskiEngine/src/coordinate_map_cpu.hpp:261 kernel index 4 size: 0
/home/.../repos/MinkowskiEngine/src/coordinate_map_cpu.hpp:261 kernel index 5 size: 0
/home/.../repos/MinkowskiEngine/src/coordinate_map_cpu.hpp:261 kernel index 6 size: 0
/home/.../repos/MinkowskiEngine/src/coordinate_map_cpu.hpp:261 kernel index 7 size: 0
/home/.../repos/MinkowskiEngine/src/interpolation_cpu.cpp:68 out_feat with size 4 1
/home/.../repos/MinkowskiEngine/src/interpolation_cpu.cpp:74 InterpolationForwardKernelCPU
/home/.../repos/MinkowskiEngine/src/coordinate_map_manager.cpp:386 initializing a map with tensor stride: [1, 1, 1] string id: 
/home/.../repos/MinkowskiEngine/src/coordinate_map.hpp:186 Allocate 4 coordinates.
/home/.../repos/MinkowskiEngine/src/coordinate_map.hpp:133 tensor stride: [1, 1, 1]
/home/.../repos/MinkowskiEngine/src/coordinate_map_gpu.cuh:205 Reserve map of 16 for concurrent_unordered_map of size 4
/home/.../repos/MinkowskiEngine/src/coordinate_map_gpu.cuh:211 Done concurrent_unordered_map creation
/home/.../repos/MinkowskiEngine/src/coordinate_map_gpu.cuh:214 Reserved concurrent_unordered_map
/home/.../repos/MinkowskiEngine/src/coordinate_map_gpu.cuh:176 device tensor_stride: [1, 1, 1]
/home/.../repos/MinkowskiEngine/src/coordinate_map_manager.cu:63 inserting 4 coordinates with coordinate_size: 4
/home/.../repos/MinkowskiEngine/src/coordinate_map_manager.cu:68 insert_and_map
/home/.../repos/MinkowskiEngine/src/coordinate_map_gpu.cu:287 insert_and_map
/home/.../repos/MinkowskiEngine/src/coordinate_map_gpu.cu:204 key iterator length 4
/home/.../repos/MinkowskiEngine/src/coordinate_map_gpu.cu:223 Reserved and copied 4 x 4 coordinates
/home/.../repos/MinkowskiEngine/src/coordinate_map_gpu.cu:237 Map size: 4
/home/.../repos/MinkowskiEngine/src/coordinate_map_gpu.cu:253 Number of successful insertion 4
/home/.../repos/MinkowskiEngine/src/coordinate_map_manager.cu:71 mapping size: 4
/home/.../repos/MinkowskiEngine/src/coordinate_map_manager.hpp:247 insert map with tensor_stride [1, 1, 1]
/home/.../repos/MinkowskiEngine/src/coordinate_map_manager.hpp:251 map insertion 1
/home/.../repos/MinkowskiEngine/src/coordinate_map_manager.cu:81 Reserve mapping torch output tensors.
/home/.../repos/MinkowskiEngine/src/coordinate_map_manager.cu:92 cuda_copy_n with num_blocks: 1 mapping.size(): 4
/home/.../repos/MinkowskiEngine/src/coordinate_map_manager.cu:102 cuda_copy_n with num_inv_blocks: 1 inverse_mapping.size(): 4
/home/.../repos/MinkowskiEngine/src/coordinate_map_gpu.cu:2122 map size 4
/home/.../repos/MinkowskiEngine/src/coordinate_map_gpu.cu:1961 neighbor_volume: 8 num_tfield: 4 num_threads: 32
/home/.../repos/MinkowskiEngine/src/coordinate_map_gpu.cu:1975 Shared memory size: 4112
/home/.../repos/MinkowskiEngine/src/coordinate_map_gpu.cu:2001 number_of_valid: 7
/home/.../repos/MinkowskiEngine/src/spmm.cu:214 Is sorted 0
/home/.../repos/MinkowskiEngine/src/spmm.cu:222 Allocated sorted row col val 7
/home/.../repos/MinkowskiEngine/src/spmm.cu:243 sorted row 0
/home/.../repos/MinkowskiEngine/src/spmm.cu:277 initialized matrices 0
 ** On entry to cusparseSpMM_bufferSize() parameter number 1 (handle) had an illegal value: bad initialization or already destroyed

Traceback (most recent call last):
  File "/home/.../repos/pycloud/scripts/test_me_bug.py", line 27, in <module>
    feats = tensor.features_at_coordinates(query_xyz.to(device) / q_size)
  File "/home/.../pyenvs/mink_clean_env/lib/python3.8/site-packages/MinkowskiEngine-0.5.1-py3.8-linux-x86_64.egg/MinkowskiEngine/MinkowskiSparseTensor.py", line 662, in features_at_coordinates
    return MinkowskiInterpolationFunction().apply(
  File "/home/.../pyenvs/mink_clean_env/lib/python3.8/site-packages/MinkowskiEngine-0.5.1-py3.8-linux-x86_64.egg/MinkowskiEngine/MinkowskiInterpolation.py", line 52, in forward
    out_feat, in_map, out_map, weights = fw_fn(
RuntimeError: CUSPARSE_STATUS_INVALID_VALUE at /home/.../repos/MinkowskiEngine/src/spmm.cu:280

@chrischoy
Copy link
Contributor

chrischoy commented Feb 9, 2021

Thanks for the quick reply.

The log says

 ** On entry to cusparseSpMM_bufferSize() parameter number 1 (handle) had an illegal value: bad initialization or already destroyed

Which says that the cusparse_handle is invalid.
I used pytorch's auto cusparse_handle = at::cuda::getCurrentCUDASparseHandle(); to get the cusparse handle, but it seems that getCurrentCUDASparseHandle might be faulty on some systems.

I created a branch cusparse_handle_issue308 to create a custom cusparse handle instead of the pytorch's getCurrentCUDASparseHandle. Please follow the instruction to install the branch with the debug flag and could you post the output of your test script?

git clone https://github.com/NVIDIA/MinkowskiEngine
cd MinkowskiEngine
git checkout cusparse_handle_issue308
python setup.py install --debug

@heiwang1997
Copy link

Hi Chris, I'm also experiencing this issue after upgrading to the newest version (CPU version works fine). But I am able to get rid of this cusparse error using your cusparse_handle_issue308 branch. Thanks.

Here is my environment:

==========System==========
Linux-5.8.0-38-generic-x86_64-with-glibc2.10
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=20.04
DISTRIB_CODENAME=focal
DISTRIB_DESCRIPTION="Ubuntu 20.04.1 LTS"
3.8.3 (default, Jul  2 2020, 16:21:59)
[GCC 7.3.0]
==========Pytorch==========
1.7.1+cu110
torch.cuda.is_available(): True
==========NVIDIA-SMI==========
/usr/bin/nvidia-smi
Driver Version 460.32.03
CUDA Version 11.2
VBIOS Version 86.04.60.00.B4
Image Version G001.0000.01.04
==========NVCC==========
/usr/local/cuda/bin/nvcc
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Mon_Oct_12_20:09:46_PDT_2020
Cuda compilation tools, release 11.1, V11.1.105
Build cuda_11.1.TC455_06.29190527_0
==========CC==========
/usr/bin/c++
c++ (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Copyright (C) 2019 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

==========MinkowskiEngine==========
0.5.1
MinkowskiEngine compiled with CUDA Support: True
NVCC version MinkowskiEngine is compiled: 11010
CUDART version MinkowskiEngine is compiled: 11010

@Milogav
Copy link
Author

Milogav commented Feb 9, 2021

Same here, the test script runs fine for me using the cusparse_handle_issue308 branch. Thanks!

@chrischoy
Copy link
Contributor

Merged to the master.

AlexeyGB pushed a commit to AlexeyGB/MinkowskiEngine that referenced this issue Feb 23, 2021
…IDIA#308) (NVIDIA#315)

* force initialize cusparse handle

* replace all at::cuda::getCurrentCUDASparseHandle with custom func

* change log
Tanazzah pushed a commit to Tanazzah/MinkowskiEngine that referenced this issue Feb 9, 2024
…IDIA#308) (NVIDIA#315)

* force initialize cusparse handle

* replace all at::cuda::getCurrentCUDASparseHandle with custom func

* change log
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants