Fix GPU memory leak from missing .detach() in model wrappers#491
Conversation
Several model wrappers were returning tensors still attached to the computation graph, causing the entire forward-pass graph to be retained in memory across simulation steps. - fairchem: detach energy, forces, stress predictions - orb: detach prediction outputs and conservative forces/stress - metatomic: detach forces and stress (energy was already detached) - fairchem_legacy: use detach().clone() on inputs to prevent graph retention via self.data_object - graphpes_framework: detach predictions from external library
|
Thank you for the PR I will merge it once tests and lint pass. It is worth noting that we are attempting to move all the models you fixed this for to external posture. These fixes may also need to be applied upstream. In the pair potential backend refactor I am doing I also included a retrain graph argument to allow for diff-sim tutorial. I think that's a very niche use so wouldn't add the flag here but just an FYI |
| atomic_graph = state_to_atomic_graph(state, cutoff) | ||
| return self._gp_model.predict(atomic_graph, self._properties) # type: ignore[return-value] | ||
| preds = self._gp_model.predict(atomic_graph, self._properties) # ty: ignore[call-non-callable] | ||
| return {k: v.detach() for k, v in preds.items()} |
There was a problem hiding this comment.
It feels cleaner to be to detatch everything like this all in one line at the end? Could you update all the models to follow this pattern?
There was a problem hiding this comment.
if it's already detached then the op is idempotent and so no harm.
| atomic_graph = state_to_atomic_graph(state, cutoff) | ||
| return self._gp_model.predict(atomic_graph, self._properties) # type: ignore[return-value] | ||
| preds = self._gp_model.predict(atomic_graph, self._properties) # ty: ignore[call-non-callable] | ||
| return {k: v.detach() for k, v in preds.items()} |
There was a problem hiding this comment.
if it's already detached then the op is idempotent and so no harm.
Move .detach() calls to a single return statement in each model's forward method instead of detaching inline at each assignment.
|
Thanks! |
Summary
.detach()calls to model outputs infairchem,orb,metatomic,fairchem_legacy, andgraphpes_frameworkfairchem_legacywrapper also used.clone()without.detach()on inputs, retaining graph references viaself.data_objectDetails
fairchem.pyorb.pymetatomic.pyfairchem_legacy.py.clone()without.detach()on inputsgraphpes_framework.pyWithout
.detach(),.to(dtype=...)and.view()create new graph nodes that still reference the full forward pass. Integrators store these intostate.energy/state.forces/state.stress, keeping the graph alive until the next step overwrites them — and longer if trajectory loggers or MC routines.clone()the state.Test plan