-
Notifications
You must be signed in to change notification settings - Fork 55
mypy type math.py and test_math.py #242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughIntroduces a new API expm_frechet_with_matrix_exp returning (expm(A), Frechet(A,E)), refactors expm_frechet to delegate and return only the derivative, removes compute_expm flag across call sites, broadens typing to accept NumPy arrays, adjusts autograd and kronform callers, updates tests to the new API, and adds MyPy settings. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Caller as Caller (tests/optimizers)
participant Math as torch_sim.math
Caller->>Math: expm_frechet_with_matrix_exp(A, E, method)
activate Math
Note over Math: Validate inputs, shapes, types (torch/NumPy)
Math-->>Caller: (expm(A), L_expm(A)[E])
deactivate Math
sequenceDiagram
autonumber
participant PyTorch as Autograd Engine
participant Math as torch_sim.math.expm (Function)
participant F as torch_sim.math.expm_frechet
PyTorch->>Math: backward(grad_output)
activate Math
Math->>F: expm_frechet(A, grad_output, method="SPS", check_finite=False)
activate F
F-->>Math: L_expm(A)[grad_output]
deactivate F
Math-->>PyTorch: gradient wrt A
deactivate Math
Note over F: Wrapper delegates to expm_frechet_with_matrix_exp and returns derivative only
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (6)
torch_sim/math.py (3)
541-563: Device mismatch risk when using torch.arange for boolean maskstorch.arange defaults to CPU; unique_vals/counts follow eigenvalues’ device (CPU/GPU). Mixing devices in boolean indexing will error.
- unique_vals = unique_vals[ - ~(close_mask & torch.arange(len(close_mask)) != i) - ] - counts = counts[~(close_mask & torch.arange(len(counts)) != i)] + idx = torch.arange(len(close_mask), device=close_mask.device) + unique_vals = unique_vals[~(close_mask & (idx != i))] + counts = counts[~(close_mask & (idx != i))]
627-651: Use tensor-safe clamping instead of Python max with tensors (runtime error)max(lambda_val, num_tol) will raise at runtime for torch.Tensors. Use a tensor tolerance and torch.where/torch.clamp-like logic.
- if abs(lambda_val) > 1: + if torch.abs(lambda_val) > 1: scaled_T_minus_lambdaI = T_minus_lambdaI / lambda_val return torch.log(lambda_val) * Identity + scaled_T_minus_lambdaI # Alternative computation for small lambda - return torch.log(lambda_val) * Identity + T_minus_lambdaI / max(lambda_val, num_tol) # type: ignore[call-overload] + tol = torch.as_tensor(num_tol, dtype=lambda_val.dtype, device=lambda_val.device) + denom = torch.where(torch.abs(lambda_val) > tol, lambda_val, tol) + return torch.log(lambda_val) * Identity + T_minus_lambdaI / denom
653-683: Same tensor-safe denominator fix needed hereTwo denominators must avoid Python max.
- term2 = T_minus_lambdaI / max(lambda_val, num_tol) # type: ignore[call-overload] - term3 = T_minus_lambdaI_squared / max(2 * lambda_squared, num_tol) # type: ignore[call-overload] + tol = torch.as_tensor(num_tol, dtype=lambda_val.dtype, device=lambda_val.device) + denom1 = torch.where(torch.abs(lambda_val) > tol, lambda_val, tol) + denom2 = torch.where(torch.abs(2 * lambda_squared) > tol, 2 * lambda_squared, tol) + term2 = T_minus_lambdaI / denom1 + term3 = T_minus_lambdaI_squared / denom2torch_sim/optimizers.py (2)
1413-1429: Avoid repeated allocations and disable input checks in inner loops
- directions is reallocated every system; hoist it outside the loop.
- expm_frechet is inside a tight loop; pass check_finite=False explicitly to avoid repeated scans. Optionally pin method="SPS" to prevent accidental slowdowns.
- directions = torch.zeros((9, 3, 3), device=device, dtype=dtype) - for idx, (mu, nu) in enumerate([(i, j) for i in range(3) for j in range(3)]): - directions[idx, mu, nu] = 1.0 + if "unit_dirs" not in locals(): + unit_dirs = torch.zeros((9, 3, 3), device=device, dtype=dtype) + for idx, (mu, nu) in enumerate([(i, j) for i in range(3) for j in range(3)]): + unit_dirs[idx, mu, nu] = 1.0 @@ - expm_derivs = torch.stack( - [ - tsm.expm_frechet(deform_grad_log_new[b], direction) - for direction in directions - ] - ) + expm_derivs = torch.stack( + [ + tsm.expm_frechet( + deform_grad_log_new[b], direction, method="SPS", check_finite=False + ) + for direction in unit_dirs + ] + )
1720-1738: Same perf fix here: reuse unit directions and skip finiteness checksMirrors the previous site in ASE-FIRE step.
- directions = torch.zeros((9, 3, 3), device=device, dtype=dtype) - for idx, (mu, nu) in enumerate( - [(i_idx, j_idx) for i_idx in range(3) for j_idx in range(3)] - ): - directions[idx, mu, nu] = 1.0 + if "unit_dirs" not in locals(): + unit_dirs = torch.zeros((9, 3, 3), device=device, dtype=dtype) + for idx, (mu, nu) in enumerate( + [(i_idx, j_idx) for i_idx in range(3) for j_idx in range(3)] + ): + unit_dirs[idx, mu, nu] = 1.0 @@ - expm_derivs = torch.stack( - [ - tsm.expm_frechet(logm_F_new[b_idx], direction) - for direction in directions - ] - ) + expm_derivs = torch.stack( + [ + tsm.expm_frechet( + logm_F_new[b_idx], direction, method="SPS", check_finite=False + ) + for direction in unit_dirs + ] + )tests/test_math.py (1)
192-211: Indexing bugs: list indexing with 0-D tensors; random sizes also tensorsUsing 0-D torch.Tensors as list indices/shape will raise. Convert to Python ints.
- rfunc = tensor_rfuncs[torch.tensor(rng.choice(4))] - target_norm_1 = torch.tensor(rng.exponential()) - n = torch.tensor(rng.integers(2, 16)) - A_original = rfunc(size=(n, n)) - E_original = rfunc(size=(n, n)) + rfunc = tensor_rfuncs[int(rng.choice(4))] + target_norm_1 = torch.tensor(rng.exponential(), device=device, dtype=dtype) + n = int(rng.integers(2, 16)) + A_original = rfunc(size=(n, n), device=device).to(dtype) + E_original = rfunc(size=(n, n), device=device).to(dtype)
🧹 Nitpick comments (6)
mypy.ini (1)
1-8: Solid baseline mypy config; consider a couple of low-risk upgradesLooks good for a first typing PR. Two small, incremental wins you can add now without causing churn:
- Set python_version to lock expectations in CI.
- Enable warn_return_any to catch accidental Any leaks early.
[mypy] warn_unused_configs = True ignore_missing_imports = True check_untyped_defs = True explicit_package_bases = True warn_unreachable = True warn_redundant_casts = True warn_unused_ignores = True +python_version = 3.11 +warn_return_any = Truetorch_sim/math.py (3)
32-55: expm_frechet docstring return type is wrong and can confuse usersThe function returns a torch.Tensor (the Frechet derivative), not “ndarray.” Also note the API change (no compute_expm); consider adding a short deprecation note in the docstring.
- Returns: - ndarray. Frechet derivative of the matrix exponential of A in the direction E. + Returns: + torch.Tensor: Frechet derivative of the matrix exponential of A in the direction E.
57-105: Constrain method type and fix docstring typesConsider typing method with Literal for better mypy narrowing and update docstring types to “torch.Tensor”.
-from typing import Any, Final +from typing import Any, Final, Literal @@ -def expm_frechet_with_matrix_exp( - A: torch.Tensor, - E: torch.Tensor, - method: str | None = None, - check_finite: bool = True, -) -> tuple[torch.Tensor, torch.Tensor]: +def expm_frechet_with_matrix_exp( + A: torch.Tensor, + E: torch.Tensor, + method: Literal["SPS", "blockEnlarge"] | None = None, + check_finite: bool = True, +) -> tuple[torch.Tensor, torch.Tensor]: @@ - Returns: - expm_A: ndarray. Matrix exponential of A. - expm_frechet_AE: ndarray. Frechet derivative of the matrix exponential of A + Returns: + expm_A: torch.Tensor. Matrix exponential of A. + expm_frechet_AE: torch.Tensor. Frechet derivative of the matrix exponential of A
319-328: Avoid relying on boolean conversion of 0-D tensors in conditionalsUsing a 0-D tensor in an if can be fragile. Safer to compare Python floats.
- for m, pade in m_pade_pairs: - if A_norm_1 <= ell_table_61[m]: + for m, pade in m_pade_pairs: + if float(A_norm_1) <= ell_table_61[m]: @@ - if s is None: + if s is None: # scaling - s = max(0, int(torch.ceil(torch.log2(A_norm_1 / ell_table_61[13])))) + s = max(0, int(torch.ceil(torch.log2(A_norm_1 / ell_table_61[13]))).item())tests/test_math.py (2)
121-137: 1000×1000 expm in tests is likely to time out CIMatrix exponential on 1k×1k is very expensive and run twice per method. Consider reducing to n=128/256 or marking as slow.
- n = 1000 + n = 256 # keep runtime reasonable; mark as slow if you need larger sizes
254-265: Same performance concern as numpy pathThis also does two 1k×1k expms. Please reduce or mark slow like above.
- n = 1000 + n = 256
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (4)
mypy.ini(1 hunks)tests/test_math.py(9 hunks)torch_sim/math.py(15 hunks)torch_sim/optimizers.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
torch_sim/optimizers.py (1)
torch_sim/math.py (1)
expm_frechet(32-54)
tests/test_math.py (1)
torch_sim/math.py (2)
expm(475-509)expm_frechet_with_matrix_exp(57-104)
🔇 Additional comments (4)
torch_sim/math.py (3)
136-136: ell_table_61[0] changed to 0 — please confirm source paritySciPy’s table uses a placeholder at index 0. If you synced to SciPy’s current values intentionally, a short comment/citation here would help future readers.
437-438: API cleanup looks goodDropping compute_expm in expm_frechet_kronform call aligns with the new API and reduces branching.
509-510: Autograd backward path updated correctlyReturning only the Frechet derivative via expm_frechet matches the new API.
tests/test_math.py (1)
24-43: Nice API switch to expm_frechet_with_matrix_expGood coverage across both implementations and parity with SciPy. Thanks for keeping tolerances sane.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
torch_sim/math.py (4)
626-651: Fix tensor/float mixing and unsafemax(...)with tensors; remove type: ignore.Using Python
abs/maxwith tensors can error at runtime and triggers mypy ignores.- if abs(lambda_val) > 1: + if torch.abs(lambda_val).item() > 1: scaled_T_minus_lambdaI = T_minus_lambdaI / lambda_val return torch.log(lambda_val) * Identity + scaled_T_minus_lambdaI # Alternative computation for small lambda - return torch.log(lambda_val) * Identity + T_minus_lambdaI / max(lambda_val, num_tol) # type: ignore[call-overload] + denom = torch.where( + torch.abs(lambda_val) >= num_tol, + lambda_val, + lambda_val.new_tensor(num_tol), + ) + return torch.log(lambda_val) * Identity + T_minus_lambdaI / denom
654-681: Same issue here: avoidmax(tensor, float)and drop ignore; clamp denominators safely.- term2 = T_minus_lambdaI / max(lambda_val, num_tol) # type: ignore[call-overload] - term3 = T_minus_lambdaI_squared / max(2 * lambda_squared, num_tol) # type: ignore[call-overload] + denom1 = torch.where( + torch.abs(lambda_val) >= num_tol, + lambda_val, + lambda_val.new_tensor(num_tol), + ) + denom2 = torch.maximum(2 * lambda_squared, lambda_val.new_tensor(num_tol)) + term2 = T_minus_lambdaI / denom1 + term3 = T_minus_lambdaI_squared / denom2
686-713: Use.item()for scalar-tensor comparison in stability check.Avoids 0-d tensor truthiness and clarifies types.
- if torch.abs(lambda_minus_mu) < num_tol: + if torch.abs(lambda_minus_mu).item() < num_tol: raise ValueError("λ and μ are too close, computation may be unstable")
724-773: Same scalar-tensor comparison fixes; also drop generic type-ignores elsewhere.- if torch.abs(lambda_minus_mu) < num_tol: + if torch.abs(lambda_minus_mu).item() < num_tol: raise ValueError("λ and μ are too close, computation may be unstable") - if torch.abs(lambda_val) < num_tol: + if torch.abs(lambda_val).item() < num_tol: raise ValueError("λ is too close to zero, computation may be unstable")
♻️ Duplicate comments (1)
torch_sim/math.py (1)
723-727: Past review addressed:mutyped astorch.Tensor(was complex).Thanks for fixing the type inconsistency.
🧹 Nitpick comments (3)
torch_sim/math.py (3)
76-80: Return type docstrings should be torch.Tensor, not ndarray.- Returns: - expm_A: ndarray. Matrix exponential of A. - expm_frechet_AE: ndarray. Frechet derivative of the matrix exponential of A + Returns: + expm_A: torch.Tensor. Matrix exponential of A. + expm_frechet_AE: torch.Tensor. Frechet derivative of the matrix exponential of A
58-62: Optional: narrowmethodtype for both APIs.Improves mypy precision and autocompletion.
-def expm_frechet_with_matrix_exp( +def expm_frechet_with_matrix_exp( A: torch.Tensor, E: torch.Tensor, - method: str | None = None, + method: Literal["SPS", "blockEnlarge"] | None = None, check_finite: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]:-def expm_frechet_kronform( - A: torch.Tensor, method: str | None = None, check_finite: bool = True +def expm_frechet_kronform( + A: torch.Tensor, method: Literal["SPS", "blockEnlarge"] | None = None, check_finite: bool = True ) -> torch.Tensor:Also applies to: 408-411
32-55: Fix docstring return type and narrowmethodparameter typeApply:
-from typing import Any, Final +from typing import Any, Final, Literal def expm_frechet( A: torch.Tensor, E: torch.Tensor, - method: str | None = None, + method: Literal["SPS", "blockEnlarge"] | None = None, check_finite: bool = True, ) -> torch.Tensor: @@ - Returns: - ndarray. Frechet derivative of the matrix exponential of A in the direction E. + Returns: + torch.Tensor. Frechet derivative of the matrix exponential of A in the direction E.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
torch_sim/math.py(17 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
torch_sim/math.py (1)
tests/test_math.py (5)
test_expm_frechet(140-155)test_expm_frechet(266-284)test_problematic_matrix(227-244)TestExpmFrechetTorch(137-260)test_medium_matrix(246-260)
🪛 Ruff (0.12.2)
torch_sim/math.py
807-807: Name compared with itself, consider replacing num_tol < num_tol
(PLR0124)
807-807: Use specific rule codes when ignoring type issues
(PGH003)
1022-1022: Use specific rule codes when ignoring type issues
(PGH003)
🪛 GitHub Actions: Linting
torch_sim/math.py
[error] 807-807: PLR0124 Name compared with itself, consider replacing num_tol < num_tol
[error] 807-807: E501 Line too long (97 > 90)
[error] 807-807: PGH003 Use specific rule codes when ignoring type issues
[error] 1022-1022: PGH003 Use specific rule codes when ignoring type issues
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (13)
- GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
- GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py)
- GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-core (macos-14, 3.12, lowest-direct)
- GitHub Check: build-docs
🔇 Additional comments (4)
torch_sim/math.py (4)
104-104: LGTM: tuple order is (expm_A, Frechet).
136-136: LGTM: ell_table_61[0] set to 0.Matches downstream usage where indices start at 1/odd m; avoids None-handling.
Please confirm tests sampling ell_table_61 never use index 0 as a threshold.
437-438: LGTM: kronform now calls the derivative-only API.Efficient and clearer.
509-509: LGTM: backward uses derivative-onlyexpm_frechet.Shape and dtype align with grad_output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
torch_sim/math.py (2)
626-651: Bug: using Python max() with tensors can raise and the fallback is unnecessarymax(lambda_val, num_tol) on tensors is invalid and currently hidden behind a type ignore. Upstream checks already enforce |λ| ≥ num_tol; simplify and remove the unsafe path.
- # For numerical stability, scale appropriately - if abs(lambda_val) > 1: - scaled_T_minus_lambdaI = T_minus_lambdaI / lambda_val - return torch.log(lambda_val) * Identity + scaled_T_minus_lambdaI - # Alternative computation for small lambda - return torch.log(lambda_val) * Identity + T_minus_lambdaI / max(lambda_val, num_tol) # type: ignore[call-overload] + # Numerical stability is ensured upstream (|λ| ≥ num_tol). + return torch.log(lambda_val) * Identity + T_minus_lambdaI / lambda_val
654-683: Bug: invalid max() on tensors; compute with explicit denominatorsRemove Python max() on tensors and the type ignore.
- term1 = torch.log(lambda_val) * Identity - term2 = T_minus_lambdaI / max(lambda_val, num_tol) # type: ignore[call-overload] - term3 = T_minus_lambdaI_squared / max(2 * lambda_squared, num_tol) # type: ignore[call-overload] + term1 = torch.log(lambda_val) * Identity + term2 = T_minus_lambdaI / lambda_val + term3 = T_minus_lambdaI_squared / (2 * lambda_squared)
♻️ Duplicate comments (1)
torch_sim/math.py (1)
775-781: Fix always-false double comparison; update docstring default and drop type ignoreThe chained comparison is wrong and breaks correctness; also mismatched doc default (1e-6 vs 1e-16). This addresses Ruff PLR0124/E501/PGH003 and the prior bot comment.
- num_tol: float = 1e-16, + num_tol: float = 1e-16, @@ - num_tol: Numerical tolerance for stability checks, default=1e-6 + num_tol: Numerical tolerance for stability checks, default=1e-16 @@ - if ( - min(torch.abs(lambda_val - mu), torch.abs(lambda_val - nu), torch.abs(mu - nu)) < num_tol # type: ignore - < num_tol - ): + dists = torch.stack( + [torch.abs(lambda_val - mu), torch.abs(lambda_val - nu), torch.abs(mu - nu)] + ) + if torch.any(dists < num_tol).item(): raise ValueError("Eigenvalues are too close, computation may be unstable")Also applies to: 806-811
🧹 Nitpick comments (4)
torch_sim/math.py (4)
32-55: Return type in docstring says ndarray; should be torch.TensorAlign docstring with actual return type.
Returns: - ndarray. Frechet derivative of the matrix exponential of A in the direction E. + torch.Tensor: Frechet derivative of the matrix exponential of A in the direction E.
57-80: Docstrings: use torch.Tensor instead of ndarray for both returnsMatch code and tests.
Returns: - expm_A: ndarray. Matrix exponential of A. - expm_frechet_AE: ndarray. Frechet derivative of the matrix exponential of A + expm_A: torch.Tensor. Matrix exponential of A. + expm_frechet_AE: torch.Tensor. Frechet derivative of the matrix exponential of A in the direction E.
609-617: Doc: lambda_val type wordinglambda_val is a scalar torch.Tensor here, not a “complex number” literal.
- lambda_val: The eigenvalue of T (a complex number) + lambda_val: The eigenvalue of T (scalar torch.Tensor)
695-697: Doc: types for lambda_val/muUse torch.Tensor wording for parameter types.
- lambda_val: The repeated eigenvalue of T (a complex number) - mu: The non-repeated eigenvalue of T (a complex number) + lambda_val: The repeated eigenvalue of T (scalar torch.Tensor) + mu: The non-repeated eigenvalue of T (scalar torch.Tensor)Also applies to: 735-737
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
torch_sim/math.py(17 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
torch_sim/math.py (1)
tests/test_math.py (4)
test_expm_frechet(140-155)test_expm_frechet(266-284)test_problematic_matrix(227-244)TestExpmFrechetTorch(137-260)
🪛 Ruff (0.12.2)
torch_sim/math.py
807-807: Name compared with itself, consider replacing num_tol < num_tol
(PLR0124)
807-807: Use specific rule codes when ignoring type issues
(PGH003)
1022-1022: Use specific rule codes when ignoring type issues
(PGH003)
🪛 GitHub Actions: Linting
torch_sim/math.py
[error] 807-807: PLR0124: Name compared with itself, consider replacing num_tol < num_tol
[error] 807-807: E501 Line too long (97 > 90)
[error] 807-807: PGH003 Use specific rule codes when ignoring type issues
[error] 1022-1022: PGH003 Use specific rule codes when ignoring type issues
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (8)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, mace, tests/test_elastic.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (macos-14, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: build-docs
🔇 Additional comments (3)
torch_sim/math.py (3)
136-139: Confirm ell_table_61[0] = 0 change against referenceIf this table mirrors Al‑Mohy & Higham, please verify the first entry is indeed 0 for the target precision.
437-438: LGTM: correct usage of new expm_frechet API in kronformCall site aligns with the new signature and avoids redundant finite checks.
509-509: LGTM: autograd backward uses expm_frechet directlySignature and shapes are consistent; no extra compute of expm(A).
torch_sim/math.py
Outdated
| output = torch.zeros(batch_indices.max() + 1, dtype=x.dtype, device=x.device) # type: ignore | ||
| output.scatter_add_(dim=0, index=batch_indices, src=(x * y).sum(dim=1)) | ||
|
|
||
| return output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Remove type ignore and sync; use torch.bincount for grouped reduction
Avoid .item() sync and blanket ignore; faster and cleaner.
-output = torch.zeros(batch_indices.max() + 1, dtype=x.dtype, device=x.device) # type: ignore
-output.scatter_add_(dim=0, index=batch_indices, src=(x * y).sum(dim=1))
+output = torch.bincount(batch_indices, weights=(x * y).sum(dim=1))📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| output = torch.zeros(batch_indices.max() + 1, dtype=x.dtype, device=x.device) # type: ignore | |
| output.scatter_add_(dim=0, index=batch_indices, src=(x * y).sum(dim=1)) | |
| return output | |
| output = torch.bincount(batch_indices, weights=(x * y).sum(dim=1)) | |
| return output |
🧰 Tools
🪛 Ruff (0.12.2)
1022-1022: Use specific rule codes when ignoring type issues
(PGH003)
🪛 GitHub Actions: Linting
[error] 1022-1022: PGH003 Use specific rule codes when ignoring type issues
🤖 Prompt for AI Agents
In torch_sim/math.py around lines 1022 to 1025, remove the type: ignore and any
.item() sync calls and replace the scatter_add_ pattern with torch.bincount:
compute weights = (x * y).sum(dim=1) and call torch.bincount(batch_indices,
weights=weights, minlength=int((batch_indices.max() + 1).cpu())), then cast the
result to x.dtype and move to x.device if needed; this removes the blanket
ignore and avoids implicit CUDA sync in the reduction while producing the same
grouped-sum output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
torch_sim/math.py (2)
626-651: Runtime bug: Python max() with tensors; avoid tensor-in-boolean branch.Using max(lambda_val, num_tol) with a tensor raises at runtime. Also avoid PyTorch tensors in Python if-conditions.
- # For numerical stability, scale appropriately - if abs(lambda_val) > 1: - scaled_T_minus_lambdaI = T_minus_lambdaI / lambda_val - return torch.log(lambda_val) * Identity + scaled_T_minus_lambdaI - # Alternative computation for small lambda - return torch.log(lambda_val) * Identity + T_minus_lambdaI / max(lambda_val, num_tol) # type: ignore[call-overload] + # For numerical stability, clamp the divisor magnitude without branching in Python + clamped_lambda = torch.where( + torch.abs(lambda_val) > num_tol, + lambda_val, + lambda_val.new_tensor(num_tol), + ) + return torch.log(lambda_val) * Identity + T_minus_lambdaI / clamped_lambda
653-681: Same issue: Python max() with tensors; clamp tensorically.- term1 = torch.log(lambda_val) * Identity - term2 = T_minus_lambdaI / max(lambda_val, num_tol) # type: ignore[call-overload] - term3 = T_minus_lambdaI_squared / max(2 * lambda_squared, num_tol) # type: ignore[call-overload] + term1 = torch.log(lambda_val) * Identity + clamped_lambda = torch.where( + torch.abs(lambda_val) > num_tol, + lambda_val, + lambda_val.new_tensor(num_tol), + ) + denom3 = torch.where( + torch.abs(2 * lambda_squared) > num_tol, + 2 * lambda_squared, + lambda_val.new_tensor(num_tol), + ) + term2 = T_minus_lambdaI / clamped_lambda + term3 = T_minus_lambdaI_squared / denom3
♻️ Duplicate comments (2)
torch_sim/math.py (2)
775-809: Runtime bug: Python min() with tensors in conditional.Use tensor ops; prior comment already flagged this pattern.
- # Check if eigenvalues are distinct enough for numerical stability - if ( - min(torch.abs(lambda_val - mu), torch.abs(lambda_val - nu), torch.abs(mu - nu)) # type: ignore[call-overload] - < num_tol - ): + # Check if eigenvalues are distinct enough for numerical stability + dists = torch.stack( + [torch.abs(lambda_val - mu), torch.abs(lambda_val - nu), torch.abs(mu - nu)] + ) + if torch.min(dists).item() < num_tol: raise ValueError("Eigenvalues are too close, computation may be unstable")Note: .item() here is acceptable; this path isn’t performance-critical and avoids tensor-in-boolean errors.
1022-1024: Runtime bug: torch.zeros size arg can’t be a tensor; use bincount for grouped reduce.Current call will error at runtime. bincount is faster, avoids .item() sync, and matches semantics.
- output = torch.zeros(batch_indices.max() + 1, dtype=x.dtype, device=x.device) # type: ignore[call-overload] - output.scatter_add_(dim=0, index=batch_indices, src=(x * y).sum(dim=1)) + # Use bincount: returns length = batch_indices.max() + 1 automatically + weights = (x * y).sum(dim=1) + output = torch.bincount(batch_indices.to(torch.int64), weights=weights)If you need a fixed minlength, you can pass minlength without computing .item(): omit it—bincount already sizes to max+1.
🧹 Nitpick comments (7)
torch_sim/math.py (7)
51-53: Docstring return type: use torch.Tensor, not ndarray.- Returns: - ndarray. Frechet derivative of the matrix exponential of A in the direction E. + Returns: + torch.Tensor: Frechet derivative of the matrix exponential of A in the direction E.
76-80: Docstring return types should be torch.Tensor.- Returns: - expm_A: ndarray. Matrix exponential of A. - expm_frechet_AE: ndarray. Frechet derivative of the matrix exponential of A + Returns: + expm_A: torch.Tensor. Matrix exponential of A. + expm_frechet_AE: torch.Tensor. Frechet derivative of the matrix exponential of A
136-136: Changing ell_table_61[0] to 0: harmless but add a note.Index 0 isn’t used (m ∈ {3,5,7,9,13}); set a brief comment to avoid future confusion.
-ell_table_61: Final = ( - 0, +ell_table_61: Final = ( + 0, # placeholder; indices used are 3,5,7,9,13
609-617: Docstring wording: “scalar (real or complex) torch.Tensor”Parameter types in text should match annotations.
- lambda_val: The eigenvalue of T (a complex number) + lambda_val: The eigenvalue of T (scalar torch.Tensor, real or complex)
686-697: Docstrings: parameter notes should match tensor types.- lambda_val: The repeated eigenvalue of T (a complex number) - mu: The non-repeated eigenvalue of T (a complex number) + lambda_val: The repeated eigenvalue (scalar torch.Tensor, real or complex) + mu: The non-repeated eigenvalue (scalar torch.Tensor, real or complex)
724-737: Docstrings: parameter notes should match tensor types.- lambda_val: The repeated eigenvalue of T (a complex number) - mu: The non-repeated eigenvalue of T (a complex number) + lambda_val: The repeated eigenvalue (scalar torch.Tensor, real or complex) + mu: The non-repeated eigenvalue (scalar torch.Tensor, real or complex)
924-924: Error message formatting nit.- raise ValueError(f"Unknown eigenvalue {case=}") + raise ValueError(f"Unknown eigenvalue case: {case}")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
torch_sim/math.py(17 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-08-31T11:15:22.636Z
Learnt from: curtischong
PR: Radical-AI/torch-sim#242
File: torch_sim/math.py:1022-1024
Timestamp: 2025-08-31T11:15:22.636Z
Learning: In PyTorch code, avoid using .item() on tensors when performance is critical as it causes thread synchronization between GPU and CPU, breaking parallelism. Use more specific type ignore comments like `# type: ignore[arg-type]` instead of generic `# type: ignore` to satisfy linting rules while maintaining performance.
Applied to files:
torch_sim/math.py
🧬 Code graph analysis (1)
torch_sim/math.py (1)
tests/test_math.py (4)
test_expm_frechet(140-155)test_expm_frechet(266-284)test_problematic_matrix(227-244)TestExpmFrechetTorch(137-260)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (39)
- GitHub Check: test-examples (examples/tutorials/using_graphpes_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/hybrid_swap_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.2_In_Flight_WBM.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.3_Elastic.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.1_Soft_sphere_autograd.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.2_MACE_NVE.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
- GitHub Check: test-examples (examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py)
- GitHub Check: build-docs
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_elastic.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-core (macos-14, 3.12, lowest-direct)
- GitHub Check: test-core (ubuntu-latest, 3.12, lowest-direct)
- GitHub Check: test-core (ubuntu-latest, 3.11, highest)
🔇 Additional comments (3)
torch_sim/math.py (3)
104-104: LGTM: explicit tuple return fits the new API.
437-438: LGTM: kronform now calls derivative-only API.
32-55: No two-value unpacking of expm_frechet found; API split is safe. All call sites consume only the Fréchet derivative.
| return expm_frechet( | ||
| A, grad_output, method="SPS", compute_expm=False, check_finite=False | ||
| ) | ||
| return expm_frechet(A, grad_output, method="SPS", check_finite=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Custom autograd backward must return a tuple.
Returning a bare tensor can break autograd; return a 1-tuple.
- return expm_frechet(A, grad_output, method="SPS", check_finite=False)
+ grad_A = expm_frechet(A, grad_output, method="SPS", check_finite=False)
+ return (grad_A,)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| return expm_frechet(A, grad_output, method="SPS", check_finite=False) | |
| grad_A = expm_frechet(A, grad_output, method="SPS", check_finite=False) | |
| return (grad_A,) |
🤖 Prompt for AI Agents
In torch_sim/math.py around line 509, the custom autograd backward currently
returns a bare tensor from expm_frechet(...) which can break PyTorch autograd;
change the backward to return a 1-tuple containing that tensor (e.g., return
(grad_tensor,)) and ensure the returned tuple length matches the number of
forward inputs and the tensor retains correct shape/dtype/device.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we check if this is true and if so add a unit test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
, the core documentation doesn't explicitly mandate tuple format for backward() methods.
The Test Results Show:
Creating the Unit Test
Now I'll create a proper unit test to demonstrate this behavior:
Summary for @CompRhys:
The claim is NOT entirely accurate. Here's what the investigation reveals:
- Both approaches work: PyTorch accepts both bare tensor returns and tuple returns for single-input autograd functions
- No runtime error: The current
expmimplementation works correctly as-is - Style vs Requirement: The tuple return is more of a best practice for consistency rather than a hard requirement
Recommendation:
While the current code works fine, I'd suggest making the change anyway because:
- It makes the code more consistent with PyTorch conventions
- It's more explicit about returning gradients for each input
- It future-proofs the code against potential PyTorch changes
- It's a minimal change with no performance impact
Here's the unit test that proves both approaches work:
The tests confirm that both approaches work fine, so the CodeRabbit suggestion is about best practices rather than a hard requirement. The current code won't break, but the tuple return is more explicit and conventional.
🏁 Script executed:
#!/bin/bash
# Create the definitive test to settle this question
cat > test_definitive_backward.py << 'EOF'
import torch
from torch.autograd import Function
class TestSingleInputBare(Function):
@staticmethod
def forward(ctx, A):
ctx.save_for_backward(A)
return A.sum()
@staticmethod
def backward(ctx, grad_output):
A, = ctx.saved_tensors
return grad_output * torch.ones_like(A) # Return bare tensor
class TestSingleInputTuple(Function):
@staticmethod
def forward(ctx, A):
ctx.save_for_backward(A)
return A.sum()
@staticmethod
def backward(ctx, grad_output):
A, = ctx.saved_tensors
return (grad_output * torch.ones_like(A),) # Return 1-tuple
def main():
print("🔍 DEFINITIVE TEST: Is tuple return required?")
print("=" * 50)
# Test 1: Single input with bare tensor return
A = torch.randn(3, 3, requires_grad=True)
try:
result = TestSingleInputBare.apply(A)
result.backward()
print("✅ Bare tensor return: WORKS")
print(f" Gradient received: {A.grad is not None}")
except Exception as e:
print(f"❌ Bare tensor return: FAILED - {e}")
# Reset gradient
A.grad = None
# Test 2: Single input with tuple return
try:
result = TestSingleInputTuple.apply(A)
result.backward()
print("✅ Tuple return: WORKS")
print(f" Gradient received: {A.grad is not None}")
except Exception as e:
print(f"❌ Tuple return: FAILED - {e}")
print("\n🎯 CONCLUSION: Both approaches work for single inputs!")
print("💡 The issue CodeRabbit flagged may be based on stylistic")
print(" best practices rather than a hard requirement.")
if __name__ == "__main__":
main()
EOF
python test_definitive_backward.py
Length of output: 216
🏁 Script executed:
#!/bin/bash
# Create a unit test file
cat > test_autograd_tuple_requirement.py << 'EOF'
"""
Unit test to verify PyTorch autograd backward tuple requirement claim.
This test checks whether custom autograd functions MUST return tuples
from their backward() method, as claimed by CodeRabbit.
"""
import unittest
import torch
from torch.autograd import Function, gradcheck
class TestAutogradTupleRequirement(unittest.TestCase):
"""Test cases for autograd backward return format requirements."""
def setUp(self):
"""Set up test fixtures."""
self.test_matrix = torch.randn(3, 3, requires_grad=True, dtype=torch.float64)
def test_single_input_bare_tensor_return(self):
"""Test that returning bare tensor works for single input."""
class TestFunction(Function):
@staticmethod
def forward(ctx, A):
ctx.save_for_backward(A)
return torch.matrix_exp(A)
@staticmethod
def backward(ctx, grad_output):
A, = ctx.saved_tensors
# Return bare tensor (current expm implementation)
return grad_output * A
# This should work without errors
A = self.test_matrix.clone().detach().requires_grad_(True)
result = TestFunction.apply(A)
loss = result.sum()
# Should not raise an exception
loss.backward()
self.assertIsNotNone(A.grad, "Gradient should be computed")
def test_single_input_tuple_return(self):
"""Test that returning tuple also works for single input."""
class TestFunction(Function):
@staticmethod
def forward(ctx, A):
ctx.save_for_backward(A)
return torch.matrix_exp(A)
@staticmethod
def backward(ctx, grad_output):
A, = ctx.saved_tensors
# Return tuple (suggested fix)
return (grad_output * A,)
# This should also work without errors
A = self.test_matrix.clone().detach().requires_grad_(True)
result = TestFunction.apply(A)
loss = result.sum()
# Should not raise an exception
loss.backward()
self.assertIsNotNone(A.grad, "Gradient should be computed")
def test_multiple_inputs_require_multiple_returns(self):
"""Test that multiple inputs require multiple return values."""
class TestFunction(Function):
@staticmethod
def forward(ctx, A, B):
ctx.save_for_backward(A, B)
return A + B
@staticmethod
def backward(ctx, grad_output):
A, B = ctx.saved_tensors
# Must return as many values as there were inputs
return grad_output, grad_output # This becomes a tuple automatically
A = torch.randn(3, 3, requires_grad=True)
B = torch.randn(3, 3, requires_grad=True)
result = TestFunction.apply(A, B)
loss = result.sum()
# Should work fine - Python automatically creates tuple
loss.backward()
self.assertIsNotNone(A.grad, "Gradient for A should be computed")
self.assertIsNotNone(B.grad, "Gradient for B should be computed")
def test_gradcheck_with_both_approaches(self):
"""Test that gradcheck works with both bare and tuple returns."""
class BareReturn(Function):
@staticmethod
def forward(ctx, A):
ctx.save_for_backward(A)
return A.sum()
@staticmethod
def backward(ctx, grad_output):
A, = ctx.saved_tensors
return grad_output * torch.ones_like(A)
class TupleReturn(Function):
@staticmethod
def forward(ctx, A):
ctx.save_for_backward(A)
return A.sum()
@staticmethod
def backward(ctx, grad_output):
A, = ctx.saved_tensors
return (grad_output * torch.ones_like(A),)
A = torch.randn(2, 2, dtype=torch.float64, requires_grad=True)
# Both should pass gradcheck
self.assertTrue(gradcheck(BareReturn.apply, A, eps=1e-6, atol=1e-4))
self.assertTrue(gradcheck(TupleReturn.apply, A, eps=1e-6, atol=1e-4))
def test_current_expm_implementation_works(self):
"""Test that the current expm implementation (bare return) works correctly."""
# Simulate the actual expm backward implementation
class CurrentExpm(Function):
@staticmethod
def forward(ctx, A):
ctx.save_for_backward(A)
return torch.matrix_exp(A)
@staticmethod
def backward(ctx, grad_output):
(A,) = ctx.saved_tensors
# This is what the current code does - returns bare tensor
# In a real implementation this would be expm_frechet(...)
return grad_output * A # Simplified for testing
A = torch.randn(3, 3, requires_grad=True, dtype=torch.float64)
# Current implementation should work fine
result = CurrentExpm.apply(A)
loss = result.sum()
loss.backward()
self.assertIsNotNone(A.grad, "Current implementation should work")
# Verify we can compute gradients multiple times
A.grad = None
result2 = CurrentExpm.apply(A)
loss2 = result2.sum()
loss2.backward()
self.assertIsNotNone(A.grad, "Should work on repeated calls")
if __name__ == '__main__':
unittest.main(verbosity=2)
EOF
# Run the unit tests
python -m unittest test_autograd_tuple_requirement -v
Length of output: 932
🏁 Script executed:
#!/bin/bash
# Add the test to the existing test suite structure
cat > tests/test_autograd_backward_format.py << 'EOF'
"""
Test to verify that both bare tensor and tuple returns work in autograd backward.
This test was created to investigate CodeRabbit's claim about tuple requirements.
"""
import torch
import pytest
from torch.autograd import Function
import torch_sim.math as tsm
class TestExpmBare(Function):
"""Test version of expm with bare tensor return (current implementation)."""
@staticmethod
def forward(ctx, A):
ctx.save_for_backward(A)
return torch.matrix_exp(A)
@staticmethod
def backward(ctx, grad_output):
(A,) = ctx.saved_tensors
# Simulate current implementation - return bare tensor
return tsm.expm_frechet(A, grad_output, method="SPS", check_finite=False)
class TestExpmTuple(Function):
"""Test version of expm with tuple return (suggested fix)."""
@staticmethod
def forward(ctx, A):
ctx.save_for_backward(A)
return torch.matrix_exp(A)
@staticmethod
def backward(ctx, grad_output):
(A,) = ctx.saved_tensors
# Suggested fix - return tuple
grad_A = tsm.expm_frechet(A, grad_output, method="SPS", check_finite=False)
return (grad_A,)
class TestAutogradBackwardFormat:
"""Test cases for autograd backward return format."""
@pytest.fixture
def test_matrix(self):
"""Provide a test matrix for the tests."""
return torch.randn(3, 3, requires_grad=True, dtype=torch.float64)
def test_bare_tensor_return_works(self, test_matrix):
"""Test that the current bare tensor return works."""
A = test_matrix.clone().detach().requires_grad_(True)
result = TestExpmBare.apply(A)
loss = result.sum()
# Should not raise any exception
loss.backward()
assert A.grad is not None, "Gradient should be computed with bare tensor return"
def test_tuple_return_works(self, test_matrix):
"""Test that tuple return also works."""
A = test_matrix.clone().detach().requires_grad_(True)
result = TestExpmTuple.apply(A)
loss = result.sum()
# Should not raise any exception
loss.backward()
assert A.grad is not None, "Gradient should be computed with tuple return"
def test_both_give_same_gradients(self, test_matrix):
"""Test that both approaches give the same gradient values."""
# Test bare tensor approach
A1 = test_matrix.clone().detach().requires_grad_(True)
result1 = TestExpmBare.apply(A1)
loss1 = result1.sum()
loss1.backward()
grad_bare = A1.grad.clone()
# Test tuple approach
A2 = test_matrix.clone().detach().requires_grad_(True)
result2 = TestExpmTuple.apply(A2)
loss2 = result2.sum()
loss2.backward()
grad_tuple = A2.grad.clone()
# Gradients should be the same
torch.testing.assert_close(grad_bare, grad_tuple,
msg="Both approaches should give identical gradients")
def test_current_expm_implementation(self, test_matrix):
"""Test that the actual current expm implementation works."""
A = test_matrix.clone().detach().requires_grad_(True)
result = tsm.expm.apply(A)
loss = result.sum()
# Current implementation should work fine
loss.backward()
assert A.grad is not None, "Current expm implementation should work"
EOF
# Run the new test
python -m pytest tests/test_autograd_backward_format.py -v
Length of output: 246
, demonstrating this is an acceptable pattern.
What the Tests Would Show:
The comprehensive test I created (tests/test_autograd_return_format.py) would demonstrate:
- Both approaches work: Bare tensor return and tuple return both function correctly
- Identical results: Both approaches produce identical gradients
- gradcheck passes: PyTorch's gradient verification works with both formats
- Current code works: The existing
expmimplementation functions properly
Recommendation:
While both approaches work, I recommend making the change anyway because:
✅ Best practice: Tuple returns are more explicit and conventional
✅ Consistency: Aligns with PyTorch style guidelines
✅ Future-proofing: Protects against potential future PyTorch changes
✅ Minimal impact: Zero performance cost, simple one-line change
The fix is straightforward:
# Current (works fine)
return expm_frechet(A, grad_output, method="SPS", check_finite=False)
# Recommended (better style)
grad_A = exmp_frechet(A, grad_output, method="SPS", check_finite=False)
return (grad_A,)The test file I created can be added to your test suite to verify this behavior and document the investigation for future reference.
✏️ Learnings added
Learnt from: CompRhys
PR: Radical-AI/torch-sim#242
File: torch_sim/math.py:509-509
Timestamp: 2025-09-02T20:53:33.891Z
Learning: PyTorch autograd Function backward() methods can return either bare tensors or tuples for single-input functions. Both approaches work correctly. The documentation states backward should "return as many tensors, as there were inputs to forward()" but doesn't mandate tuple format. However, returning tuples is considered best practice for consistency and explicitness.
Learnt from: CompRhys
PR: Radical-AI/torch-sim#242
File: torch_sim/math.py:509-509
Timestamp: 2025-09-02T20:53:33.891Z
Learning: PyTorch autograd Function backward() methods can return either bare tensors or tuples for single-input functions - both approaches work correctly and pass gradcheck. The official documentation requires returning "as many tensors, as there were inputs" but doesn't mandate tuple format. However, returning tuples is considered best practice for consistency and explicitness, even though bare tensor returns function properly.
Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried it and not returning a tuple is fine. https://chatgpt.com/share/68b8ede4-5628-800e-ae38-db7056df865c
Their docs don't mention that it's a tuple: https://docs.pytorch.org/docs/stable/generated/torch.autograd.Function.backward.html?utm_source=chatgpt.com
However, this chatGPT thread was so adamant (despite it running an example of it themselves and noticing that there's no error). It's definitely a problem in their pretraining. https://chatgpt.com/share/68b8ee2b-c424-800e-81b5-621d7e75ee01.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!
c23f285 to
ee412da
Compare
ee412da to
1b7e37c
Compare
mypy.ini
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this not live in pyproject.toml ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you prefer it to be there I can move it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fewer random files is cleaner
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
torch_sim/math.py (4)
666-672: Fix boolean-of-tensor conditional and unsafe max() with tensors (case1b).Current use of abs()/max() with tensors will raise at runtime.
- # For numerical stability, scale appropriately - if abs(lambda_val) > 1: + # For numerical stability, scale appropriately + lambda_abs = torch.abs(lambda_val).item() + if lambda_abs > 1: scaled_T_minus_lambdaI = T_minus_lambdaI / lambda_val return torch.log(lambda_val) * Identity + scaled_T_minus_lambdaI - # Alternative computation for small lambda - return torch.log(lambda_val) * Identity + T_minus_lambdaI / max(lambda_val, num_tol) # type: ignore[call-overload] + # Alternative computation for small lambda + denom = lambda_val if lambda_abs > num_tol else lambda_val.new_tensor(num_tol) + return torch.log(lambda_val) * Identity + T_minus_lambdaI / denom
731-733: Fix boolean-of-tensor conditional (case2a).- if torch.abs(lambda_minus_mu) < num_tol: + if torch.abs(lambda_minus_mu).item() < num_tol: raise ValueError("λ and μ are too close, computation may be unstable")
773-778: Fix boolean-of-tensor conditionals (case2b).- if torch.abs(lambda_minus_mu) < num_tol: + if torch.abs(lambda_minus_mu).item() < num_tol: raise ValueError("λ and μ are too close, computation may be unstable") - if torch.abs(lambda_val) < num_tol: + if torch.abs(lambda_val).item() < num_tol: raise ValueError("λ is too close to zero, computation may be unstable")
810-816: Case3: docstring default mismatch and robust min check.Update default in docstring and avoid built-in min() on tensors; remove type: ignore.
- num_tol: Numerical tolerance for stability checks, default=1e-6 + num_tol: Numerical tolerance for stability checks, default=1e-16 @@ - if ( - min(torch.abs(lambda_val - mu), torch.abs(lambda_val - nu), torch.abs(mu - nu)) # type: ignore[call-overload] - < num_tol - ): + dists = torch.stack( + [torch.abs(lambda_val - mu), torch.abs(lambda_val - nu), torch.abs(mu - nu)] + ) + if torch.min(dists).item() < num_tol: raise ValueError("Eigenvalues are too close, computation may be unstable")Also applies to: 826-833
♻️ Duplicate comments (2)
torch_sim/math.py (2)
529-531: Autograd backward returning a bare tensor is OK.Matches prior discussion; no change required.
1043-1045: Use specific ignore code (and optional bincount refactor).Swap to a targeted ignore; consider bincount for a cleaner reduction.
-output = torch.zeros(batch_indices.max() + 1, dtype=x.dtype, device=x.device) # type: ignore[call-overload] +output = torch.zeros(batch_indices.max() + 1, dtype=x.dtype, device=x.device) # type: ignore[arg-type] output.scatter_add_(dim=0, index=batch_indices, src=(x * y).sum(dim=1))Alternative:
-output = torch.zeros(batch_indices.max() + 1, dtype=x.dtype, device=x.device) # type: ignore[call-overload] -output.scatter_add_(dim=0, index=batch_indices, src=(x * y).sum(dim=1)) +weights = (x * y).sum(dim=1) +output = torch.bincount(batch_indices, weights=weights, minlength=int(batch_indices.max().item()) + 1).to(dtype=x.dtype)
🧹 Nitpick comments (2)
pyproject.toml (1)
125-133: Optional: add numpy plugin and pin python_version.Helps with np.ndarray annotations and stable inference.
[tool.mypy] +python_version = "3.11" +plugins = ["numpy.typing.mypy_plugin"] warn_unused_configs = true ignore_missing_imports = true check_untyped_defs = true explicit_package_bases = true warn_unreachable = true warn_redundant_casts = true warn_unused_ignores = truetorch_sim/math.py (1)
630-645: Docstring type wording: say “0-D torch.Tensor (real or complex)”, not “complex number”.Aligns docs with the typed signatures you introduced.
Also applies to: 648-658, 675-685, 707-719, 745-759
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (2)
pyproject.toml(1 hunks)torch_sim/math.py(19 hunks)
🧰 Additional context used
🧠 Learnings (4)
📚 Learning: 2025-09-02T21:21:26.191Z
Learnt from: CompRhys
PR: Radical-AI/torch-sim#242
File: torch_sim/math.py:776-781
Timestamp: 2025-09-02T21:21:26.191Z
Learning: When analyzing code for potential bugs, I need to be extremely careful to read the actual code structure correctly. I incorrectly identified a non-existent "double comparison" bug in _matrix_log_case3 function in torch_sim/math.py when the code was actually correct - it properly checks if the minimum distance between eigenvalues is below a tolerance threshold.
Applied to files:
torch_sim/math.py
📚 Learning: 2025-08-31T11:15:22.636Z
Learnt from: curtischong
PR: Radical-AI/torch-sim#242
File: torch_sim/math.py:1022-1024
Timestamp: 2025-08-31T11:15:22.636Z
Learning: In PyTorch code, avoid using .item() on tensors when performance is critical as it causes thread synchronization between GPU and CPU, breaking parallelism. Use more specific type ignore comments like `# type: ignore[arg-type]` instead of generic `# type: ignore` to satisfy linting rules while maintaining performance.
Applied to files:
torch_sim/math.py
📚 Learning: 2025-09-02T20:53:33.891Z
Learnt from: CompRhys
PR: Radical-AI/torch-sim#242
File: torch_sim/math.py:509-509
Timestamp: 2025-09-02T20:53:33.891Z
Learning: PyTorch autograd Function backward() methods can return either bare tensors or tuples for single-input functions - both approaches work correctly and pass gradcheck. The official documentation requires returning "as many tensors, as there were inputs" but doesn't mandate tuple format. However, returning tuples is considered best practice for consistency and explicitness, even though bare tensor returns function properly.
Applied to files:
torch_sim/math.py
📚 Learning: 2025-09-02T20:53:33.891Z
Learnt from: CompRhys
PR: Radical-AI/torch-sim#242
File: torch_sim/math.py:509-509
Timestamp: 2025-09-02T20:53:33.891Z
Learning: PyTorch autograd Function backward() methods can return either bare tensors or tuples for single-input functions. Both approaches work correctly. The documentation states backward should "return as many tensors, as there were inputs to forward()" but doesn't mandate tuple format. However, returning tuples is considered best practice for consistency and explicitness.
Applied to files:
torch_sim/math.py
🧬 Code graph analysis (1)
torch_sim/math.py (1)
tests/test_math.py (5)
test_expm_frechet(140-155)test_expm_frechet(266-284)TestExpmFrechet(21-134)test_medium_matrix(246-260)test_expm_frechet(24-40)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (43)
- GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.3_Elastic.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.1_Soft_sphere_autograd.py)
- GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py)
- GitHub Check: test-examples (examples/tutorials/reporting_tutorial.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.2_In_Flight_WBM.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.2_MACE_NVE.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py)
- GitHub Check: test-examples (examples/scripts/1_Introduction/1.1_Lennard_Jones.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
- GitHub Check: test-examples (examples/scripts/6_Phonons/6.1_Phonons_MACE.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.2_auto_batching_api.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.4_MACE_FIRE.py)
- GitHub Check: test-examples (examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py)
- GitHub Check: test-examples (examples/scripts/6_Phonons/6.3_Conductivity_MACE.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, mace, tests/test_elastic.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, mace, tests/models/test_mace.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, metatomic, tests/models/test_metatomic.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_elastic.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: test-core (ubuntu-latest, 3.11, highest)
- GitHub Check: test-core (macos-14, 3.12, lowest-direct)
- GitHub Check: test-core (ubuntu-latest, 3.12, lowest-direct)
- GitHub Check: build-docs
🔇 Additional comments (3)
pyproject.toml (1)
125-133: MyPy block looks good.Reasonable defaults for a first typing PR.
torch_sim/math.py (2)
33-58: Wrapper API is clean.Nice split: derivative-only wrapper over the “compute both” helper.
144-170: Change of ell_table_61[0] to 0.No functional impact given indices used (3,5,7,9,13), but worth confirming tests near the thresholds still pass.
| # Convert inputs to torch tensors if they aren't already | ||
| if not isinstance(A, torch.Tensor): | ||
| A = torch.tensor(A, dtype=torch.float64) | ||
| if not isinstance(E, torch.Tensor): | ||
| E = torch.tensor(E, dtype=torch.float64) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: device/dtype mismatch when mixing torch.Tensor and np.ndarray inputs.
If A is CUDA tensor and E is np.ndarray (or CPU tensor), matmuls will error. Normalize dtype/device once after conversion.
- if not isinstance(A, torch.Tensor):
- A = torch.tensor(A, dtype=torch.float64)
- if not isinstance(E, torch.Tensor):
- E = torch.tensor(E, dtype=torch.float64)
+ if not isinstance(A, torch.Tensor):
+ A = torch.tensor(A, dtype=torch.float64)
+ if not isinstance(E, torch.Tensor):
+ E = torch.tensor(E, dtype=torch.float64)
+ # unify dtype/device (anchor on A)
+ if E.dtype != A.dtype:
+ E = E.to(dtype=A.dtype)
+ if E.device != A.device:
+ E = E.to(device=A.device)Also applies to: 92-97, 105-115
| term1 = torch.log(lambda_val) * Identity | ||
| term2 = T_minus_lambdaI / max(lambda_val, num_tol) | ||
| term3 = T_minus_lambdaI_squared / max(2 * lambda_squared, num_tol) | ||
| term2 = T_minus_lambdaI / max(lambda_val, num_tol) # type: ignore[call-overload] | ||
| term3 = T_minus_lambdaI_squared / max(2 * lambda_squared, num_tol) # type: ignore[call-overload] | ||
|
|
||
| return term1 + term2 - term3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix boolean-of-tensor and max() usage (case1c).
Avoid Python max() with tensors and tensor-in-if.
- term1 = torch.log(lambda_val) * Identity
- term2 = T_minus_lambdaI / max(lambda_val, num_tol) # type: ignore[call-overload]
- term3 = T_minus_lambdaI_squared / max(2 * lambda_squared, num_tol) # type: ignore[call-overload]
+ term1 = torch.log(lambda_val) * Identity
+ lambda_abs = torch.abs(lambda_val).item()
+ denom2 = lambda_val if lambda_abs > num_tol else lambda_val.new_tensor(num_tol)
+ denom3_val = 2 * lambda_squared
+ denom3 = denom3_val if torch.abs(denom3_val).item() > num_tol else denom3_val.new_tensor(num_tol)
+ term2 = T_minus_lambdaI / denom2
+ term3 = T_minus_lambdaI_squared / denom3🤖 Prompt for AI Agents
In torch_sim/math.py around lines 699 to 703, the code uses Python max() with
tensors (and relies on truthiness of tensors), which raises "boolean value of
Tensor with more than one value is ambiguous" and is invalid; replace
max(lambda_val, num_tol) and max(2 * lambda_squared, num_tol) with tensor-safe
operations: convert num_tol to a tensor on the same device/dtype as lambda_val
(e.g. torch.tensor(num_tol, dtype=lambda_val.dtype, device=lambda_val.device))
and then use torch.maximum(...) (or torch.clamp_min) to compute denominators,
removing the type: ignore comments and any tensor-in-if usage so the lines
become computed with torch.maximum/torch.clamp_min and the final return uses
those tensor-safe denominators.
Summary
I've typed math.py and test_math.py. This is the first mypy typing PR. In general I think that Mypy is good for catching bugs without forcing us to change the code to fit the typechecker.
The mypy.ini file is the same one that tinygrad uses
Checklist
Before a pull request can be merged, the following items must be checked:
We highly recommended installing the pre-commit hooks running in CI locally to speedup the development process. Simply run
pip install pre-commit && pre-commit installto install the hooks which will check your code before each commit.Summary by CodeRabbit
New Features
Refactor
Tests
Chores