Skip to content

Fix GPU memory leak from missing .detach() in model wrappers#491

Merged
CompRhys merged 4 commits intoTorchSim:mainfrom
reillyosadchey:fix/detach-model-outputs-memory-leak
Mar 4, 2026
Merged

Fix GPU memory leak from missing .detach() in model wrappers#491
CompRhys merged 4 commits intoTorchSim:mainfrom
reillyosadchey:fix/detach-model-outputs-memory-leak

Conversation

@reillyosadchey
Copy link
Contributor

Summary

  • Several model wrappers return tensors still attached to the autograd computation graph, causing the entire forward-pass graph to be retained in GPU memory across simulation steps
  • Added .detach() calls to model outputs in fairchem, orb, metatomic, fairchem_legacy, and graphpes_framework
  • The fairchem_legacy wrapper also used .clone() without .detach() on inputs, retaining graph references via self.data_object

Details

Model Issue Severity
fairchem.py energy/forces/stress not detached High
orb.py predictions + conservative forces/stress not detached High
metatomic.py forces/stress not detached (energy was correct) Medium
fairchem_legacy.py .clone() without .detach() on inputs Medium
graphpes_framework.py external library predictions not detached Low

Without .detach(), .to(dtype=...) and .view() create new graph nodes that still reference the full forward pass. Integrators store these into state.energy/state.forces/state.stress, keeping the graph alive until the next step overwrites them — and longer if trajectory loggers or MC routines .clone() the state.

Test plan

  • Existing tests pass (no functional change — detach only affects graph retention)
  • Verify GPU memory stays flat over long simulations with FairChem/Orb models

Several model wrappers were returning tensors still attached to the
computation graph, causing the entire forward-pass graph to be retained
in memory across simulation steps.

- fairchem: detach energy, forces, stress predictions
- orb: detach prediction outputs and conservative forces/stress
- metatomic: detach forces and stress (energy was already detached)
- fairchem_legacy: use detach().clone() on inputs to prevent graph
  retention via self.data_object
- graphpes_framework: detach predictions from external library
@CompRhys
Copy link
Member

CompRhys commented Mar 4, 2026

Thank you for the PR I will merge it once tests and lint pass.

It is worth noting that we are attempting to move all the models you fixed this for to external posture. These fixes may also need to be applied upstream.

In the pair potential backend refactor I am doing I also included a retrain graph argument to allow for diff-sim tutorial. I think that's a very niche use so wouldn't add the flag here but just an FYI

atomic_graph = state_to_atomic_graph(state, cutoff)
return self._gp_model.predict(atomic_graph, self._properties) # type: ignore[return-value]
preds = self._gp_model.predict(atomic_graph, self._properties) # ty: ignore[call-non-callable]
return {k: v.detach() for k, v in preds.items()}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels cleaner to be to detatch everything like this all in one line at the end? Could you update all the models to follow this pattern?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's already detached then the op is idempotent and so no harm.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call

atomic_graph = state_to_atomic_graph(state, cutoff)
return self._gp_model.predict(atomic_graph, self._properties) # type: ignore[return-value]
preds = self._gp_model.predict(atomic_graph, self._properties) # ty: ignore[call-non-callable]
return {k: v.detach() for k, v in preds.items()}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's already detached then the op is idempotent and so no harm.

Move .detach() calls to a single return statement in each model's
forward method instead of detaching inline at each assignment.
@CompRhys CompRhys enabled auto-merge (squash) March 4, 2026 22:24
@CompRhys
Copy link
Member

CompRhys commented Mar 4, 2026

Thanks!

@CompRhys CompRhys disabled auto-merge March 4, 2026 22:59
@CompRhys CompRhys merged commit 7f85ec4 into TorchSim:main Mar 4, 2026
64 of 66 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants