diff --git a/thunder/core/rematerialization.py b/thunder/core/rematerialization.py index 9333117ba..56a1638b2 100644 --- a/thunder/core/rematerialization.py +++ b/thunder/core/rematerialization.py @@ -12,7 +12,7 @@ from thunder.core import prims, utils from thunder.core.baseutils import BoundSymbolInterface, ProxyInterface from thunder.core.prims import PrimIDs -from thunder.core.proxies import TensorProxy, variableify +from thunder.core.proxies import TensorProxy, variableify, NumberProxy from thunder.core.pytree import tree_flatten, tree_unflatten from thunder.core.symbol import has_tags from thunder.core.trace import from_trace, TraceCtx, TraceProvenance @@ -332,6 +332,8 @@ def add_edge(src, dst, capacity): def get_weight(var): if isinstance(var, TensorProxy): return WEIGHT * var.dtype.bytes + elif isinstance(var, NumberProxy): + return 0.0 return WEIGHT def add_edges(var):