Skip to content

refactor(autogram): Make VJP take flat grad_outputs.#438

Merged
ValerianRey merged 6 commits intomainfrom
make-vjp-take-flat-grad-outputs
Oct 1, 2025
Merged

refactor(autogram): Make VJP take flat grad_outputs.#438
ValerianRey merged 6 commits intomainfrom
make-vjp-take-flat-grad-outputs

Conversation

@PierreQuinton
Copy link
Copy Markdown
Contributor

This allows removing the parameter output_spec from both autograd.Function in ModuleHookManager.

…er `output_spec` from both `autograd.Function` in `ModuleHookManager`.
@PierreQuinton PierreQuinton added cc: refactor Conventional commit type for any refactoring, not user-facing, and not typing or perf improvements package: autogram labels Sep 30, 2025
@codecov
Copy link
Copy Markdown

codecov bot commented Sep 30, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.

Files with missing lines Coverage Δ
src/torchjd/autogram/_module_hook_manager.py 100.00% <100.00%> (ø)
src/torchjd/autogram/_vjp.py 100.00% <100.00%> (ø)

... and 1 file with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ValerianRey
Copy link
Copy Markdown
Contributor

I really like this PR. Before it, we did unflatten then flatten for AutogradVJP, and now we just don't do any of those. So it should be a performance improvement.

@ValerianRey
Copy link
Copy Markdown
Contributor

We have a mypy error with this:

src/torchjd/autogram/_vjp.py:126: error: Argument 3 to "grad" has incompatible type "list[Tensor | None]"; expected "Tensor | Sequence[Tensor] | None"  [arg-type]
Found 1 error in 1 file (checked 52 source files

But it seems that PyTorch is in the wrong here. In the documentation of grad, they say:

grad_outputs should be a sequence of length matching output containing the "vector" in vector-Jacobian product, usually the pre-computed gradients w.r.t. each of the outputs. If an output doesn't require_grad, then the gradient can be None).

However, they type it as Optional[_TensorOrTensors], where _TensorOrTensors is defined as _TensorOrTensors: TypeAlias = Union[Tensor, Sequence[Tensor]].

@PierreQuinton did I miss something? If not, I'll probably open an issue or a PR in torch.

@ValerianRey
Copy link
Copy Markdown
Contributor

I just opened an issue in pytorch: pytorch/pytorch#164298

In the meantime, I think it's fine to replace tuple[Tensor | None, ...] by tuple[Tensor, ...], because we will always call these VJPs with only non-None grad_outputs, if I understand correctly (unless the type hint for *flat_grad_outputs in AccumulateJacobian.forward, which says Tensor rather than Tensor | None, is wrong.

@PierreQuinton
Copy link
Copy Markdown
Contributor Author

PierreQuinton commented Oct 1, 2025

I just opened an issue in pytorch: pytorch/pytorch#164298

In the meantime, I think it's fine to replace tuple[Tensor | None, ...] by tuple[Tensor, ...], because we will always call these VJPs with only non-None grad_outputs, if I understand correctly (unless the type hint for *flat_grad_outputs in AccumulateJacobian.forward, which says Tensor rather than Tensor | None, is wrong.

Yeah the type hint is wrong, we will update that in another PR though. The thing is when we filter with require_grad, we are supposed to remove any None from the grad_outputs. This is slightly complicated to type though.

@ValerianRey
Copy link
Copy Markdown
Contributor

I really like 9598370 and 89d5f8f

* They're always flat now
Copy link
Copy Markdown
Contributor

@ValerianRey ValerianRey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@PierreQuinton
Copy link
Copy Markdown
Contributor Author

PierreQuinton commented Oct 1, 2025

There seems to be a problem with macOS runners, I think this is safe to merge anyways, so should we?

@ValerianRey ValerianRey merged commit b075f6f into main Oct 1, 2025
13 of 17 checks passed
@ValerianRey ValerianRey deleted the make-vjp-take-flat-grad-outputs branch October 1, 2025 10:45
@ValerianRey
Copy link
Copy Markdown
Contributor

ValerianRey commented Oct 1, 2025

There seems to be a problem with macOS runners, I think this is safe to merge anyways, so should we?

Yes, no need to wait, especially since nothing here is macos-specific (no numerical errors can be made because of this PR).

Here is the status btw: https://www.githubstatus.com/
It should be working soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: refactor Conventional commit type for any refactoring, not user-facing, and not typing or perf improvements package: autogram

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants