Skip to content

refactor: Improve PSD typing#522

Merged
ValerianRey merged 37 commits intomainfrom
add-generalized-matrix-psd-matrix
Jan 23, 2026
Merged

refactor: Improve PSD typing#522
ValerianRey merged 37 commits intomainfrom
add-generalized-matrix-psd-matrix

Conversation

@PierreQuinton
Copy link
Copy Markdown
Contributor

@PierreQuinton PierreQuinton commented Jan 19, 2026

  • Add PSDTensor
  • Remove GeneralizedMatrix
  • Use classes for Matrix, PSDMatrix and PSDTensor instead of type annotations
  • Use casting in compute_gramian, normalize and regularize
  • Add typeguard functions is_matrix, is_psd_tensor and is_psd_matrix
  • Move normalize and regularize to _linalg
  • Move _check_is_matrix to Aggregator.__call__
  • Improve internal type hints
  • Rename reshape_gramian to reshape and movedim_gramian to movedim
  • Add _gramian_utils.flatten
  • Rename a few tests
  • Add some parametrizations to test_gramian_is_psd
  • Add test_reshape_yields_psd, test_flatten_yields_matrix, test_flatten_yields_psd, test_movedim_yields_psd, test_normalize_yields_psd and test_regularize_yields_psd
  • Add assert_is_psd_tensor

Also fixed some typing problems on the rest of the lib (sorry).

I'm still not sure about having classes, but as far as I understand this is the only way of having intersections. The main counter argument is that it makes the MRO heavier to compute.

TODO:

  • add unit tests of the PSD property for every function that use the type guard for PSDQuadraticForm or PSDMatrix.
  • at the end: change most (all?) is_* to cast(*, *)

@PierreQuinton PierreQuinton added cc: refactor Conventional commit type for any refactoring, not user-facing, and not typing or perf improvements package: linalg labels Jan 19, 2026
@codecov
Copy link
Copy Markdown

codecov bot commented Jan 19, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

Files with missing lines Coverage Δ
src/torchjd/_linalg/__init__.py 100.00% <100.00%> (ø)
src/torchjd/_linalg/_gramian.py 100.00% <100.00%> (ø)
src/torchjd/_linalg/_matrix.py 100.00% <100.00%> (ø)
src/torchjd/aggregation/_aggregator_bases.py 100.00% <100.00%> (ø)
src/torchjd/aggregation/_cagrad.py 100.00% <100.00%> (ø)
src/torchjd/aggregation/_config.py 100.00% <100.00%> (ø)
src/torchjd/aggregation/_dualproj.py 100.00% <100.00%> (ø)
src/torchjd/aggregation/_flattening.py 100.00% <100.00%> (ø)
src/torchjd/aggregation/_graddrop.py 100.00% <100.00%> (ø)
src/torchjd/aggregation/_pcgrad.py 100.00% <100.00%> (ø)
... and 11 more
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@claude
Copy link
Copy Markdown

claude bot commented Jan 19, 2026

Code review

No issues found. Checked for bugs and CLAUDE.md compliance.

@PierreQuinton PierreQuinton changed the title refactor(linalg): Add PSDQuadraticForm and GeneralizedMatrix. refactor(linalg): Handle PSDQuadraticForm and GeneralizedMatrix. Jan 20, 2026
Copy link
Copy Markdown
Contributor

@ValerianRey ValerianRey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very good. I really like all of those changes. I made many minor comments that should be super easy to fix.

Suggested commit description:

  • Add PSDTensor
  • Remove GeneralizedMatrix
  • Use classes for Matrix, PSDMatrix and PSDTensor instead of type annotations
  • Use casting in compute_gramian, normalize and regularize
  • Add typeguard functions is_matrix, is_psd_tensor and is_psd_matrix
  • Move normalize and regularize to _linalg
  • Move _check_is_matrix to Aggregator.__call__
  • Improve internal type hints
  • Rename reshape_gramian to reshape and movedim_gramian to movedim
  • Add _gramian_utils.flatten
  • Rename a few tests
  • Add some parametrizations to test_gramian_is_psd
  • Add test_reshape_yields_psd, test_flatten_yields_matrix, test_flatten_yields_psd, test_movedim_yields_psd, test_normalize_yields_psd and test_regularize_yields_psd
  • Add assert_is_psd_tensor

@ValerianRey
Copy link
Copy Markdown
Contributor

@claude Could you double-check that we didn't forget to change some internal type hints? Also, did I miss something in my review?

@claude
Copy link
Copy Markdown

claude bot commented Jan 20, 2026

Claude finished @ValerianRey's task in 3m 14s —— View job


Type Hints Review Complete

I've performed a comprehensive review of the internal type hints across the PR. Here are my findings:

✅ Internal Type Hints Look Good

The type hint migration from Annotated type aliases to classes has been done consistently and correctly across the codebase:

Core type system:

  • _linalg/_matrix.py properly defines the new class hierarchy: GeneralizedMatrix, Matrix, PSDGeneralizedMatrix, PSDMatrix
  • Type guards correctly use TypeGuard return type
  • All imports properly use the new classes instead of old type aliases

Production code:

  • All aggregators correctly use Matrix or PSDMatrix in their type hints (checked: _aggregator_bases.py, _cagrad.py, _pcgrad.py, _graddrop.py, _dualproj.py, _upgrad.py, _config.py)
  • All weightings correctly use PSDMatrix or PSDGeneralizedMatrix (checked: _weighting_bases.py, _flattening.py)
  • autogram module correctly uses the new types (checked: _engine.py, _gramian_computer.py, _jacobian_computer.py, _gramian_utils.py)
  • _linalg/_gramian.py has proper type hints: compute_gramian(GeneralizedMatrix) -> PSDMatrix, normalize(PSDMatrix) -> PSDMatrix, regularize(PSDMatrix) -> PSDMatrix

No remaining old references:

  • ✅ No remaining PSDQuadraticForm references
  • ✅ No old Annotated[Tensor, ...] type alias usage
  • ✅ All imports from torchjd._linalg use the new classes

📝 About ValerianRey's Review

ValerianRey's review is excellent and thorough. I didn't find anything missing. The review covers:

  1. Naming consistency (e.g., GeneralizedMatrix vs PSDGeneralizedMatrix naming)
  2. Assert vs cast usage (appropriate suggestions to use cast for internal code)
  3. Documentation improvements (docstring clarifications)
  4. Test utility improvements (return type hints for forward_backwards.compute_gramian and compute_gramian_with_autograd)
  5. Code organization (moving functions, one-lining, etc.)
  6. Outdated comments removal

The only suggestion in the review related to type hints is about test utility functions (tests/utils/forward_backwards.py), not production code. ValerianRey correctly identified that compute_gramian and compute_gramian_with_autograd could have more specific return types (PSDGeneralizedMatrix), which would eliminate some unnecessary type assertions in tests.

✨ Conclusion

No internal type hints were forgotten. The migration from type aliases to classes was done comprehensively and correctly. All the suggestions in ValerianRey's review are valid improvements to make the code even better.


@claude
Copy link
Copy Markdown

claude bot commented Jan 21, 2026


Code review

No issues found. Checked for bugs and CLAUDE.md compliance.


@ValerianRey ValerianRey changed the title refactor(linalg): Handle PSDQuadraticForm and GeneralizedMatrix. refactor: Improve PSD typing Jan 23, 2026
@ValerianRey ValerianRey merged commit 2d7bf7f into main Jan 23, 2026
15 checks passed
@ValerianRey ValerianRey deleted the add-generalized-matrix-psd-matrix branch January 23, 2026 13:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: refactor Conventional commit type for any refactoring, not user-facing, and not typing or perf improvements package: linalg

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants