Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ImportError: cannot import name 'nvcc' #52

Closed
safooray opened this issue May 29, 2020 · 33 comments
Closed

ImportError: cannot import name 'nvcc' #52

safooray opened this issue May 29, 2020 · 33 comments

Comments

@safooray
Copy link

from tvm.contrib import nvcc
ImportError: cannot import name 'nvcc'

I get this when trying to compile the kernel from scratch. Did I miss something in the cmake config? I can import a lot of TVM modules but not nvcc.

My cuda version is: Cuda compilation tools, release 10.0, V10.0.130

@ibeltagy
Copy link
Collaborator

Are you trying to compile the kernel? are you using the docker image?

@safooray
Copy link
Author

Trying to compile the kernel, having installed TVM from source based on these instructions:
https://docs.tvm.ai/install/from_source.html#build-the-shared-library

@ibeltagy
Copy link
Collaborator

ibeltagy commented May 29, 2020

I would strongly suggest you follow the instructions in the cheatsheet.txt and use the docker image. TVM codebase is changing all the time, and I am using a previous release (here, v0.6.0). If you want to use master, lots of things will need to change.

@safooray
Copy link
Author

Thank you for your suggestion. This particular error was due to python looking at longformer's tvm directory instead of the installed tvm library. I renamed longformer's tvm directory and changed the tvm.module import accordingly but now I get a seg fault!

I think I'll take your suggestion re using docker.

@ibeltagy
Copy link
Collaborator

Ah, right, this issue is already mentioned here. So yes, please try the instructions in cheatsheet.txt and run via docker.

but now I get a seg fault!

Did you add import ipdb; ipdb.set_trace() in the code somewhere?

@safooray
Copy link
Author

No I did not, are you saying that could be source of the seg fault or suggesting I use ipdb for debugging?

@ibeltagy
Copy link
Collaborator

tvm throws segfaults for weird reasons, ipdb.set_trace() is one of them. For ipdb to work, you have to call ipdb.set_trace() once before import tvm, then call it again where you want to have a breakpoint 🤷‍♂️

@safooray
Copy link
Author

I was finally able to compile from scratch without using the docker image.

I had to change some tvm imports in diagonaled_mm_tvm.py to tvm.te, rename longformer's tvm directory to something else, install longformer again, and then be very careful about my path and pythonpath environment variables.

I still need to use the newly compiled module in pretraining and make sure it works.

@ibeltagy
Copy link
Collaborator

ibeltagy commented May 30, 2020

Glad it is working.

tvm.te

Yes, this is the new API with tvm0.7.x

I still need to use the newly compiled module in pretraining and make sure it works.

An easier solution for testing would be to run this unit test to make sure the output of sliding_chunks perfectly matches tvm https://github.com/allenai/longformer/blob/master/tests/test_sliding_chunks.py

@safooray
Copy link
Author

safooray commented Jun 1, 2020

Thank you for this pointer. Both the above test and code snippet in readme lead to this error at dlpack.py L40:

dlpack.py", line 40, in _wrapper
    return tvm_func(*args)
  File "/usr/local/lib/python3.6/dist-packages/tvm-0.7.dev1-py3.6-linux-x86_64.egg/tvm/runtime/module.py", line 110, in __call__
    return self.entry_func(*args)
  File "/usr/local/lib/python3.6/dist-packages/tvm-0.7.dev1-py3.6-linux-x86_64.egg/tvm/_ffi/_ctypes/packed_func.py", line 219, in __call__
    values, tcodes, num_args = _make_tvm_args(args, temp_args)
  File "/usr/local/lib/python3.6/dist-packages/tvm-0.7.dev1-py3.6-linux-x86_64.egg/tvm/_ffi/_ctypes/packed_func.py", line 185, in _make_tvm_args
    raise TypeError("Don't know how to handle type %s" % type(arg))
TypeError: Don't know how to handle type <class 'tvm_runtime.ndarray.NDArray'>

Just posting it here while I work on it in case others face the same issue.

@ibeltagy
Copy link
Collaborator

ibeltagy commented Jun 1, 2020

Maybe something changed in the 0.7.dev1 api compared to 0.6.0. If you can reproduce the error in a small example, try asking here https://discuss.tvm.ai/.

@safooray
Copy link
Author

safooray commented Jun 1, 2020

So this happens when I use TVM's load_module to load the compiled kernel, something in TVM doesn't know how to handle longformer's custom ndarray type.

When I switch to longformer tvm runtime's load function instead, I get:

module.py", line 263, in load
    return _LoadFromFile(path, fmt)
NameError: name '_LoadFromFile' is not defined

I can't tell where _LoadFromFile was supposed to be known from, because it's not in the imports :-?

@ibeltagy
Copy link
Collaborator

ibeltagy commented Jun 1, 2020

how to handle longformer's custom ndarray type.

I don't think our code has a custom ndarray type.

_LoadFromFile

_LoadFromFile is probably one of the C++ functions that are compiled into libtvm.so.
I think the load function changed in 0.7.x, can you check the tutorial here https://docs.tvm.ai/tutorials/tensor_expr_get_started.html#load-compiled-module

@safooray
Copy link
Author

safooray commented Jun 1, 2020

I don't think our code has a custom ndarray type.

This is the ndarray class I meant:
https://github.com/longformer/blob/master/tvm/ndarray.py

@ibeltagy
Copy link
Collaborator

ibeltagy commented Jun 1, 2020

got it. TVM doesn't have a small runtime code, so I copied a few of the tvm files into longformer to save the user the need to pip install tvm. It is a hacky solution but it works.
To use tvm 0.7.x, you will need to update the stuff in longformer/tvm/ to match the new version. You will probably need a few trial and errors and figure out the smallest number of relevant files.

@safooray
Copy link
Author

safooray commented Jun 8, 2020

So I completely removed dependency on the small tvm runtime, and always import the whole thing. With this I can successfully compile and load the kernel.

The sliding chunks test fails though. The non-zero elements in the tvm results match the sliding chunks results but some blocks of the tvm result tensor have all zero elements where sliding chunks gives normal non-zero ones.

@safooray
Copy link
Author

safooray commented Jun 8, 2020

It's weird that I don't get the all-zero blocks consistently. I run the same line attention1 = diagonaled_mm_tvm(query, key, W, D, False, 0, autoregressive) several times in a row in debug mode and sometimes the results are exactly as the sliding chunks output, and sometimes these chunks of all zeros appear.

@ibeltagy
Copy link
Collaborator

ibeltagy commented Jun 9, 2020

ha, interesting. Does that happen with the code out-of-the-box or only with the kernel you compiled?

@safooray
Copy link
Author

safooray commented Jun 9, 2020

I never ran the code out of the box since I started with a different tvm version.

@ibeltagy
Copy link
Collaborator

ibeltagy commented Jun 9, 2020

is there an easy way for me to reproduce it? or, can you try it on a very small example and show me how the zero pattern looks like

@safooray
Copy link
Author

safooray commented Jun 9, 2020

N needs to be more than 16 to reproduce this, so I went with N=20 to be divisible by my w*2, setting w=2, M=4, B=D=H=1

      ([[[[   -inf,    -inf,  1.0454,  0.2067,  1.0842]],
         [[   -inf,  0.0000,  0.0000, -0.1133, -0.1551]],
         [[ 0.0000,  0.0000,  0.0000,  0.1974,  0.5001]],
         [[ 0.0000,  0.0000,  0.0000, -0.3928,  0.1745]],
         [[ 0.0000,  0.0000,  0.0000, -1.4956, -0.6937]],
         [[ 0.0000,  0.0000,  0.0000,  1.3766, -3.2872]],
         [[ 0.0000,  0.0000,  0.0000, -1.4435, -0.6505]],
         [[ 0.0000,  0.0000,  0.0000, -0.8083,  0.0732]],
         [[ 0.0000,  0.0000,  0.0000, -1.1314, -2.0845]],
         [[ 0.0000,  0.0000,  0.0000,  0.5186, -0.0927]],
         [[ 0.0000,  0.0000,  0.0000, -0.5307,  2.3749]],
         [[ 0.0000,  0.0000,  0.0000,  0.5193,  1.1943]],
         [[ 0.0000,  0.0000,  0.0000, -0.8338,  0.4442]],
         [[ 0.0000,  0.0000,  0.0000,  0.7164,  1.4491]],
         [[ 0.0000,  0.0000,  0.0000,  0.7989, -2.2085]],
         [[ 0.0000,  0.0000,  0.0000, -0.9621,  1.1597]],
         [[ 0.0000,  0.0000,  0.0000, -2.7438, -0.9485]],
         [[-0.1792, -3.9535,  4.3465,  3.5005, -0.0829]],
         [[-0.8578, -1.7571,  2.3984,  1.2971,    -inf]],
         [[ 1.1246,  2.1220,  1.1970,    -inf,    -inf]]]]

@safooray
Copy link
Author

Is this line correct? Shouldn't something be done with ko too?

        ko, ki = s[Z].split(Z.op.reduce_axis[0], factor=b0)
        ZF = s.rfactor(Z, ki)

I see that this is what the tvm doc does as well, but I don't get that either.

PS: should I move this to a new issue?

@ibeltagy
Copy link
Collaborator

I copied it from the tutorial and didn't carefully think about how it works. Are you trying to make it faster? The conversation here should be helpful.

@safooray
Copy link
Author

No, I'm trying to debug the block of zeros.

@ibeltagy
Copy link
Collaborator

If you are suspecting it is the scheduler, you can replace the whole scheduler with a naive one for debugging, something like:

s[Z].bind(s[Z].op.axis[-1], tvm.thread_axis("blockIdx.x"))

@ibeltagy
Copy link
Collaborator

ibeltagy commented Jun 12, 2020

btw, is this fp16 or fp32? I am asking because there used to be a bug in the codegen of fp16

@safooray
Copy link
Author

I'm passing dtype = torch.float32 to my test tensors.

@ibeltagy
Copy link
Collaborator

Right now we don't know where the bug is. In could be in our code, your code, or in TVM itself.
In addition to replacing the scheduler with a simpler one, you can also try our code out of the box, just for debugging. I also found that really short sequences don't work, so maybe increase the sequence length to see if the problem goes away. Finally, here's a small TVM example that might be useful as a diagnostic

import torch
import tvm
from tvm.contrib import dlpack
d1 = tvm.var('d1')  # define dimensions as variables
d2 = tvm.var('d2')  # define dimensions as variables
d3 = tvm.var('d3')  # define dimensions as variables
A = tvm.placeholder((d1, d2), name='A', dtype='float32')  # first tensor 
B = tvm.placeholder((d2, d3), name='B', dtype='float32')  # second tensor
k = tvm.reduce_axis((0, d2), name='k')  # dimension to sum over
output_shape = (d1, d3)
algorithm = lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=k)  # explain computation
R = tvm.compute(output_shape, algorithm, name='R')
s = tvm.create_schedule(R.op)
s[R].bind(s[R].op.axis[1], tvm.thread_axis("blockIdx.x"))  # map computation to gpu resources
tvm_fn = tvm.build(s, [A, B, R], target='cuda', target_host='llvm', name='mm1')  # generate C++ code and compile
tvm_fn.export_library('libmm.so')  # save to disk
tvm_fn = tvm.module.load('libmm.so')  # load from disk
my_mm_pytorch_fn = dlpack.to_pytorch_func(tvm_fn)  # package it as a PyTorch function
X = torch.randn(128, 256, device='cuda')  # allocate pytorch input tensors
Y = torch.randn(256, 32, device='cuda')  # allocate pytorch input tensors
Z = X.new_empty(128, 32, device='cuda')  # allocate pytorch output tensor
my_mm_pytorch_fn(X, Y, Z)  # call the tvm kernel
torch.allclose(X.matmul(Y), Z, atol=1e-04)  # compare tvm output with pytorch

@safooray
Copy link
Author

I tried the code out of the box, tests pass.
I also tried replacing the scheduler with the simple version, and tests pass.

So we know the bug is in the splitting and binding.

@ibeltagy
Copy link
Collaborator

Does it still happen randomly or is it consistently breaking?

The next step would be to find a minimal example that reproduces the bug and post it on the TVM forum.

@ibeltagy
Copy link
Collaborator

Btw, this is the reason I am using a specific version of TVM, because it changes a lot and things break after they were working

@safooray
Copy link
Author

I have now isolated the issue to this split and the corresponding binds:
j_outer, j_inner = s[Z].split(s[Z].op.axis[-1], factor=b1)

The split on axis 1 and the reduce_axis split are fine; tests pass with them.

Thank you for your help so far.

@ibeltagy
Copy link
Collaborator

Will close this issue for now. Please reopen if you have other questions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants