Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Any plans for double backward / second-order gradients ? i.e. backward for backward functions. #58

Closed
ventusff opened this issue Mar 10, 2022 · 9 comments

Comments

@ventusff
Copy link
Contributor

ventusff commented Mar 10, 2022

Hi,
First of all, thanks for the great repo! I've already built a project based on tcnn and found it extremely helpful.

However during usage, I found out that since the backward functions are c++ implemented, they are not trackable by pytorch, causing autograd.grad(..., create_graph=True) fails to generate grad_fn for grads (i.e. second-order gradients).

This functionality is helpful when training and losses are related to first-order gradients. For example, when training a SDF MLP, typically a eikonal loss will be used, which is a loss applied on dy_dx (nablas) of the network. To achieve this, a d(dy_dx)_dparam is needed.
Ref: https://arxiv.org/abs/2002.10099
Fig:
image

Currently I'm writing custom backward_backward functions upon tcnn's grid.h and fully_fused_mlp.cu, but it would be really nice if this could be officially supported. 😄

BR,
Ventus


🎉🎉🎉 UPDATE: to all people who reach here

For now, a partial support for double backward and only for grid encodings is implemented within the tiny-cuda-nn repo.

Example usage script could be found here.

For implementation details, please check the original PR #69 .

@Tom94
Copy link
Collaborator

Tom94 commented Mar 14, 2022

Hi there, thanks for the kind words! :)

Unfortunately, supporting second-order derivatives throughout the entire framework would be quite an undertaking. While nice-to-have, it's not a high priority at the moment.

Still, I'm going to leave this issue open as a TODO marker.

Cheers!

@za-cheng
Copy link

Hi,

First huge thanks for the nice paper and pytorch extension.

About this issue - is there any chance you can make the hash encoding module twice differentiable? I.e. support calling backward on (d_encoding/d_input_coords)?

This should be easier to implement than full MLP and would still be useful in scenarios like sequentializing hash encoding and pytorch MLPs to support SDF gradients.

Thanks
Z

@ventusff
Copy link
Contributor Author

Hi @za-cheng , my custom implementation of backward_backward of grid.h has passed checking.

i.e. from dL_ddLdinput to dL_dparam and dL_ddLdoutput.
Both torch.gradcheck & torch.gradgradcheck are passed, and a simple SDF training procedure with eikonal loss is also passed.

I'm cleaning codes and willing to submit a PR here soon within this week. :)

@za-cheng
Copy link

Great thanks @ventusff I'll keep an eye out for it.

@ventusff
Copy link
Contributor Author

Hi @za-cheng , the PR #69 is submited 😄 !

I managed to add a partial support for second-order derivatives, only for grid.h, and only for d(dL_dinput)_d(...).
Detailed theoretical derivations and code tests are in the above PR.

After compling this implementation, you can try my toy SDF training script test_train() with an eikonal term, which is provided here:
https://gist.github.com/ventusff/57f47588eaff5f8b77a382260e7da8a3

BR,
Ventus

@za-cheng
Copy link

Hi @ventusff

Thanks so much for the PR. I tested the script however got a cuda illegal memory access error with n_features_per_level>2. This error does not show with n_features_per_level<=2.

Best,
Z

@ventusff
Copy link
Contributor Author

ventusff commented Mar 18, 2022

@za-cheng Fixed now. You can pull and compile again :)
Gradcheck passed with n_features_per_level>2 too.
NOTE: you might need to uninstall and clean previously built binaries, since only a header file has chaneged.

@Tom94
Copy link
Collaborator

Tom94 commented Apr 24, 2022

@ventusff 's PR got merged since

@Tom94 Tom94 closed this as completed Apr 24, 2022
@juuso-oskari
Copy link

@ventusff Thank you so much for your work. Are you still doing the double backwards for the fully_fused_mlp.cu? I would really like to test it on my thesis project.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants