Skip to content

Commit

Permalink
fixing numberproxy trace mode check (#450)
Browse files Browse the repository at this point in the history
  • Loading branch information
jjsjann123 authored May 25, 2024
1 parent 3f9aedd commit 4c9a765
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions thunder/core/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from thunder.core.trace import VariableInterface, get_tracectx, TraceCtx
from thunder.core.baseutils import ProxyInterface, NumberProxyInterface, TensorProxyInterface
import thunder.core.baseutils as baseutils
from thunder.core.langctxs import resolve_method
from thunder.core.langctxs import resolve_method, get_langctx
import thunder.core.devices as devices
import thunder.core.dtypes as dtypes

Expand Down Expand Up @@ -592,13 +592,18 @@ def known_value(self) -> bool:
# fn is the function to call if executing outside a language context
@staticmethod
def _elementwise_unary_helper(a, name, fn, type_promotion_kind=None):
trace: None | TraceCtx = get_tracectx()

vala = pyval(a)

if trace is None:
# Outside of a trace context, operations on NumberProxies are executed by the
# Python interpreter
trace: None | TraceCtx = get_tracectx()
lang: None | LangCtx = None
try:
lang = get_langctx()
except LookupError:
pass
if trace is None or lang is None:
# Outside of a trace or language context, operations on NumberProxies are
# executed by the Python interpreter
baseutils.check(
vala is not None,
lambda: f"Trying to {name} a number with an unknown value",
Expand Down Expand Up @@ -649,7 +654,12 @@ def _elementwise_binary_helper(a, b, name, fn, type_promotion_kind=None):
valb = pyval(b) if isinstance(b, NumberProxy) else b

trace: None | TraceCtx = get_tracectx()
if trace is None:
lang: None | LangCtx = None
try:
lang = get_langctx()
except LookupError:
pass
if trace is None or lang is None:
# Outside of a trace or language context, binary operations on NumberProxies are
# executed by the Python interpreter
baseutils.check(
Expand Down

0 comments on commit 4c9a765

Please sign in to comment.