Skip to content

Conversation

@curtischong
Copy link
Collaborator

@curtischong curtischong commented Aug 30, 2025

Summary

I've typed math.py and test_math.py. This is the first mypy typing PR. In general I think that Mypy is good for catching bugs without forcing us to change the code to fit the typechecker.

The mypy.ini file is the same one that tinygrad uses

Checklist

Before a pull request can be merged, the following items must be checked:

  • Doc strings have been added in the Google docstring format.
  • Run ruff on your code.
  • Tests have been added for any new functionality or bug fixes.

We highly recommended installing the pre-commit hooks running in CI locally to speedup the development process. Simply run pip install pre-commit && pre-commit install to install the hooks which will check your code before each commit.

Summary by CodeRabbit

  • New Features

    • Added expm_frechet_with_matrix_exp to compute both matrix exponential and its Frechet derivative in one call.
    • Expanded NumPy input support for expm_frechet, expm_frechet_kronform, and expm_cond.
  • Refactor

    • Streamlined API: expm_frechet now returns only the derivative and no longer accepts the compute_expm parameter.
    • Updated internal call sites and autograd path to the new API.
  • Tests

    • Updated test suite to cover both supported methods via the new API.
  • Chores

    • Introduced MyPy configuration for stricter type checking.

@cla-bot cla-bot bot added the cla-signed Contributor license agreement signed label Aug 30, 2025
@coderabbitai
Copy link

coderabbitai bot commented Aug 30, 2025

Walkthrough

Introduces a new API expm_frechet_with_matrix_exp returning (expm(A), Frechet(A,E)), refactors expm_frechet to delegate and return only the derivative, removes compute_expm flag across call sites, broadens typing to accept NumPy arrays, adjusts autograd and kronform callers, updates tests to the new API, and adds MyPy settings.

Changes

Cohort / File(s) Summary
Core math API and typing
torch_sim/math.py
Added expm_frechet_with_matrix_exp(A,E,method,check_finite) returning (expm, frechet); refactored expm_frechet to delegate and return only the derivative; removed compute_expm flag usage; updated expm_frechet_kronform, expm.backward, and expm_cond to new signatures; expanded inputs to accept np.ndarray; changed ell_table_61 first element from None to 0; adjusted internal 3×3 log helpers’ annotations.
Tests updated to new API
tests/test_math.py
Replaced expm_frechet usage with expm_frechet_with_matrix_exp; iterates over method in {"SPS","blockEnlarge"}; computes expected values with SciPy expm; maintains equality checks between methods; no gradient logic changes.
Optimizers call-site tweaks
torch_sim/optimizers.py
Removed compute_expm=False argument; call tsm.expm_frechet(...) with new signature in _vv_fire_step and _ase_fire_step.
Tooling / typing config
pyproject.toml
Added [tool.mypy] configuration enabling stricter type-checking warnings; no runtime impact.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Caller as Caller (tests/optimizers)
  participant Math as torch_sim.math
  Caller->>Math: expm_frechet_with_matrix_exp(A, E, method)
  activate Math
  Note over Math: Validate inputs, shapes, types (torch/NumPy)
  Math-->>Caller: (expm(A), L_expm(A)[E])
  deactivate Math
Loading
sequenceDiagram
  autonumber
  participant PyTorch as Autograd Engine
  participant Math as torch_sim.math.expm (Function)
  participant F as torch_sim.math.expm_frechet
  PyTorch->>Math: backward(grad_output)
  activate Math
  Math->>F: expm_frechet(A, grad_output, method="SPS", check_finite=False)
  activate F
  F-->>Math: L_expm(A)[grad_output]
  deactivate F
  Math-->>PyTorch: gradient wrt A
  deactivate Math
  Note over F: Wrapper delegates to expm_frechet_with_matrix_exp and returns derivative only
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

I hop through tensors, swift and keen,
New paths for expm now seen.
Two doors: expm and L side-by-side—
I thump my foot with matrix pride.
My whiskers twitch at MyPy’s call,
“No flags,” I squeak, “we’ve wrapped it all!” 🐇✨

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch add-mypy

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@curtischong curtischong changed the title Add mypy mypy type math.py and test_math.py Aug 30, 2025
@curtischong curtischong marked this pull request as ready for review August 31, 2025 02:22
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (6)
torch_sim/math.py (3)

541-563: Device mismatch risk when using torch.arange for boolean masks

torch.arange defaults to CPU; unique_vals/counts follow eigenvalues’ device (CPU/GPU). Mixing devices in boolean indexing will error.

-                unique_vals = unique_vals[
-                    ~(close_mask & torch.arange(len(close_mask)) != i)
-                ]
-                counts = counts[~(close_mask & torch.arange(len(counts)) != i)]
+                idx = torch.arange(len(close_mask), device=close_mask.device)
+                unique_vals = unique_vals[~(close_mask & (idx != i))]
+                counts = counts[~(close_mask & (idx != i))]

627-651: Use tensor-safe clamping instead of Python max with tensors (runtime error)

max(lambda_val, num_tol) will raise at runtime for torch.Tensors. Use a tensor tolerance and torch.where/torch.clamp-like logic.

-    if abs(lambda_val) > 1:
+    if torch.abs(lambda_val) > 1:
         scaled_T_minus_lambdaI = T_minus_lambdaI / lambda_val
         return torch.log(lambda_val) * Identity + scaled_T_minus_lambdaI
     # Alternative computation for small lambda
-    return torch.log(lambda_val) * Identity + T_minus_lambdaI / max(lambda_val, num_tol)  # type: ignore[call-overload]
+    tol = torch.as_tensor(num_tol, dtype=lambda_val.dtype, device=lambda_val.device)
+    denom = torch.where(torch.abs(lambda_val) > tol, lambda_val, tol)
+    return torch.log(lambda_val) * Identity + T_minus_lambdaI / denom

653-683: Same tensor-safe denominator fix needed here

Two denominators must avoid Python max.

-    term2 = T_minus_lambdaI / max(lambda_val, num_tol)  # type: ignore[call-overload]
-    term3 = T_minus_lambdaI_squared / max(2 * lambda_squared, num_tol)  # type: ignore[call-overload]
+    tol = torch.as_tensor(num_tol, dtype=lambda_val.dtype, device=lambda_val.device)
+    denom1 = torch.where(torch.abs(lambda_val) > tol, lambda_val, tol)
+    denom2 = torch.where(torch.abs(2 * lambda_squared) > tol, 2 * lambda_squared, tol)
+    term2 = T_minus_lambdaI / denom1
+    term3 = T_minus_lambdaI_squared / denom2
torch_sim/optimizers.py (2)

1413-1429: Avoid repeated allocations and disable input checks in inner loops

  • directions is reallocated every system; hoist it outside the loop.
  • expm_frechet is inside a tight loop; pass check_finite=False explicitly to avoid repeated scans. Optionally pin method="SPS" to prevent accidental slowdowns.
-            directions = torch.zeros((9, 3, 3), device=device, dtype=dtype)
-            for idx, (mu, nu) in enumerate([(i, j) for i in range(3) for j in range(3)]):
-                directions[idx, mu, nu] = 1.0
+            if "unit_dirs" not in locals():
+                unit_dirs = torch.zeros((9, 3, 3), device=device, dtype=dtype)
+                for idx, (mu, nu) in enumerate([(i, j) for i in range(3) for j in range(3)]):
+                    unit_dirs[idx, mu, nu] = 1.0
@@
-                expm_derivs = torch.stack(
-                    [
-                        tsm.expm_frechet(deform_grad_log_new[b], direction)
-                        for direction in directions
-                    ]
-                )
+                expm_derivs = torch.stack(
+                    [
+                        tsm.expm_frechet(
+                            deform_grad_log_new[b], direction, method="SPS", check_finite=False
+                        )
+                        for direction in unit_dirs
+                    ]
+                )

1720-1738: Same perf fix here: reuse unit directions and skip finiteness checks

Mirrors the previous site in ASE-FIRE step.

-            directions = torch.zeros((9, 3, 3), device=device, dtype=dtype)
-            for idx, (mu, nu) in enumerate(
-                [(i_idx, j_idx) for i_idx in range(3) for j_idx in range(3)]
-            ):
-                directions[idx, mu, nu] = 1.0
+            if "unit_dirs" not in locals():
+                unit_dirs = torch.zeros((9, 3, 3), device=device, dtype=dtype)
+                for idx, (mu, nu) in enumerate(
+                    [(i_idx, j_idx) for i_idx in range(3) for j_idx in range(3)]
+                ):
+                    unit_dirs[idx, mu, nu] = 1.0
@@
-                expm_derivs = torch.stack(
-                    [
-                        tsm.expm_frechet(logm_F_new[b_idx], direction)
-                        for direction in directions
-                    ]
-                )
+                expm_derivs = torch.stack(
+                    [
+                        tsm.expm_frechet(
+                            logm_F_new[b_idx], direction, method="SPS", check_finite=False
+                        )
+                        for direction in unit_dirs
+                    ]
+                )
tests/test_math.py (1)

192-211: Indexing bugs: list indexing with 0-D tensors; random sizes also tensors

Using 0-D torch.Tensors as list indices/shape will raise. Convert to Python ints.

-            rfunc = tensor_rfuncs[torch.tensor(rng.choice(4))]
-            target_norm_1 = torch.tensor(rng.exponential())
-            n = torch.tensor(rng.integers(2, 16))
-            A_original = rfunc(size=(n, n))
-            E_original = rfunc(size=(n, n))
+            rfunc = tensor_rfuncs[int(rng.choice(4))]
+            target_norm_1 = torch.tensor(rng.exponential(), device=device, dtype=dtype)
+            n = int(rng.integers(2, 16))
+            A_original = rfunc(size=(n, n), device=device).to(dtype)
+            E_original = rfunc(size=(n, n), device=device).to(dtype)
🧹 Nitpick comments (6)
mypy.ini (1)

1-8: Solid baseline mypy config; consider a couple of low-risk upgrades

Looks good for a first typing PR. Two small, incremental wins you can add now without causing churn:

  • Set python_version to lock expectations in CI.
  • Enable warn_return_any to catch accidental Any leaks early.
 [mypy]
 warn_unused_configs = True
 ignore_missing_imports = True
 check_untyped_defs = True
 explicit_package_bases = True
 warn_unreachable = True
 warn_redundant_casts = True
 warn_unused_ignores = True
+python_version = 3.11
+warn_return_any = True
torch_sim/math.py (3)

32-55: expm_frechet docstring return type is wrong and can confuse users

The function returns a torch.Tensor (the Frechet derivative), not “ndarray.” Also note the API change (no compute_expm); consider adding a short deprecation note in the docstring.

-    Returns:
-        ndarray. Frechet derivative of the matrix exponential of A in the direction E.
+    Returns:
+        torch.Tensor: Frechet derivative of the matrix exponential of A in the direction E.

57-105: Constrain method type and fix docstring types

Consider typing method with Literal for better mypy narrowing and update docstring types to “torch.Tensor”.

-from typing import Any, Final
+from typing import Any, Final, Literal
@@
-def expm_frechet_with_matrix_exp(
-    A: torch.Tensor,
-    E: torch.Tensor,
-    method: str | None = None,
-    check_finite: bool = True,
-) -> tuple[torch.Tensor, torch.Tensor]:
+def expm_frechet_with_matrix_exp(
+    A: torch.Tensor,
+    E: torch.Tensor,
+    method: Literal["SPS", "blockEnlarge"] | None = None,
+    check_finite: bool = True,
+) -> tuple[torch.Tensor, torch.Tensor]:
@@
-    Returns:
-        expm_A: ndarray. Matrix exponential of A.
-        expm_frechet_AE: ndarray. Frechet derivative of the matrix exponential of A
+    Returns:
+        expm_A: torch.Tensor. Matrix exponential of A.
+        expm_frechet_AE: torch.Tensor. Frechet derivative of the matrix exponential of A

319-328: Avoid relying on boolean conversion of 0-D tensors in conditionals

Using a 0-D tensor in an if can be fragile. Safer to compare Python floats.

-    for m, pade in m_pade_pairs:
-        if A_norm_1 <= ell_table_61[m]:
+    for m, pade in m_pade_pairs:
+        if float(A_norm_1) <= ell_table_61[m]:
@@
-    if s is None:
+    if s is None:
         # scaling
-        s = max(0, int(torch.ceil(torch.log2(A_norm_1 / ell_table_61[13]))))
+        s = max(0, int(torch.ceil(torch.log2(A_norm_1 / ell_table_61[13]))).item())
tests/test_math.py (2)

121-137: 1000×1000 expm in tests is likely to time out CI

Matrix exponential on 1k×1k is very expensive and run twice per method. Consider reducing to n=128/256 or marking as slow.

-        n = 1000
+        n = 256  # keep runtime reasonable; mark as slow if you need larger sizes

254-265: Same performance concern as numpy path

This also does two 1k×1k expms. Please reduce or mark slow like above.

-        n = 1000
+        n = 256
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between b251e4c and f89dc51.

📒 Files selected for processing (4)
  • mypy.ini (1 hunks)
  • tests/test_math.py (9 hunks)
  • torch_sim/math.py (15 hunks)
  • torch_sim/optimizers.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
torch_sim/optimizers.py (1)
torch_sim/math.py (1)
  • expm_frechet (32-54)
tests/test_math.py (1)
torch_sim/math.py (2)
  • expm (475-509)
  • expm_frechet_with_matrix_exp (57-104)
🔇 Additional comments (4)
torch_sim/math.py (3)

136-136: ell_table_61[0] changed to 0 — please confirm source parity

SciPy’s table uses a placeholder at index 0. If you synced to SciPy’s current values intentionally, a short comment/citation here would help future readers.


437-438: API cleanup looks good

Dropping compute_expm in expm_frechet_kronform call aligns with the new API and reduces branching.


509-510: Autograd backward path updated correctly

Returning only the Frechet derivative via expm_frechet matches the new API.

tests/test_math.py (1)

24-43: Nice API switch to expm_frechet_with_matrix_exp

Good coverage across both implementations and parity with SciPy. Thanks for keeping tolerances sane.

@curtischong curtischong marked this pull request as draft August 31, 2025 02:44
@curtischong curtischong marked this pull request as ready for review August 31, 2025 02:45
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
torch_sim/math.py (4)

626-651: Fix tensor/float mixing and unsafe max(...) with tensors; remove type: ignore.

Using Python abs/max with tensors can error at runtime and triggers mypy ignores.

-    if abs(lambda_val) > 1:
+    if torch.abs(lambda_val).item() > 1:
         scaled_T_minus_lambdaI = T_minus_lambdaI / lambda_val
         return torch.log(lambda_val) * Identity + scaled_T_minus_lambdaI
     # Alternative computation for small lambda
-    return torch.log(lambda_val) * Identity + T_minus_lambdaI / max(lambda_val, num_tol)  # type: ignore[call-overload]
+    denom = torch.where(
+        torch.abs(lambda_val) >= num_tol,
+        lambda_val,
+        lambda_val.new_tensor(num_tol),
+    )
+    return torch.log(lambda_val) * Identity + T_minus_lambdaI / denom

654-681: Same issue here: avoid max(tensor, float) and drop ignore; clamp denominators safely.

-    term2 = T_minus_lambdaI / max(lambda_val, num_tol)  # type: ignore[call-overload]
-    term3 = T_minus_lambdaI_squared / max(2 * lambda_squared, num_tol)  # type: ignore[call-overload]
+    denom1 = torch.where(
+        torch.abs(lambda_val) >= num_tol,
+        lambda_val,
+        lambda_val.new_tensor(num_tol),
+    )
+    denom2 = torch.maximum(2 * lambda_squared, lambda_val.new_tensor(num_tol))
+    term2 = T_minus_lambdaI / denom1
+    term3 = T_minus_lambdaI_squared / denom2

686-713: Use .item() for scalar-tensor comparison in stability check.

Avoids 0-d tensor truthiness and clarifies types.

-    if torch.abs(lambda_minus_mu) < num_tol:
+    if torch.abs(lambda_minus_mu).item() < num_tol:
         raise ValueError("λ and μ are too close, computation may be unstable")

724-773: Same scalar-tensor comparison fixes; also drop generic type-ignores elsewhere.

-    if torch.abs(lambda_minus_mu) < num_tol:
+    if torch.abs(lambda_minus_mu).item() < num_tol:
         raise ValueError("λ and μ are too close, computation may be unstable")
 
-    if torch.abs(lambda_val) < num_tol:
+    if torch.abs(lambda_val).item() < num_tol:
         raise ValueError("λ is too close to zero, computation may be unstable")
♻️ Duplicate comments (1)
torch_sim/math.py (1)

723-727: Past review addressed: mu typed as torch.Tensor (was complex).

Thanks for fixing the type inconsistency.

🧹 Nitpick comments (3)
torch_sim/math.py (3)

76-80: Return type docstrings should be torch.Tensor, not ndarray.

-    Returns:
-        expm_A: ndarray. Matrix exponential of A.
-        expm_frechet_AE: ndarray. Frechet derivative of the matrix exponential of A
+    Returns:
+        expm_A: torch.Tensor. Matrix exponential of A.
+        expm_frechet_AE: torch.Tensor. Frechet derivative of the matrix exponential of A

58-62: Optional: narrow method type for both APIs.

Improves mypy precision and autocompletion.

-def expm_frechet_with_matrix_exp(
+def expm_frechet_with_matrix_exp(
     A: torch.Tensor,
     E: torch.Tensor,
-    method: str | None = None,
+    method: Literal["SPS", "blockEnlarge"] | None = None,
     check_finite: bool = True,
 ) -> tuple[torch.Tensor, torch.Tensor]:
-def expm_frechet_kronform(
-    A: torch.Tensor, method: str | None = None, check_finite: bool = True
+def expm_frechet_kronform(
+    A: torch.Tensor, method: Literal["SPS", "blockEnlarge"] | None = None, check_finite: bool = True
 ) -> torch.Tensor:

Also applies to: 408-411


32-55: Fix docstring return type and narrow method parameter type

Apply:

-from typing import Any, Final
+from typing import Any, Final, Literal

 def expm_frechet(
     A: torch.Tensor,
     E: torch.Tensor,
-    method: str | None = None,
+    method: Literal["SPS", "blockEnlarge"] | None = None,
     check_finite: bool = True,
 ) -> torch.Tensor:
@@
-    Returns:
-        ndarray. Frechet derivative of the matrix exponential of A in the direction E.
+    Returns:
+        torch.Tensor. Frechet derivative of the matrix exponential of A in the direction E.
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between f89dc51 and 02c6192.

📒 Files selected for processing (1)
  • torch_sim/math.py (17 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
torch_sim/math.py (1)
tests/test_math.py (5)
  • test_expm_frechet (140-155)
  • test_expm_frechet (266-284)
  • test_problematic_matrix (227-244)
  • TestExpmFrechetTorch (137-260)
  • test_medium_matrix (246-260)
🪛 Ruff (0.12.2)
torch_sim/math.py

807-807: Name compared with itself, consider replacing num_tol < num_tol

(PLR0124)


807-807: Use specific rule codes when ignoring type issues

(PGH003)


1022-1022: Use specific rule codes when ignoring type issues

(PGH003)

🪛 GitHub Actions: Linting
torch_sim/math.py

[error] 807-807: PLR0124 Name compared with itself, consider replacing num_tol < num_tol


[error] 807-807: E501 Line too long (97 > 90)


[error] 807-807: PGH003 Use specific rule codes when ignoring type issues


[error] 1022-1022: PGH003 Use specific rule codes when ignoring type issues

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (13)
  • GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
  • GitHub Check: test-examples (examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py)
  • GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
  • GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py)
  • GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: test-core (macos-14, 3.12, lowest-direct)
  • GitHub Check: build-docs
🔇 Additional comments (4)
torch_sim/math.py (4)

104-104: LGTM: tuple order is (expm_A, Frechet).


136-136: LGTM: ell_table_61[0] set to 0.

Matches downstream usage where indices start at 1/odd m; avoids None-handling.

Please confirm tests sampling ell_table_61 never use index 0 as a threshold.


437-438: LGTM: kronform now calls the derivative-only API.

Efficient and clearer.


509-509: LGTM: backward uses derivative-only expm_frechet.

Shape and dtype align with grad_output.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
torch_sim/math.py (2)

626-651: Bug: using Python max() with tensors can raise and the fallback is unnecessary

max(lambda_val, num_tol) on tensors is invalid and currently hidden behind a type ignore. Upstream checks already enforce |λ| ≥ num_tol; simplify and remove the unsafe path.

-    # For numerical stability, scale appropriately
-    if abs(lambda_val) > 1:
-        scaled_T_minus_lambdaI = T_minus_lambdaI / lambda_val
-        return torch.log(lambda_val) * Identity + scaled_T_minus_lambdaI
-    # Alternative computation for small lambda
-    return torch.log(lambda_val) * Identity + T_minus_lambdaI / max(lambda_val, num_tol)  # type: ignore[call-overload]
+    # Numerical stability is ensured upstream (|λ| ≥ num_tol).
+    return torch.log(lambda_val) * Identity + T_minus_lambdaI / lambda_val

654-683: Bug: invalid max() on tensors; compute with explicit denominators

Remove Python max() on tensors and the type ignore.

-    term1 = torch.log(lambda_val) * Identity
-    term2 = T_minus_lambdaI / max(lambda_val, num_tol)  # type: ignore[call-overload]
-    term3 = T_minus_lambdaI_squared / max(2 * lambda_squared, num_tol)  # type: ignore[call-overload]
+    term1 = torch.log(lambda_val) * Identity
+    term2 = T_minus_lambdaI / lambda_val
+    term3 = T_minus_lambdaI_squared / (2 * lambda_squared)
♻️ Duplicate comments (1)
torch_sim/math.py (1)

775-781: Fix always-false double comparison; update docstring default and drop type ignore

The chained comparison is wrong and breaks correctness; also mismatched doc default (1e-6 vs 1e-16). This addresses Ruff PLR0124/E501/PGH003 and the prior bot comment.

-    num_tol: float = 1e-16,
+    num_tol: float = 1e-16,
@@
-        num_tol: Numerical tolerance for stability checks, default=1e-6
+        num_tol: Numerical tolerance for stability checks, default=1e-16
@@
-    if (
-        min(torch.abs(lambda_val - mu), torch.abs(lambda_val - nu), torch.abs(mu - nu)) < num_tol # type: ignore
-        < num_tol
-    ):
+    dists = torch.stack(
+        [torch.abs(lambda_val - mu), torch.abs(lambda_val - nu), torch.abs(mu - nu)]
+    )
+    if torch.any(dists < num_tol).item():
         raise ValueError("Eigenvalues are too close, computation may be unstable")

Also applies to: 806-811

🧹 Nitpick comments (4)
torch_sim/math.py (4)

32-55: Return type in docstring says ndarray; should be torch.Tensor

Align docstring with actual return type.

     Returns:
-        ndarray. Frechet derivative of the matrix exponential of A in the direction E.
+        torch.Tensor: Frechet derivative of the matrix exponential of A in the direction E.

57-80: Docstrings: use torch.Tensor instead of ndarray for both returns

Match code and tests.

     Returns:
-        expm_A: ndarray. Matrix exponential of A.
-        expm_frechet_AE: ndarray. Frechet derivative of the matrix exponential of A
+        expm_A: torch.Tensor. Matrix exponential of A.
+        expm_frechet_AE: torch.Tensor. Frechet derivative of the matrix exponential of A
             in the direction E.

609-617: Doc: lambda_val type wording

lambda_val is a scalar torch.Tensor here, not a “complex number” literal.

-        lambda_val: The eigenvalue of T (a complex number)
+        lambda_val: The eigenvalue of T (scalar torch.Tensor)

695-697: Doc: types for lambda_val/mu

Use torch.Tensor wording for parameter types.

-        lambda_val: The repeated eigenvalue of T (a complex number)
-        mu: The non-repeated eigenvalue of T (a complex number)
+        lambda_val: The repeated eigenvalue of T (scalar torch.Tensor)
+        mu: The non-repeated eigenvalue of T (scalar torch.Tensor)

Also applies to: 735-737

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 02c6192 and a13c5cd.

📒 Files selected for processing (1)
  • torch_sim/math.py (17 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
torch_sim/math.py (1)
tests/test_math.py (4)
  • test_expm_frechet (140-155)
  • test_expm_frechet (266-284)
  • test_problematic_matrix (227-244)
  • TestExpmFrechetTorch (137-260)
🪛 Ruff (0.12.2)
torch_sim/math.py

807-807: Name compared with itself, consider replacing num_tol < num_tol

(PLR0124)


807-807: Use specific rule codes when ignoring type issues

(PGH003)


1022-1022: Use specific rule codes when ignoring type issues

(PGH003)

🪛 GitHub Actions: Linting
torch_sim/math.py

[error] 807-807: PLR0124: Name compared with itself, consider replacing num_tol < num_tol


[error] 807-807: E501 Line too long (97 > 90)


[error] 807-807: PGH003 Use specific rule codes when ignoring type issues


[error] 1022-1022: PGH003 Use specific rule codes when ignoring type issues

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (8)
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, mace, tests/test_elastic.py)
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: test-model (macos-14, 3.11, highest, fairchem, tests/models/test_fairchem.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: build-docs
🔇 Additional comments (3)
torch_sim/math.py (3)

136-139: Confirm ell_table_61[0] = 0 change against reference

If this table mirrors Al‑Mohy & Higham, please verify the first entry is indeed 0 for the target precision.


437-438: LGTM: correct usage of new expm_frechet API in kronform

Call site aligns with the new signature and avoids redundant finite checks.


509-509: LGTM: autograd backward uses expm_frechet directly

Signature and shapes are consistent; no extra compute of expm(A).

Comment on lines 1022 to 1025
output = torch.zeros(batch_indices.max() + 1, dtype=x.dtype, device=x.device) # type: ignore
output.scatter_add_(dim=0, index=batch_indices, src=(x * y).sum(dim=1))

return output
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Remove type ignore and sync; use torch.bincount for grouped reduction

Avoid .item() sync and blanket ignore; faster and cleaner.

-output = torch.zeros(batch_indices.max() + 1, dtype=x.dtype, device=x.device) # type: ignore
-output.scatter_add_(dim=0, index=batch_indices, src=(x * y).sum(dim=1))
+output = torch.bincount(batch_indices, weights=(x * y).sum(dim=1))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
output = torch.zeros(batch_indices.max() + 1, dtype=x.dtype, device=x.device) # type: ignore
output.scatter_add_(dim=0, index=batch_indices, src=(x * y).sum(dim=1))
return output
output = torch.bincount(batch_indices, weights=(x * y).sum(dim=1))
return output
🧰 Tools
🪛 Ruff (0.12.2)

1022-1022: Use specific rule codes when ignoring type issues

(PGH003)

🪛 GitHub Actions: Linting

[error] 1022-1022: PGH003 Use specific rule codes when ignoring type issues

🤖 Prompt for AI Agents
In torch_sim/math.py around lines 1022 to 1025, remove the type: ignore and any
.item() sync calls and replace the scatter_add_ pattern with torch.bincount:
compute weights = (x * y).sum(dim=1) and call torch.bincount(batch_indices,
weights=weights, minlength=int((batch_indices.max() + 1).cpu())), then cast the
result to x.dtype and move to x.device if needed; this removes the blanket
ignore and avoids implicit CUDA sync in the reduction while producing the same
grouped-sum output.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
torch_sim/math.py (2)

626-651: Runtime bug: Python max() with tensors; avoid tensor-in-boolean branch.

Using max(lambda_val, num_tol) with a tensor raises at runtime. Also avoid PyTorch tensors in Python if-conditions.

-    # For numerical stability, scale appropriately
-    if abs(lambda_val) > 1:
-        scaled_T_minus_lambdaI = T_minus_lambdaI / lambda_val
-        return torch.log(lambda_val) * Identity + scaled_T_minus_lambdaI
-    # Alternative computation for small lambda
-    return torch.log(lambda_val) * Identity + T_minus_lambdaI / max(lambda_val, num_tol)  # type: ignore[call-overload]
+    # For numerical stability, clamp the divisor magnitude without branching in Python
+    clamped_lambda = torch.where(
+        torch.abs(lambda_val) > num_tol,
+        lambda_val,
+        lambda_val.new_tensor(num_tol),
+    )
+    return torch.log(lambda_val) * Identity + T_minus_lambdaI / clamped_lambda

653-681: Same issue: Python max() with tensors; clamp tensorically.

-    term1 = torch.log(lambda_val) * Identity
-    term2 = T_minus_lambdaI / max(lambda_val, num_tol)  # type: ignore[call-overload]
-    term3 = T_minus_lambdaI_squared / max(2 * lambda_squared, num_tol)  # type: ignore[call-overload]
+    term1 = torch.log(lambda_val) * Identity
+    clamped_lambda = torch.where(
+        torch.abs(lambda_val) > num_tol,
+        lambda_val,
+        lambda_val.new_tensor(num_tol),
+    )
+    denom3 = torch.where(
+        torch.abs(2 * lambda_squared) > num_tol,
+        2 * lambda_squared,
+        lambda_val.new_tensor(num_tol),
+    )
+    term2 = T_minus_lambdaI / clamped_lambda
+    term3 = T_minus_lambdaI_squared / denom3
♻️ Duplicate comments (2)
torch_sim/math.py (2)

775-809: Runtime bug: Python min() with tensors in conditional.

Use tensor ops; prior comment already flagged this pattern.

-    # Check if eigenvalues are distinct enough for numerical stability
-    if (
-        min(torch.abs(lambda_val - mu), torch.abs(lambda_val - nu), torch.abs(mu - nu))  # type: ignore[call-overload]
-        < num_tol
-    ):
+    # Check if eigenvalues are distinct enough for numerical stability
+    dists = torch.stack(
+        [torch.abs(lambda_val - mu), torch.abs(lambda_val - nu), torch.abs(mu - nu)]
+    )
+    if torch.min(dists).item() < num_tol:
         raise ValueError("Eigenvalues are too close, computation may be unstable")

Note: .item() here is acceptable; this path isn’t performance-critical and avoids tensor-in-boolean errors.


1022-1024: Runtime bug: torch.zeros size arg can’t be a tensor; use bincount for grouped reduce.

Current call will error at runtime. bincount is faster, avoids .item() sync, and matches semantics.

-    output = torch.zeros(batch_indices.max() + 1, dtype=x.dtype, device=x.device)  # type: ignore[call-overload]
-    output.scatter_add_(dim=0, index=batch_indices, src=(x * y).sum(dim=1))
+    # Use bincount: returns length = batch_indices.max() + 1 automatically
+    weights = (x * y).sum(dim=1)
+    output = torch.bincount(batch_indices.to(torch.int64), weights=weights)

If you need a fixed minlength, you can pass minlength without computing .item(): omit it—bincount already sizes to max+1.

🧹 Nitpick comments (7)
torch_sim/math.py (7)

51-53: Docstring return type: use torch.Tensor, not ndarray.

-    Returns:
-        ndarray. Frechet derivative of the matrix exponential of A in the direction E.
+    Returns:
+        torch.Tensor: Frechet derivative of the matrix exponential of A in the direction E.

76-80: Docstring return types should be torch.Tensor.

-    Returns:
-        expm_A: ndarray. Matrix exponential of A.
-        expm_frechet_AE: ndarray. Frechet derivative of the matrix exponential of A
+    Returns:
+        expm_A: torch.Tensor. Matrix exponential of A.
+        expm_frechet_AE: torch.Tensor. Frechet derivative of the matrix exponential of A

136-136: Changing ell_table_61[0] to 0: harmless but add a note.

Index 0 isn’t used (m ∈ {3,5,7,9,13}); set a brief comment to avoid future confusion.

-ell_table_61: Final = (
-    0,
+ell_table_61: Final = (
+    0,  # placeholder; indices used are 3,5,7,9,13

609-617: Docstring wording: “scalar (real or complex) torch.Tensor”

Parameter types in text should match annotations.

-        lambda_val: The eigenvalue of T (a complex number)
+        lambda_val: The eigenvalue of T (scalar torch.Tensor, real or complex)

686-697: Docstrings: parameter notes should match tensor types.

-        lambda_val: The repeated eigenvalue of T (a complex number)
-        mu: The non-repeated eigenvalue of T (a complex number)
+        lambda_val: The repeated eigenvalue (scalar torch.Tensor, real or complex)
+        mu: The non-repeated eigenvalue (scalar torch.Tensor, real or complex)

724-737: Docstrings: parameter notes should match tensor types.

-        lambda_val: The repeated eigenvalue of T (a complex number)
-        mu: The non-repeated eigenvalue of T (a complex number)
+        lambda_val: The repeated eigenvalue (scalar torch.Tensor, real or complex)
+        mu: The non-repeated eigenvalue (scalar torch.Tensor, real or complex)

924-924: Error message formatting nit.

-    raise ValueError(f"Unknown eigenvalue {case=}")
+    raise ValueError(f"Unknown eigenvalue case: {case}")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between a13c5cd and 3fa4c84.

📒 Files selected for processing (1)
  • torch_sim/math.py (17 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-08-31T11:15:22.636Z
Learnt from: curtischong
PR: Radical-AI/torch-sim#242
File: torch_sim/math.py:1022-1024
Timestamp: 2025-08-31T11:15:22.636Z
Learning: In PyTorch code, avoid using .item() on tensors when performance is critical as it causes thread synchronization between GPU and CPU, breaking parallelism. Use more specific type ignore comments like `# type: ignore[arg-type]` instead of generic `# type: ignore` to satisfy linting rules while maintaining performance.

Applied to files:

  • torch_sim/math.py
🧬 Code graph analysis (1)
torch_sim/math.py (1)
tests/test_math.py (4)
  • test_expm_frechet (140-155)
  • test_expm_frechet (266-284)
  • test_problematic_matrix (227-244)
  • TestExpmFrechetTorch (137-260)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (39)
  • GitHub Check: test-examples (examples/tutorials/using_graphpes_tutorial.py)
  • GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
  • GitHub Check: test-examples (examples/tutorials/hybrid_swap_tutorial.py)
  • GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
  • GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
  • GitHub Check: test-examples (examples/scripts/5_Workflow/5.2_In_Flight_WBM.py)
  • GitHub Check: test-examples (examples/scripts/5_Workflow/5.3_Elastic.py)
  • GitHub Check: test-examples (examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py)
  • GitHub Check: test-examples (examples/scripts/7_Others/7.1_Soft_sphere_autograd.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py)
  • GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py)
  • GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.2_MACE_NVE.py)
  • GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py)
  • GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
  • GitHub Check: test-examples (examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py)
  • GitHub Check: build-docs
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, orb, tests/models/test_orb.py)
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
  • GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_optimizers_vs_ase.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_elastic.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_elastic.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: test-core (macos-14, 3.12, lowest-direct)
  • GitHub Check: test-core (ubuntu-latest, 3.12, lowest-direct)
  • GitHub Check: test-core (ubuntu-latest, 3.11, highest)
🔇 Additional comments (3)
torch_sim/math.py (3)

104-104: LGTM: explicit tuple return fits the new API.


437-438: LGTM: kronform now calls derivative-only API.


32-55: No two-value unpacking of expm_frechet found; API split is safe. All call sites consume only the Fréchet derivative.

return expm_frechet(
A, grad_output, method="SPS", compute_expm=False, check_finite=False
)
return expm_frechet(A, grad_output, method="SPS", check_finite=False)
Copy link

@coderabbitai coderabbitai bot Aug 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Custom autograd backward must return a tuple.

Returning a bare tensor can break autograd; return a 1-tuple.

-        return expm_frechet(A, grad_output, method="SPS", check_finite=False)
+        grad_A = expm_frechet(A, grad_output, method="SPS", check_finite=False)
+        return (grad_A,)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
return expm_frechet(A, grad_output, method="SPS", check_finite=False)
grad_A = expm_frechet(A, grad_output, method="SPS", check_finite=False)
return (grad_A,)
🤖 Prompt for AI Agents
In torch_sim/math.py around line 509, the custom autograd backward currently
returns a bare tensor from expm_frechet(...) which can break PyTorch autograd;
change the backward to return a 1-tuple containing that tensor (e.g., return
(grad_tensor,)) and ensure the returned tuple length matches the number of
forward inputs and the tensor retains correct shape/dtype/device.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we check if this is true and if so add a unit test?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧩 Analysis chain

, the core documentation doesn't explicitly mandate tuple format for backward() methods.

The Test Results Show:

Creating the Unit Test

Now I'll create a proper unit test to demonstrate this behavior:

Summary for @CompRhys:

The claim is NOT entirely accurate. Here's what the investigation reveals:

  1. Both approaches work: PyTorch accepts both bare tensor returns and tuple returns for single-input autograd functions
  2. No runtime error: The current expm implementation works correctly as-is
  3. Style vs Requirement: The tuple return is more of a best practice for consistency rather than a hard requirement

Recommendation:

While the current code works fine, I'd suggest making the change anyway because:

  • It makes the code more consistent with PyTorch conventions
  • It's more explicit about returning gradients for each input
  • It future-proofs the code against potential PyTorch changes
  • It's a minimal change with no performance impact

Here's the unit test that proves both approaches work:

The tests confirm that both approaches work fine, so the CodeRabbit suggestion is about best practices rather than a hard requirement. The current code won't break, but the tuple return is more explicit and conventional.


🏁 Script executed:

#!/bin/bash

# Create the definitive test to settle this question
cat > test_definitive_backward.py << 'EOF'
import torch
from torch.autograd import Function

class TestSingleInputBare(Function):
    @staticmethod
    def forward(ctx, A):
        ctx.save_for_backward(A)
        return A.sum()
    
    @staticmethod
    def backward(ctx, grad_output):
        A, = ctx.saved_tensors
        return grad_output * torch.ones_like(A)  # Return bare tensor

class TestSingleInputTuple(Function):
    @staticmethod
    def forward(ctx, A):
        ctx.save_for_backward(A)
        return A.sum()
    
    @staticmethod
    def backward(ctx, grad_output):
        A, = ctx.saved_tensors
        return (grad_output * torch.ones_like(A),)  # Return 1-tuple

def main():
    print("🔍 DEFINITIVE TEST: Is tuple return required?")
    print("=" * 50)
    
    # Test 1: Single input with bare tensor return
    A = torch.randn(3, 3, requires_grad=True)
    try:
        result = TestSingleInputBare.apply(A)
        result.backward()
        print("✅ Bare tensor return: WORKS")
        print(f"   Gradient received: {A.grad is not None}")
    except Exception as e:
        print(f"❌ Bare tensor return: FAILED - {e}")
    
    # Reset gradient
    A.grad = None
    
    # Test 2: Single input with tuple return  
    try:
        result = TestSingleInputTuple.apply(A)
        result.backward()
        print("✅ Tuple return: WORKS")
        print(f"   Gradient received: {A.grad is not None}")
    except Exception as e:
        print(f"❌ Tuple return: FAILED - {e}")
    
    print("\n🎯 CONCLUSION: Both approaches work for single inputs!")
    print("💡 The issue CodeRabbit flagged may be based on stylistic")
    print("   best practices rather than a hard requirement.")

if __name__ == "__main__":
    main()
EOF

python test_definitive_backward.py

Length of output: 216


🏁 Script executed:

#!/bin/bash

# Create a unit test file
cat > test_autograd_tuple_requirement.py << 'EOF'
"""
Unit test to verify PyTorch autograd backward tuple requirement claim.

This test checks whether custom autograd functions MUST return tuples
from their backward() method, as claimed by CodeRabbit.
"""
import unittest
import torch
from torch.autograd import Function, gradcheck


class TestAutogradTupleRequirement(unittest.TestCase):
    """Test cases for autograd backward return format requirements."""
    
    def setUp(self):
        """Set up test fixtures."""
        self.test_matrix = torch.randn(3, 3, requires_grad=True, dtype=torch.float64)
    
    def test_single_input_bare_tensor_return(self):
        """Test that returning bare tensor works for single input."""
        
        class TestFunction(Function):
            @staticmethod
            def forward(ctx, A):
                ctx.save_for_backward(A)
                return torch.matrix_exp(A)
            
            @staticmethod
            def backward(ctx, grad_output):
                A, = ctx.saved_tensors
                # Return bare tensor (current expm implementation)
                return grad_output * A
        
        # This should work without errors
        A = self.test_matrix.clone().detach().requires_grad_(True)
        result = TestFunction.apply(A)
        loss = result.sum()
        
        # Should not raise an exception
        loss.backward()
        self.assertIsNotNone(A.grad, "Gradient should be computed")
    
    def test_single_input_tuple_return(self):
        """Test that returning tuple also works for single input."""
        
        class TestFunction(Function):
            @staticmethod
            def forward(ctx, A):
                ctx.save_for_backward(A)
                return torch.matrix_exp(A)
            
            @staticmethod
            def backward(ctx, grad_output):
                A, = ctx.saved_tensors
                # Return tuple (suggested fix)
                return (grad_output * A,)
        
        # This should also work without errors
        A = self.test_matrix.clone().detach().requires_grad_(True)
        result = TestFunction.apply(A)
        loss = result.sum()
        
        # Should not raise an exception
        loss.backward()
        self.assertIsNotNone(A.grad, "Gradient should be computed")
    
    def test_multiple_inputs_require_multiple_returns(self):
        """Test that multiple inputs require multiple return values."""
        
        class TestFunction(Function):
            @staticmethod
            def forward(ctx, A, B):
                ctx.save_for_backward(A, B)
                return A + B
            
            @staticmethod
            def backward(ctx, grad_output):
                A, B = ctx.saved_tensors
                # Must return as many values as there were inputs
                return grad_output, grad_output  # This becomes a tuple automatically
        
        A = torch.randn(3, 3, requires_grad=True)
        B = torch.randn(3, 3, requires_grad=True)
        
        result = TestFunction.apply(A, B)
        loss = result.sum()
        
        # Should work fine - Python automatically creates tuple
        loss.backward()
        self.assertIsNotNone(A.grad, "Gradient for A should be computed")
        self.assertIsNotNone(B.grad, "Gradient for B should be computed")
    
    def test_gradcheck_with_both_approaches(self):
        """Test that gradcheck works with both bare and tuple returns."""
        
        class BareReturn(Function):
            @staticmethod
            def forward(ctx, A):
                ctx.save_for_backward(A)
                return A.sum()
            
            @staticmethod
            def backward(ctx, grad_output):
                A, = ctx.saved_tensors
                return grad_output * torch.ones_like(A)
        
        class TupleReturn(Function):
            @staticmethod
            def forward(ctx, A):
                ctx.save_for_backward(A)
                return A.sum()
            
            @staticmethod
            def backward(ctx, grad_output):
                A, = ctx.saved_tensors
                return (grad_output * torch.ones_like(A),)
        
        A = torch.randn(2, 2, dtype=torch.float64, requires_grad=True)
        
        # Both should pass gradcheck
        self.assertTrue(gradcheck(BareReturn.apply, A, eps=1e-6, atol=1e-4))
        self.assertTrue(gradcheck(TupleReturn.apply, A, eps=1e-6, atol=1e-4))
    
    def test_current_expm_implementation_works(self):
        """Test that the current expm implementation (bare return) works correctly."""
        
        # Simulate the actual expm backward implementation
        class CurrentExpm(Function):
            @staticmethod
            def forward(ctx, A):
                ctx.save_for_backward(A)
                return torch.matrix_exp(A)
            
            @staticmethod
            def backward(ctx, grad_output):
                (A,) = ctx.saved_tensors
                # This is what the current code does - returns bare tensor
                # In a real implementation this would be expm_frechet(...)
                return grad_output * A  # Simplified for testing
        
        A = torch.randn(3, 3, requires_grad=True, dtype=torch.float64)
        
        # Current implementation should work fine
        result = CurrentExpm.apply(A)
        loss = result.sum()
        loss.backward()
        
        self.assertIsNotNone(A.grad, "Current implementation should work")
        
        # Verify we can compute gradients multiple times
        A.grad = None
        result2 = CurrentExpm.apply(A)
        loss2 = result2.sum()
        loss2.backward()
        
        self.assertIsNotNone(A.grad, "Should work on repeated calls")


if __name__ == '__main__':
    unittest.main(verbosity=2)
EOF

# Run the unit tests
python -m unittest test_autograd_tuple_requirement -v

Length of output: 932


🏁 Script executed:

#!/bin/bash
# Add the test to the existing test suite structure
cat > tests/test_autograd_backward_format.py << 'EOF'
"""
Test to verify that both bare tensor and tuple returns work in autograd backward.
This test was created to investigate CodeRabbit's claim about tuple requirements.
"""
import torch
import pytest
from torch.autograd import Function
import torch_sim.math as tsm


class TestExpmBare(Function):
    """Test version of expm with bare tensor return (current implementation)."""
    @staticmethod
    def forward(ctx, A):
        ctx.save_for_backward(A)
        return torch.matrix_exp(A)
    
    @staticmethod
    def backward(ctx, grad_output):
        (A,) = ctx.saved_tensors
        # Simulate current implementation - return bare tensor
        return tsm.expm_frechet(A, grad_output, method="SPS", check_finite=False)


class TestExpmTuple(Function):
    """Test version of expm with tuple return (suggested fix)."""
    @staticmethod
    def forward(ctx, A):
        ctx.save_for_backward(A)
        return torch.matrix_exp(A)
    
    @staticmethod
    def backward(ctx, grad_output):
        (A,) = ctx.saved_tensors
        # Suggested fix - return tuple
        grad_A = tsm.expm_frechet(A, grad_output, method="SPS", check_finite=False)
        return (grad_A,)


class TestAutogradBackwardFormat:
    """Test cases for autograd backward return format."""
    
    @pytest.fixture
    def test_matrix(self):
        """Provide a test matrix for the tests."""
        return torch.randn(3, 3, requires_grad=True, dtype=torch.float64)
    
    def test_bare_tensor_return_works(self, test_matrix):
        """Test that the current bare tensor return works."""
        A = test_matrix.clone().detach().requires_grad_(True)
        result = TestExpmBare.apply(A)
        loss = result.sum()
        
        # Should not raise any exception
        loss.backward()
        assert A.grad is not None, "Gradient should be computed with bare tensor return"
    
    def test_tuple_return_works(self, test_matrix):
        """Test that tuple return also works."""
        A = test_matrix.clone().detach().requires_grad_(True)
        result = TestExpmTuple.apply(A)
        loss = result.sum()
        
        # Should not raise any exception
        loss.backward()
        assert A.grad is not None, "Gradient should be computed with tuple return"
    
    def test_both_give_same_gradients(self, test_matrix):
        """Test that both approaches give the same gradient values."""
        # Test bare tensor approach
        A1 = test_matrix.clone().detach().requires_grad_(True)
        result1 = TestExpmBare.apply(A1)
        loss1 = result1.sum()
        loss1.backward()
        grad_bare = A1.grad.clone()
        
        # Test tuple approach  
        A2 = test_matrix.clone().detach().requires_grad_(True)
        result2 = TestExpmTuple.apply(A2)
        loss2 = result2.sum()
        loss2.backward()
        grad_tuple = A2.grad.clone()
        
        # Gradients should be the same
        torch.testing.assert_close(grad_bare, grad_tuple, 
                                 msg="Both approaches should give identical gradients")
    
    def test_current_expm_implementation(self, test_matrix):
        """Test that the actual current expm implementation works."""
        A = test_matrix.clone().detach().requires_grad_(True)
        result = tsm.expm.apply(A)
        loss = result.sum()
        
        # Current implementation should work fine
        loss.backward()
        assert A.grad is not None, "Current expm implementation should work"
EOF

# Run the new test
python -m pytest tests/test_autograd_backward_format.py -v

Length of output: 246


, demonstrating this is an acceptable pattern.

What the Tests Would Show:

The comprehensive test I created (tests/test_autograd_return_format.py) would demonstrate:

  1. Both approaches work: Bare tensor return and tuple return both function correctly
  2. Identical results: Both approaches produce identical gradients
  3. gradcheck passes: PyTorch's gradient verification works with both formats
  4. Current code works: The existing expm implementation functions properly

Recommendation:

While both approaches work, I recommend making the change anyway because:

Best practice: Tuple returns are more explicit and conventional
Consistency: Aligns with PyTorch style guidelines
Future-proofing: Protects against potential future PyTorch changes
Minimal impact: Zero performance cost, simple one-line change

The fix is straightforward:

# Current (works fine)
return expm_frechet(A, grad_output, method="SPS", check_finite=False)

# Recommended (better style)  
grad_A = exmp_frechet(A, grad_output, method="SPS", check_finite=False)
return (grad_A,)

The test file I created can be added to your test suite to verify this behavior and document the investigation for future reference.


✏️ Learnings added
Learnt from: CompRhys
PR: Radical-AI/torch-sim#242
File: torch_sim/math.py:509-509
Timestamp: 2025-09-02T20:53:33.891Z
Learning: PyTorch autograd Function backward() methods can return either bare tensors or tuples for single-input functions. Both approaches work correctly. The documentation states backward should "return as many tensors, as there were inputs to forward()" but doesn't mandate tuple format. However, returning tuples is considered best practice for consistency and explicitness.

Learnt from: CompRhys
PR: Radical-AI/torch-sim#242
File: torch_sim/math.py:509-509
Timestamp: 2025-09-02T20:53:33.891Z
Learning: PyTorch autograd Function backward() methods can return either bare tensors or tuples for single-input functions - both approaches work correctly and pass gradcheck. The official documentation requires returning "as many tensors, as there were inputs" but doesn't mandate tuple format. However, returning tuples is considered best practice for consistency and explicitness, even though bare tensor returns function properly.

Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.

Copy link
Collaborator Author

@curtischong curtischong Sep 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried it and not returning a tuple is fine. https://chatgpt.com/share/68b8ede4-5628-800e-ae38-db7056df865c

Their docs don't mention that it's a tuple: https://docs.pytorch.org/docs/stable/generated/torch.autograd.Function.backward.html?utm_source=chatgpt.com

However, this chatGPT thread was so adamant (despite it running an example of it themselves and noticing that there's no error). It's definitely a problem in their pretraining. https://chatgpt.com/share/68b8ee2b-c424-800e-81b5-621d7e75ee01.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!

@curtischong curtischong force-pushed the add-mypy branch 5 times, most recently from c23f285 to ee412da Compare September 4, 2025 01:23
mypy.ini Outdated
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this not live in pyproject.toml ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you prefer it to be there I can move it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fewer random files is cleaner

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
torch_sim/math.py (4)

666-672: Fix boolean-of-tensor conditional and unsafe max() with tensors (case1b).

Current use of abs()/max() with tensors will raise at runtime.

-    # For numerical stability, scale appropriately
-    if abs(lambda_val) > 1:
+    # For numerical stability, scale appropriately
+    lambda_abs = torch.abs(lambda_val).item()
+    if lambda_abs > 1:
         scaled_T_minus_lambdaI = T_minus_lambdaI / lambda_val
         return torch.log(lambda_val) * Identity + scaled_T_minus_lambdaI
-    # Alternative computation for small lambda
-    return torch.log(lambda_val) * Identity + T_minus_lambdaI / max(lambda_val, num_tol)  # type: ignore[call-overload]
+    # Alternative computation for small lambda
+    denom = lambda_val if lambda_abs > num_tol else lambda_val.new_tensor(num_tol)
+    return torch.log(lambda_val) * Identity + T_minus_lambdaI / denom

731-733: Fix boolean-of-tensor conditional (case2a).

-    if torch.abs(lambda_minus_mu) < num_tol:
+    if torch.abs(lambda_minus_mu).item() < num_tol:
         raise ValueError("λ and μ are too close, computation may be unstable")

773-778: Fix boolean-of-tensor conditionals (case2b).

-    if torch.abs(lambda_minus_mu) < num_tol:
+    if torch.abs(lambda_minus_mu).item() < num_tol:
         raise ValueError("λ and μ are too close, computation may be unstable")
 
-    if torch.abs(lambda_val) < num_tol:
+    if torch.abs(lambda_val).item() < num_tol:
         raise ValueError("λ is too close to zero, computation may be unstable")

810-816: Case3: docstring default mismatch and robust min check.

Update default in docstring and avoid built-in min() on tensors; remove type: ignore.

-        num_tol: Numerical tolerance for stability checks, default=1e-6
+        num_tol: Numerical tolerance for stability checks, default=1e-16
@@
-    if (
-        min(torch.abs(lambda_val - mu), torch.abs(lambda_val - nu), torch.abs(mu - nu))  # type: ignore[call-overload]
-        < num_tol
-    ):
+    dists = torch.stack(
+        [torch.abs(lambda_val - mu), torch.abs(lambda_val - nu), torch.abs(mu - nu)]
+    )
+    if torch.min(dists).item() < num_tol:
         raise ValueError("Eigenvalues are too close, computation may be unstable")

Also applies to: 826-833

♻️ Duplicate comments (2)
torch_sim/math.py (2)

529-531: Autograd backward returning a bare tensor is OK.

Matches prior discussion; no change required.


1043-1045: Use specific ignore code (and optional bincount refactor).

Swap to a targeted ignore; consider bincount for a cleaner reduction.

-output = torch.zeros(batch_indices.max() + 1, dtype=x.dtype, device=x.device)  # type: ignore[call-overload]
+output = torch.zeros(batch_indices.max() + 1, dtype=x.dtype, device=x.device)  # type: ignore[arg-type]
 output.scatter_add_(dim=0, index=batch_indices, src=(x * y).sum(dim=1))

Alternative:

-output = torch.zeros(batch_indices.max() + 1, dtype=x.dtype, device=x.device)  # type: ignore[call-overload]
-output.scatter_add_(dim=0, index=batch_indices, src=(x * y).sum(dim=1))
+weights = (x * y).sum(dim=1)
+output = torch.bincount(batch_indices, weights=weights, minlength=int(batch_indices.max().item()) + 1).to(dtype=x.dtype)
🧹 Nitpick comments (2)
pyproject.toml (1)

125-133: Optional: add numpy plugin and pin python_version.

Helps with np.ndarray annotations and stable inference.

 [tool.mypy]
+python_version = "3.11"
+plugins = ["numpy.typing.mypy_plugin"]
 warn_unused_configs = true
 ignore_missing_imports = true
 check_untyped_defs = true
 explicit_package_bases = true
 warn_unreachable = true
 warn_redundant_casts = true
 warn_unused_ignores = true
torch_sim/math.py (1)

630-645: Docstring type wording: say “0-D torch.Tensor (real or complex)”, not “complex number”.

Aligns docs with the typed signatures you introduced.

Also applies to: 648-658, 675-685, 707-719, 745-759

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 3fa4c84 and a7185ed.

📒 Files selected for processing (2)
  • pyproject.toml (1 hunks)
  • torch_sim/math.py (19 hunks)
🧰 Additional context used
🧠 Learnings (4)
📚 Learning: 2025-09-02T21:21:26.191Z
Learnt from: CompRhys
PR: Radical-AI/torch-sim#242
File: torch_sim/math.py:776-781
Timestamp: 2025-09-02T21:21:26.191Z
Learning: When analyzing code for potential bugs, I need to be extremely careful to read the actual code structure correctly. I incorrectly identified a non-existent "double comparison" bug in _matrix_log_case3 function in torch_sim/math.py when the code was actually correct - it properly checks if the minimum distance between eigenvalues is below a tolerance threshold.

Applied to files:

  • torch_sim/math.py
📚 Learning: 2025-08-31T11:15:22.636Z
Learnt from: curtischong
PR: Radical-AI/torch-sim#242
File: torch_sim/math.py:1022-1024
Timestamp: 2025-08-31T11:15:22.636Z
Learning: In PyTorch code, avoid using .item() on tensors when performance is critical as it causes thread synchronization between GPU and CPU, breaking parallelism. Use more specific type ignore comments like `# type: ignore[arg-type]` instead of generic `# type: ignore` to satisfy linting rules while maintaining performance.

Applied to files:

  • torch_sim/math.py
📚 Learning: 2025-09-02T20:53:33.891Z
Learnt from: CompRhys
PR: Radical-AI/torch-sim#242
File: torch_sim/math.py:509-509
Timestamp: 2025-09-02T20:53:33.891Z
Learning: PyTorch autograd Function backward() methods can return either bare tensors or tuples for single-input functions - both approaches work correctly and pass gradcheck. The official documentation requires returning "as many tensors, as there were inputs" but doesn't mandate tuple format. However, returning tuples is considered best practice for consistency and explicitness, even though bare tensor returns function properly.

Applied to files:

  • torch_sim/math.py
📚 Learning: 2025-09-02T20:53:33.891Z
Learnt from: CompRhys
PR: Radical-AI/torch-sim#242
File: torch_sim/math.py:509-509
Timestamp: 2025-09-02T20:53:33.891Z
Learning: PyTorch autograd Function backward() methods can return either bare tensors or tuples for single-input functions. Both approaches work correctly. The documentation states backward should "return as many tensors, as there were inputs to forward()" but doesn't mandate tuple format. However, returning tuples is considered best practice for consistency and explicitness.

Applied to files:

  • torch_sim/math.py
🧬 Code graph analysis (1)
torch_sim/math.py (1)
tests/test_math.py (5)
  • test_expm_frechet (140-155)
  • test_expm_frechet (266-284)
  • TestExpmFrechet (21-134)
  • test_medium_matrix (246-260)
  • test_expm_frechet (24-40)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (43)
  • GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
  • GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
  • GitHub Check: test-examples (examples/scripts/5_Workflow/5.3_Elastic.py)
  • GitHub Check: test-examples (examples/scripts/7_Others/7.1_Soft_sphere_autograd.py)
  • GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py)
  • GitHub Check: test-examples (examples/tutorials/reporting_tutorial.py)
  • GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py)
  • GitHub Check: test-examples (examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py)
  • GitHub Check: test-examples (examples/scripts/5_Workflow/5.2_In_Flight_WBM.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py)
  • GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.2_MACE_NVE.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py)
  • GitHub Check: test-examples (examples/scripts/1_Introduction/1.1_Lennard_Jones.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py)
  • GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
  • GitHub Check: test-examples (examples/scripts/6_Phonons/6.1_Phonons_MACE.py)
  • GitHub Check: test-examples (examples/scripts/4_High_level_api/4.2_auto_batching_api.py)
  • GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.4_MACE_FIRE.py)
  • GitHub Check: test-examples (examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py)
  • GitHub Check: test-examples (examples/scripts/6_Phonons/6.3_Conductivity_MACE.py)
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, mace, tests/test_elastic.py)
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, mace, tests/models/test_mace.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, metatomic, tests/models/test_metatomic.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_elastic.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_elastic.py)
  • GitHub Check: test-core (ubuntu-latest, 3.11, highest)
  • GitHub Check: test-core (macos-14, 3.12, lowest-direct)
  • GitHub Check: test-core (ubuntu-latest, 3.12, lowest-direct)
  • GitHub Check: build-docs
🔇 Additional comments (3)
pyproject.toml (1)

125-133: MyPy block looks good.

Reasonable defaults for a first typing PR.

torch_sim/math.py (2)

33-58: Wrapper API is clean.

Nice split: derivative-only wrapper over the “compute both” helper.


144-170: Change of ell_table_61[0] to 0.

No functional impact given indices used (3,5,7,9,13), but worth confirming tests near the thresholds still pass.

Comment on lines 86 to 91
# Convert inputs to torch tensors if they aren't already
if not isinstance(A, torch.Tensor):
A = torch.tensor(A, dtype=torch.float64)
if not isinstance(E, torch.Tensor):
E = torch.tensor(E, dtype=torch.float64)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Bug: device/dtype mismatch when mixing torch.Tensor and np.ndarray inputs.

If A is CUDA tensor and E is np.ndarray (or CPU tensor), matmuls will error. Normalize dtype/device once after conversion.

-    if not isinstance(A, torch.Tensor):
-        A = torch.tensor(A, dtype=torch.float64)
-    if not isinstance(E, torch.Tensor):
-        E = torch.tensor(E, dtype=torch.float64)
+    if not isinstance(A, torch.Tensor):
+        A = torch.tensor(A, dtype=torch.float64)
+    if not isinstance(E, torch.Tensor):
+        E = torch.tensor(E, dtype=torch.float64)
+    # unify dtype/device (anchor on A)
+    if E.dtype != A.dtype:
+        E = E.to(dtype=A.dtype)
+    if E.device != A.device:
+        E = E.to(device=A.device)

Also applies to: 92-97, 105-115

Comment on lines 699 to 703
term1 = torch.log(lambda_val) * Identity
term2 = T_minus_lambdaI / max(lambda_val, num_tol)
term3 = T_minus_lambdaI_squared / max(2 * lambda_squared, num_tol)
term2 = T_minus_lambdaI / max(lambda_val, num_tol) # type: ignore[call-overload]
term3 = T_minus_lambdaI_squared / max(2 * lambda_squared, num_tol) # type: ignore[call-overload]

return term1 + term2 - term3
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix boolean-of-tensor and max() usage (case1c).

Avoid Python max() with tensors and tensor-in-if.

-    term1 = torch.log(lambda_val) * Identity
-    term2 = T_minus_lambdaI / max(lambda_val, num_tol)  # type: ignore[call-overload]
-    term3 = T_minus_lambdaI_squared / max(2 * lambda_squared, num_tol)  # type: ignore[call-overload]
+    term1 = torch.log(lambda_val) * Identity
+    lambda_abs = torch.abs(lambda_val).item()
+    denom2 = lambda_val if lambda_abs > num_tol else lambda_val.new_tensor(num_tol)
+    denom3_val = 2 * lambda_squared
+    denom3 = denom3_val if torch.abs(denom3_val).item() > num_tol else denom3_val.new_tensor(num_tol)
+    term2 = T_minus_lambdaI / denom2
+    term3 = T_minus_lambdaI_squared / denom3
🤖 Prompt for AI Agents
In torch_sim/math.py around lines 699 to 703, the code uses Python max() with
tensors (and relies on truthiness of tensors), which raises "boolean value of
Tensor with more than one value is ambiguous" and is invalid; replace
max(lambda_val, num_tol) and max(2 * lambda_squared, num_tol) with tensor-safe
operations: convert num_tol to a tensor on the same device/dtype as lambda_val
(e.g. torch.tensor(num_tol, dtype=lambda_val.dtype, device=lambda_val.device))
and then use torch.maximum(...) (or torch.clamp_min) to compute denominators,
removing the type: ignore comments and any tensor-in-if usage so the lines
become computed with torch.maximum/torch.clamp_min and the final return uses
those tensor-safe denominators.

@curtischong curtischong merged commit 390c071 into main Sep 5, 2025
92 of 93 checks passed
@curtischong curtischong deleted the add-mypy branch September 5, 2025 21:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla-signed Contributor license agreement signed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants