Accelerated ts.optimize by batching Frechet Cell Filter #439
Accelerated ts.optimize by batching Frechet Cell Filter #439falletta wants to merge 9 commits intoTorchSim:mainfrom
Conversation
Co-authored-by: Cursor <cursoragent@cursor.com>
|
Could we get some tests verifying identical numerical behavior of the old and new versions? Can be deleted before merging when we get rid of the unbatched version. |
|
@orionarcher I added |
torch_sim/math.py
Outdated
| num_tol = 1e-16 if dtype == torch.float64 else 1e-8 | ||
| batched = T.dim() == 3 | ||
|
|
||
| if batched: |
There was a problem hiding this comment.
Why support both batched and unbatched versions?
There was a problem hiding this comment.
I now removed all unbatched code
tests/test_math_frechet.py
Outdated
| class TestExpmFrechet: | ||
| """Tests for expm_frechet against scipy.linalg.expm_frechet.""" | ||
|
|
||
| def test_small_matrix(self): |
There was a problem hiding this comment.
Are we testing the batched or unbatched versions here?
There was a problem hiding this comment.
I incorporated the batched tests only in test_math.py
|
Haven't looked carefully on the implementation but I would potentially support to have 2 separate functions for batched (B, 3, 3) and unbatched (3,3) algorithms. This would also prevent graph breaks in the future, be easier to read, and in practice a state.cell is always (B, 3, 3), potentially with B=1. So we would always use the batched version anyway. |
9fd0eea to
db04cd6
Compare
Co-authored-by: Cursor <cursoragent@cursor.com>
db04cd6 to
ee790a3
Compare
|
@orionarcher I removed all unbatched and unused code while preserving the new performance speedups. Please see the PR description for a detailed list of changes. @thomasloux It’s indeed a good point, but for now it’s probably better to keep things clean and stick to the batched implementation only. By keeping only the batched implementation, we can remove quite a few lines of dead code. |
d6d5b46 to
06ebdae
Compare
Co-authored-by: Cursor <cursoragent@cursor.com>
1f467cf to
ff0e9f3
Compare
Update Summary
1. torch_sim/math.py
Removed unbatched/legacy code:
expm_frechet_block_enlarge(helper function for block enlargement method)_diff_pade3,_diff_pade5,_diff_pade7,_diff_pade9(Padé approximation helpers)expm_frechet_algo_64(original algorithm implementation)matrix_exp(custom matrix exponential function)vec,expm_frechet_kronform(Kronecker form helpers)expm_cond(condition number estimation)class expm(autograd Function class)_is_valid_matrix,_determine_eigenvalue_case(unbatched helpers)Refactored
expm_frechet:SPS"orblockEnlarge)Refactored
matrix_log_33:_ensure_batched,_determine_matrix_log_cases,_process_matrix_log_casehelpers2. torch_sim/optimizers/cell_filters.py
Vectorized compute_cell_forces:
expm_frechet(A_batch, E_batch)is now called once with alln_systems * 9matrices batched together3. tests/test_math.py
Refactored tests:
TestExpmFrechet:test_expm_frechet,test_small_norm_expm_frechet,test_fuzzTestExpmFrechetTorch:test_expm_frechet,test_fuzzAll updated to use 3x3 matrices and simplified by removing
methodparameter testing. Fuzz tests streamlined with fewer iterations.Removed tests:
test_problematic_matrix,test_medium_matrix(both numpy and torch versions)TestExpmFrechetTorchGradclassTests for comparing computation methods and large matrix performance no longer apply to the 3x3-specialized implementation.
Added tests:
TestExpmFrechet.test_large_norm_matrices- Tests scaling behavior for larger norm matricesTestLogM33.test_batched_positive_definite- Tests batched matrix logarithm with round-trip verificationTestFrechetCellFilterIntegration- Integration tests for the cell filter pipelinetest_wrap_positions_*- Tests for the newwrap_positionspropertyResults
The figure below shows the speedup achieved for 10-step atomic relaxation. The test is performed for a 8-atom cubic supercell of MgO using the

mace-mpamodel. Prior results are shown in blue, while new results are shown in red. The speedup is calculated asspeedup (%) = (baseline_time / current_time − 1) × 100. We observe a speedup up to 564% for large batches.