Skip to content

Commit

Permalink
Support nvfuser fd.add_output(output, alias_input) (#50)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ivan Yashchuk <IvanYashchuk@users.noreply.github.com>
  • Loading branch information
3 people committed Apr 9, 2024
1 parent 393828f commit 3a91b01
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 11 deletions.
16 changes: 16 additions & 0 deletions thunder/core/prims.py
Expand Up @@ -250,6 +250,7 @@ class PrimIDs(Enum):
BATCH_NORM = auto()
# Memory access methods
ITEM = auto()
COPY_ = auto()


class OpTags(Enum):
Expand Down Expand Up @@ -3607,3 +3608,18 @@ def check_type_device_shape(param, param_name):


batch_norm = make_prim(PrimIDs.BATCH_NORM, "batch_norm", meta=batch_norm_meta, tags=(OpTags.REDUCTION_OP,))


def copy__meta(
copy_from: TensorProxy,
copy_to: TensorProxy,
):
utils.check_type(copy_from, TensorProxy)
utils.check_type(copy_to, TensorProxy)
utils.check_same_device(copy_from, copy_to)
utils.check_same_shape(copy_from, copy_to)
utils.check_same_dtype(copy_from, copy_to)
return TensorProxy(like=copy_to)


copy_ = make_prim(PrimIDs.COPY_, "copy_", meta=copy__meta, tags=(OpTags.DONT_DCE,))
10 changes: 5 additions & 5 deletions thunder/core/rematerialization.py
Expand Up @@ -87,7 +87,7 @@ def find_external_consumer_inputs(
Tuple[ProxyInterface, ...]: Consumer's inputs that must be included in
the input of the consumer.
"""
all_produced_vars = tuple(chain.from_iterable((y for y in x.flat_outs) for x in producer.subsymbols))
all_produced_vars = tuple(chain.from_iterable((y for y in x.flat_proxy_outs) for x in producer.subsymbols))
external_consumer_inputs_names = tuple(
sorted(
{x.name for x in consumer.args}
Expand Down Expand Up @@ -123,7 +123,7 @@ def apply_rematerialization_for_producer(
new_producer_output_names = tuple(
x for x in new_producer_output_names if x not in (y.name for y in producer.flat_args)
)
all_produced_vars = tuple(chain.from_iterable((y for y in x.flat_outs) for x in producer.subsymbols))
all_produced_vars = tuple(chain.from_iterable((y for y in x.flat_proxy_outs) for x in producer.subsymbols))
# Choose the new producer's output from all the produced variables.
new_producer_output = tuple(x for x in all_produced_vars if x.name in new_producer_output_names)
new_producer_output = tuple(sorted(new_producer_output, key=lambda x: x.name))
Expand Down Expand Up @@ -151,7 +151,7 @@ def apply_rematerialization_for_consumer(
# We need to keep consumer's inputs that are not in the cut and are not
# produced by the producer. We call these inputs "external inputs".
external_inputs = find_external_consumer_inputs(producer, consumer)
all_produced_vars = tuple(chain.from_iterable((y for y in x.flat_outs) for x in producer.subsymbols))
all_produced_vars = tuple(chain.from_iterable((y for y in x.flat_proxy_outs) for x in producer.subsymbols))
cut_names = tuple(map(lambda x: x.name, cut)) if isinstance(cut[0], ProxyInterface) else tuple(cut)
cut_inputs = tuple(filter(lambda x: x.name in cut_names, (*all_produced_vars, *producer.args)))
new_consumer_args = cut_inputs + external_inputs
Expand Down Expand Up @@ -342,7 +342,7 @@ def add_edges(var):
for user in combined_consumers._dict.get(var_name, tuple()):
if user.sym.id in sym_skip_list:
continue
for out in user.flat_outs:
for out in user.flat_proxy_outs:
user_name = out.name
add_edge(var_name + "_out", user_name + "_in", capacity=float("inf"))

Expand All @@ -356,7 +356,7 @@ def add_edges(var):
add_edges(var)

for symbol in chain(producer.subsymbols, consumer.subsymbols):
for var in symbol.flat_outs:
for var in symbol.flat_proxy_outs:
add_edges(var)

g = Graph(
Expand Down
9 changes: 6 additions & 3 deletions thunder/core/transforms.py
Expand Up @@ -2467,6 +2467,7 @@ def zeros_like(x):
prims.PrimIDs.LOG2: lambda x: (prims.log2(x), (x,)),
prims.PrimIDs.ZETA: lambda x, y: (prims.zeta(x, y), (x, y)),
prims.PrimIDs.FMOD: lambda x, y: (prims.fmod(x, y), (x, y)),
prims.PrimIDs.COPY_: lambda x, y: (prims.copy_(x, y), tuple()),
}


Expand Down Expand Up @@ -2497,6 +2498,8 @@ def zeros_like(x):
prims.PrimIDs.LOG1P: lambda x, g: g / (x + 1),
prims.PrimIDs.LOG2: lambda x, g: g / (x * 0.6931471805599453),
prims.PrimIDs.FMOD: lambda x, y, g: (g, -g * prims.trunc(x / y)),
# The copy should not be differentiable. We return None to enable the generation of the backward graph through them.
prims.PrimIDs.COPY_: lambda g: (None, None),
}


Expand Down Expand Up @@ -3164,6 +3167,8 @@ def iter_bound_symbols(bound_symbols):
for symbol in bound_symbols:
if symbol.sym.id in transform_skip_list:
continue
elif symbol.output is None:
continue
else:
yield symbol

Expand Down Expand Up @@ -3515,9 +3520,7 @@ def put_grad(v: Variable, val: Any) -> None:
init_cotangents = init_cotangents[0]
safe_map_flat(put_grad, trace.output, init_cotangents)

for symbol in reversed(trace.bound_symbols):
if symbol.sym.id in transform_skip_list:
continue
for symbol in reversed(list(iter_bound_symbols(trace.bound_symbols))):
symbol_output = sequencify(symbol.output)

cotangents = tree_map(get_grad, symbol_output)
Expand Down
39 changes: 36 additions & 3 deletions thunder/executors/nvfuserex_impl.py
Expand Up @@ -511,9 +511,6 @@ def get_fd(input_descriptors) -> FusionDefinition:
return create_fd(bsyms, input_descriptors, sorted_unique_inputs, sorted_unique_outputs)

fdw = FusionDefinitionWrapper(get_fd, name, get_fd.cache_info, get_fd.cache_clear)
# Avoid hitting nvFuser error when there is no output
if not sorted_unique_outputs:
return lambda *args: tuple()
return fdw


Expand Down Expand Up @@ -835,6 +832,19 @@ def _can_fuse_node(n: Node):
fused_bsyms.extend(fusion.bound_symbols)
fused_bsyms.extend(epilogue)

# Force return operator to be the last one in the fused_bsyms
if fused_bsyms[-1].sym.id != PrimIDs.RETURN:
return_idx: int = -1
for i, fused_bsym in enumerate(fused_bsyms):
if fused_bsym.sym.id == PrimIDs.RETURN:
return_idx = i
break
utils.check(
return_idx != -1,
lambda: f"Return operator does not exist in bound symbols",
)
fused_bsyms.append(fused_bsyms.pop(return_idx))

fusedtrace.bound_symbols = fused_bsyms

# Some of the operations might be better placed with its consumers (for
Expand Down Expand Up @@ -2040,6 +2050,29 @@ def batch_norm(
register_supported(PrimIDs.BATCH_NORM, batch_norm, _batch_norm_check)


def _copy__check(
copy_from: TensorProxy,
copy_to: TensorProxy,
) -> bool:
return are_supported_tensors(copy_from, copy_to)


def copy_(
copy_from: TensorProxy,
copy_to: TensorProxy,
*,
fd: FusionDefinition,
lc_to_nv_map: dict,
) -> Any:
nvcopy_from = getnv(copy_from, fd, lc_to_nv_map)
nvcopy_to = getnv(copy_to, fd, lc_to_nv_map)
fd.add_output(nvcopy_from, alias_input=nvcopy_to)
return nvcopy_to


register_supported(PrimIDs.COPY_, copy_, _copy__check)


# Removes excessive float casts, like those that occur when autocasting
# NOTE This passes actually changes a program's semantics, because it will take a sequence like
# fp32 -> fp16 -> fp32 and remove all the operations, but casting fp32 values to fp16 can
Expand Down
9 changes: 9 additions & 0 deletions thunder/executors/torchex.py
Expand Up @@ -1806,3 +1806,12 @@ def is_float_type(self, input):
# We force the registration of the backend here to not use
# the torch backend when diverting isinstance
einops._backends._type2backend[TensorProxy] = EinopsThunderBackend()


def _copy__impl(copy_from, copy_to):
copy_to.copy_(copy_from)
return copy_to


copy_ = ex.register_operator("copy_", meta=prims.copy_, tags=(prims.OpTags.DONT_DCE,), fn=_copy__impl)
_register_implementation(prims.copy_, copy_, checker=_always_executable)
4 changes: 4 additions & 0 deletions thunder/extend/__init__.py
Expand Up @@ -206,6 +206,7 @@ def register_operator(
*,
like: None | Callable = None,
meta: None | Callable = None,
tags: None | list[Any] = None,
module: None | type | ModuleType = None,
fn: None | Callable = None,
bind_postprocess: None | Callable = None,
Expand All @@ -219,6 +220,8 @@ def register_operator(

# NOTE Directly specifying a meta function makes the operation a prim
is_prim = meta is not None
# Set tags to be the same as 'like' if 'tags' is not specified
tags = like.tags if (tags is None and like is not None and hasattr(like, "tags")) else tags
meta = meta if meta is not None else like
call_ctx: None | dict[str, Callable] = None if fn is None else {name: fn}

Expand All @@ -236,6 +239,7 @@ def _bind_postprocess(bsym: BoundSymbol) -> None:
executor=self,
_bind_postprocess=_bind_postprocess,
python_printer=python_printer,
tags=tags,
)
self.opmap[name] = sym

Expand Down
89 changes: 89 additions & 0 deletions thunder/tests/test_inplace_copy.py
@@ -0,0 +1,89 @@
from functools import partial

import torch
from torch.testing import assert_close, make_tensor

import thunder
import thunder.core.dtypes as datatypes
import thunder.torch as ttorch
from thunder.tests.framework import instantiate


@instantiate()
def test_prim_inplace_copy_fwd(executor, device, dtype):
def torch_foo(x, y):
z = x * y
z = z + z
z = x + z
o = x.copy_(z)
return o

def foo(x, y):
z = x * y
z = z + z
z = x + z
# NOTE: nvfuserex doesn't support `return z`, i.e. the copy_from argument
o = thunder.core.prims.copy_(z, x)
return o

traced_nvfuser_foo = executor.make_callable(foo)

tdtype = ttorch.to_torch_dtype(dtype)
a = make_tensor((4, 4), device=device, dtype=tdtype, requires_grad=False)
b = make_tensor((4, 4), device=device, dtype=tdtype, requires_grad=False)
a1 = a.detach().clone()
b1 = b.detach().clone()
thunder_result = traced_nvfuser_foo(a, b)
torch_result = torch_foo(a1, b1)

custom_comparator = (
partial(assert_close, atol=1e-2, rtol=1e-2)
if dtype in (datatypes.bfloat16, datatypes.float16)
else assert_close
)
custom_comparator(thunder_result, torch_result)
custom_comparator(a, a1)


@instantiate(dtypes=(datatypes.floating,))
def test_prim_inplace_copy_bwd(executor, device, dtype):
def torch_foo(x, y):
z = x * y
z = z * x
o = x.copy_(z)
p = y * y
return p

def foo(x, y):
z = x * y
z = z * x
o = thunder.core.prims.copy_(z, x)
p = y * y
return p

traced_nvfuser_foo = executor.make_callable(foo)

tdtype = ttorch.to_torch_dtype(dtype)
a = make_tensor((4, 4), device=device, dtype=tdtype, requires_grad=False)
b = make_tensor((4, 4), device=device, dtype=tdtype, requires_grad=True)
a1 = a.detach().clone()
b1 = b.detach().clone()
b1.requires_grad_()

thunder_result = traced_nvfuser_foo(a, b)
torch_result = torch_foo(a1, b1)
assert_close(thunder_result, torch_result)
custom_comparator = (
partial(assert_close, atol=1e-2, rtol=1e-2)
if dtype in (datatypes.bfloat16, datatypes.float16)
else assert_close
)
custom_comparator(a, a1)

g = torch.ones_like(thunder_result)
thunder_result.backward(g)

g1 = torch.ones_like(torch_result)
torch_result.backward(g1)
assert_close(g, g1)
assert_close(b.grad, b1.grad)

0 comments on commit 3a91b01

Please sign in to comment.