Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autograd rules for Pytorch inputs #40

Closed
jemisjoky opened this issue Mar 28, 2023 · 10 comments
Closed

Autograd rules for Pytorch inputs #40

jemisjoky opened this issue Mar 28, 2023 · 10 comments
Assignees

Comments

@jemisjoky
Copy link

As someone who benefits a lot from having high-performance tensor network methods, a huge thanks for putting this excellent library together (and for creating Python bindings for greater ease of use)!

I'm attempting to use these tools (in particular, cuquantum.contract) to accelerate GPU tensor network contraction with Pytorch, in a setting where automatic differentiation is used to optimize model parameters. Everything works great in the forward pass, but as soon I try to compute gradients, I get the error RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn.

I'm including a minimum working example below for reproducing this error, but in general I wanted to ask if implementing autograd rules for Pytorch inputs is something that is planned for cuQuantum. I understand this isn't a trivial thing to add, but for anyone using these tools for ML (and also for many physics users) being able to efficiently contract and backpropagate would be a huge bonus for the library.

And of course, if all of this functionality is available already and I'm just doing something wrong, that would be wonderful news 😁

MWE:

import torch
from cuquantum import contract

# Dummy class whose parameters will get trained
class Contractor(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.ones(2, 3))

    def forward(self, x):
        return contract("ab,ba->", self.param, x)

# Initialization of model and evaluation on input
model = Contractor()
data = torch.ones(3, 2)
loss = model(data)

# This is where things break
loss.backward()
@leofang leofang assigned leofang and tlubowe and unassigned leofang Mar 29, 2023
@tlubowe
Copy link
Collaborator

tlubowe commented Mar 31, 2023

Hi @jemisjoky thank you for reaching out - it's great to hear from users like yourself! This is pretty cool work and hopefully we can discuss more in an upcoming meeting.

We are currently working on supporting gradients of Pauli strings from a tensor network with optimizations based upon intelligent reuse of intermediates to optimize performance. We hope to have this in our next next release (will not be included in 23.03, which is coming shortly).

@jemisjoky
Copy link
Author

That's great to hear @tlubowe, thanks for that info! Looking forward to talking about this soon 😁

@leofang leofang assigned leofang and unassigned tlubowe Jul 14, 2023
@leofang
Copy link
Member

leofang commented Jul 14, 2023

Hi @jemisjoky just FYI, we released cuQuantum v23.06 / cuTensorNet v2.2.0. It includes an experimental gradient API, see the sample code here and how we test it.

For this release, there's only low-level Python binding; no pythonic API is built yet. But it should be enough for you to hook into your torch.nn.Module snippet, if you have time to give it a shot. Basically, you want to call contract_slices() in the forward() method to perform the usual (einsum) contraction, and then compute_gradient_backward() in the backward() method. Note that for complex-valued tensors, our gradients differ from PyTorch's by a complex conjugation.

If you do try it out, please kindly let us know your feedbacks, we'd appreciate it and take into consideration for making it a formal API offering (+ pythonic API at least for PyTorch users). Thanks 🙂

@mtjrider
Copy link
Collaborator

mtjrider commented Sep 3, 2023

@jemisjoky are you working with our low-level interface or are you waiting for the upcoming support alluded to by @tlubowe?

I recommend we move this issue to a GitHub discussion.

@jemisjoky
Copy link
Author

@mtjrider @leofang huge thanks for putting together this gradient API! My apologies for not replying earlier, I somehow missed the notification when you put the feature together earlier.

I think that should meet our needs for automatic differentiation support. I'm going to try integrating this as a custom autograd rule in the wrapper class we made for cuquantum.contract, and see if any issues pop up with that. I'm a bit swamped with some other projects right now, but will let you know when I have anything to report along those lines. Many thanks again for putting this together, I really appreciate it!

@leofang
Copy link
Member

leofang commented Sep 5, 2023

FYI we already have an internal implementation that meets your need. It's tentatively slated for the upcoming (23.10) release. If you can't wait for that long, please let @sam-stanwyck know and we'll figure out how to help you better.

@jemisjoky
Copy link
Author

Ah yes, that internal implementation sounds perfect! I'll keep an eye out for that release and let you know if I have any feedback on that once it's released. Thanks everyone!

@mtjrider
Copy link
Collaborator

mtjrider commented Sep 7, 2023

@jemisjoky I will leave this issue open and provide you with an update when the features are released.

Please feel free to post here with any other questions.

@leofang
Copy link
Member

leofang commented Nov 3, 2023

@jemisjoky Thanks for the waiting. I am pleased to share that the feature for differentiable contract is available in the new 23.10 release. Please check out the documentation and the code samples, and let us know if you have any questions/issues.

@leofang leofang closed this as completed Nov 3, 2023
@jemisjoky
Copy link
Author

This looks PERFECT, massive thanks for this @leofang (and anyone who worked on this functionality)!!! This feature is a massive boost for anyone using Pytorch for ML work with TNs, looking forward to using this soon!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants