refactor!: Remove generalized gramians#692
Conversation
Simplify `Engine.compute_gramian` to always return a flat `[m, m]` PSD matrix where `m = output.numel()`, removing the concepts of generalized Gramians, `PSDTensor`, `GeneralizedWeighting`, and `Flattening`. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
ce15075 to
5ad48d2
Compare
|
No idea what claude changed in autogram. I need to look at that in details before you bother reviewing this @PierreQuinton. |
PierreQuinton
left a comment
There was a problem hiding this comment.
I think I like it but let me think about it. Also I don't think we want to erase all private generalized_gramian utilities and put their implementation in autogram.
| def compute_gramian(t: Tensor) -> PSDMatrix: ... | ||
|
|
||
|
|
||
| @overload | ||
| def compute_gramian(t: Tensor, contracted_dims: Literal[-1]) -> PSDMatrix: | ||
| pass | ||
| def compute_gramian(t: Tensor, contracted_dims: Literal[-1]) -> PSDMatrix: ... | ||
|
|
||
|
|
||
| @overload | ||
| def compute_gramian(t: Matrix, contracted_dims: Literal[1]) -> PSDMatrix: | ||
| pass | ||
| def compute_gramian(t: Matrix, contracted_dims: Literal[1]) -> PSDMatrix: ... | ||
|
|
There was a problem hiding this comment.
It removed all the pass here.
There was a problem hiding this comment.
Yeah no idea why. I think we had a problem with ... and another problem (code coverage issue) with pass, and we fixed the latter. Not sure why ... even works there.
| @@ -1,87 +0,0 @@ | |||
| from math import prod | |||
There was a problem hiding this comment.
why do we remove this file? We still essentially need to use some of its functions in autogram (as we do have a [m, k, k, m] gramian matrix). I would rather not replace all calls to this by their implementation.
There was a problem hiding this comment.
I agree that if this is necessary in autogram, we shouldn't remove this. So we first have to get a good autogram implementation before deciding if we keep this or not.
| m = square_gramian.shape[0] | ||
| internal_indices = torch.arange(m, device=output.device).reshape(ordered_output.shape) | ||
| perm = internal_indices.movedim(-1, self._batch_dim).reshape(-1) | ||
| gramian = cast(PSDMatrix, square_gramian[perm, :][:, perm]) |
There was a problem hiding this comment.
I think this is less efficient than reshape + movedim, which is essentially just a different stride. The reason is it is hard to detect that this slicing operation is effectively a new striding. I liked the old implementation, I would vouch for restoring it. Otherwise I would translate it here somewhere.
Closes #690.
Summary
Engine.compute_gramiannow always returns a flat[m, m]PSD matrix wherem = output.numel(), regardless of the output shapePSDTensor,GeneralizedWeighting, andFlattening— they are no longer neededUPGradWeightingdirectly and reshape the weights before callingbackwardCHANGELOG.mdMigration guide
🤖 Generated with Claude Code