https://pytorch.org/docs/stable/notes/extending.html#extending-autograd

In [2]:
import torch
import numpy as np
import inspect

In [3]:
class AttentionOp(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, q, k):
        w = torch.einsum('ncq,nck->nqk', q.to(torch.float32), (k / np.sqrt(k.shape[1])).to(torch.float32)).softmax(dim=2).to(q.dtype)
        ctx.save_for_backward(q, k, w)
        return w

    @staticmethod
    def backward(ctx, dw):
        q, k, w = ctx.saved_tensors
        db = torch._softmax_backward_data(grad_output=dw.to(torch.float32), output=w.to(torch.float32), dim=2, input_dtype=torch.float32)
        dq = torch.einsum('nck,nqk->ncq', k.to(torch.float32), db).to(q.dtype) / np.sqrt(k.shape[1])
        dk = torch.einsum('ncq,nqk->nck', q.to(torch.float32), db).to(k.dtype) / np.sqrt(k.shape[1])
        return dq, dk

In [4]:
# --- Example usage ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
result = AttentionOp.apply(
    torch.randn(2, 5, 192, device=device),
    torch.randn(1, 5, 192, device=device)
)
print(result.shape)

torch.Size([2, 192, 192])


In [5]:


class AttentionOp_custom(torch.autograd.Function):
    @staticmethod
    def setup_context(ctx, inputs, output):
        # No additional context needed for functorch transforms.
        q, k = inputs
        w = output
        ctx.save_for_backward(q, k, w)

    @staticmethod
    def forward(q, k):
        w = torch.einsum('ncq,nck->nqk', q.to(torch.float32), (k / np.sqrt(k.shape[1])).to(torch.float32)).softmax(dim=2).to(q.dtype)
        # ctx.save_for_backward(q, k, w)
        return w

    @staticmethod
    def backward(ctx, dw):
        q, k, w = ctx.saved_tensors
        db = torch._softmax_backward_data(grad_output=dw.to(torch.float32), output=w.to(torch.float32), dim=2, input_dtype=torch.float32)
        dq = torch.einsum('nck,nqk->ncq', k.to(torch.float32), db).to(q.dtype) / np.sqrt(k.shape[1])
        dk = torch.einsum('ncq,nqk->nck', q.to(torch.float32), db).to(k.dtype) / np.sqrt(k.shape[1])
        return dq, dk
# --- Hiding the signature ---
# The goal is that the user sees the "apply" method as accepting only (q, k)
# rather than (cls, q, k) with a hidden 'ctx'.

# First, grab the original apply method.
_original_apply = AttentionOp_custom.apply

# Define a new wrapper that takes only (q, k) and calls the original.
def _apply_wrapper(q, k):
    return _original_apply(q, k)

# Set the __signature__ of our wrapper to expose only q and k.
_apply_wrapper.__signature__ = inspect.Signature(parameters=[
    inspect.Parameter("q", inspect.Parameter.POSITIONAL_ONLY),
    inspect.Parameter("k", inspect.Parameter.POSITIONAL_ONLY),
])

# Replace the original apply with our wrapped version.
AttentionOp_custom.apply = _apply_wrapper

# --- Example usage ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
result = AttentionOp_custom.apply(
    torch.randn(2, 5, 192, device=device),
    torch.randn(1, 5, 192, device=device)
)
print(result.shape)

torch.Size([2, 192, 192])
