Skip to content

refactor!: Remove generalized gramians#692

Draft
ValerianRey wants to merge 4 commits into
mainfrom
remove-generalized-gramians
Draft

refactor!: Remove generalized gramians#692
ValerianRey wants to merge 4 commits into
mainfrom
remove-generalized-gramians

Conversation

@ValerianRey
Copy link
Copy Markdown
Contributor

Closes #690.

Summary

  • Engine.compute_gramian now always returns a flat [m, m] PSD matrix where m = output.numel(), regardless of the output shape
  • Removed PSDTensor, GeneralizedWeighting, and Flattening — they are no longer needed
  • Updated the IWMTL example to use UPGradWeighting directly and reshape the weights before calling backward
  • Added a migration guide in CHANGELOG.md

Migration guide

# Before
weighting = Flattening(UPGradWeighting())
gramian = engine.compute_gramian(losses)  # shape: [m1, m2, m2, m1]
weights = weighting(gramian)              # shape: [m1, m2]
losses.backward(weights)

# After
weighting = UPGradWeighting()
gramian = engine.compute_gramian(losses)           # shape: [m1*m2, m1*m2]
weights = weighting(gramian).reshape(losses.shape) # shape: [m1, m2]
losses.backward(weights)

🤖 Generated with Claude Code

Simplify `Engine.compute_gramian` to always return a flat `[m, m]` PSD matrix
where `m = output.numel()`, removing the concepts of generalized Gramians,
`PSDTensor`, `GeneralizedWeighting`, and `Flattening`.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@ValerianRey ValerianRey added package: aggregation cc: refactor Conventional commit type for any refactoring, not user-facing, and not typing or perf improvements breaking-change This PR introduces a breaking change. package: autogram labels May 20, 2026
@github-actions github-actions Bot changed the title feat(autogram): Remove generalized gramians refactor!: Remove generalized gramians May 20, 2026
@ValerianRey ValerianRey force-pushed the remove-generalized-gramians branch from ce15075 to 5ad48d2 Compare May 20, 2026 23:59
@ValerianRey
Copy link
Copy Markdown
Contributor Author

No idea what claude changed in autogram. I need to look at that in details before you bother reviewing this @PierreQuinton.

Copy link
Copy Markdown
Contributor

@PierreQuinton PierreQuinton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I like it but let me think about it. Also I don't think we want to erase all private generalized_gramian utilities and put their implementation in autogram.

Comment on lines +10 to 22
def compute_gramian(t: Tensor) -> PSDMatrix: ...


@overload
def compute_gramian(t: Tensor, contracted_dims: Literal[-1]) -> PSDMatrix:
pass
def compute_gramian(t: Tensor, contracted_dims: Literal[-1]) -> PSDMatrix: ...


@overload
def compute_gramian(t: Matrix, contracted_dims: Literal[1]) -> PSDMatrix:
pass
def compute_gramian(t: Matrix, contracted_dims: Literal[1]) -> PSDMatrix: ...

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It removed all the pass here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah no idea why. I think we had a problem with ... and another problem (code coverage issue) with pass, and we fixed the latter. Not sure why ... even works there.

@@ -1,87 +0,0 @@
from math import prod
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we remove this file? We still essentially need to use some of its functions in autogram (as we do have a [m, k, k, m] gramian matrix). I would rather not replace all calls to this by their implementation.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that if this is necessary in autogram, we shouldn't remove this. So we first have to get a good autogram implementation before deciding if we keep this or not.

m = square_gramian.shape[0]
internal_indices = torch.arange(m, device=output.device).reshape(ordered_output.shape)
perm = internal_indices.movedim(-1, self._batch_dim).reshape(-1)
gramian = cast(PSDMatrix, square_gramian[perm, :][:, perm])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is less efficient than reshape + movedim, which is essentially just a different stride. The reason is it is hard to detect that this slicing operation is effectively a new striding. I liked the old implementation, I would vouch for restoring it. Otherwise I would translate it here somewhere.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

breaking-change This PR introduces a breaking change. cc: refactor Conventional commit type for any refactoring, not user-facing, and not typing or perf improvements package: aggregation package: autogram

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Remove generalized gramians

2 participants