Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do you plan to have a python wrapper for the fully fused MLP? #14

Closed
MultiPath opened this issue Jan 15, 2022 · 8 comments
Closed

Do you plan to have a python wrapper for the fully fused MLP? #14

MultiPath opened this issue Jan 15, 2022 · 8 comments

Comments

@MultiPath
Copy link

Hi, I am not an expert on cuda coding but have more experience on pytorch/tensorflow...
Do you have any plans to have this code with a python (more specifically pytorch) wrapper?
Or will it be possible to point the location for forward/backward function of this MLP implementation so that we can potentially incorporate this into other python code?

Thanks a lot

@Tom94
Copy link
Collaborator

Tom94 commented Jan 15, 2022

Related to NVlabs/instant-ngp#8.

Summary: yes, we plan to have a Python wrapper for both the fully fused MLP and the hash grid encoding.

Colleagues of ours have internally built a functional (but not quite ready) PyTorch wrapper around tiny-cuda-nn, which we hope to release soonish. Based on experience so far, there will be a slowdown compared to the native C++ API, but still significantly faster than Python-native MLPs.

@MultiPath
Copy link
Author

MultiPath commented Jan 15, 2022

Sounds great and looking forward to that. Both ideas are really cool, and I am in particular like the fully fused version of MLP. It should be definitely much faster especially for small hidden sizes!!

It would also be really grateful if you can point me the file and lines for both the “fully fused mlp” and “hashtable encoding”. I can check the current implementation to better understand, and maybe have a simple try of binding it in the weekend.
Thanks again!

@Tom94
Copy link
Collaborator

Tom94 commented Jan 15, 2022

Gladly! Those would be

  • include/tiny-cuda-nn/encodings/grid.h
  • src/fully_fused_mlp.cu and include/tiny-cuda-nn/networks/fully_fused_mlp.h

@Tom94
Copy link
Collaborator

Tom94 commented Feb 14, 2022

I just pushed a first version (call it "beta") for a PyTorch extension. See this section of the README for installation/usage instructions and please do report problems you encounter along the way. :)

@MultiPath
Copy link
Author

MultiPath commented Feb 14, 2022

Hi, thanks for releasing the pytorch binding! I have tried a bit, compiled and ran the "mlp_learning_an_image_pytorch.py" script. It returned the following error.

Writing '0.jpg'... Traceback (most recent call last):
  File "samples/mlp_learning_an_image_pytorch.py", line 178, in <module>
    out_img = model(xy_padded)[:n_pixels,:].reshape(img_shape).clamp(0.0, 1.0).detach().cpu().numpy()
  File "/private/home/jgu/anaconda3/envs/fairseq-20210102/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/private/home/jgu/anaconda3/envs/fairseq-20210102/lib/python3.8/site-packages/tinycudann-1.4-py3.8-linux-x86_64.egg/tinycudann/ops.py", line 47, in forward
    output = _module_func.apply(
  File "/private/home/jgu/anaconda3/envs/fairseq-20210102/lib/python3.8/site-packages/tinycudann-1.4-py3.8-linux-x86_64.egg/tinycudann/ops.py", line 16, in forward
    output = native_tcnn_module.fwd(input, params)
RuntimeError: /private/home/jgu/work/tiny-cuda-nn/include/tiny-cuda-nn/cuda_graph.h:58 cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal) failed with error operation not permitted when stream is capturing

Do you have any clue what might be the problem? Thanks

It seems only failed for when writing we are trying to write the image under "with torch.no_grad():".
Training is fine and loss goes down quickly. Without "torch.no_grad()", it works ok on my side.

@Tom94
Copy link
Collaborator

Tom94 commented Feb 15, 2022

@MultiPath should be fixed now! Can you try again?

@a-canela
Copy link

It is working now on my side both with and without torch.no_grad(), many thanks!

@MultiPath
Copy link
Author

@MultiPath should be fixed now! Can you try again?

Confirm it works now. Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants