Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update "phantom grad" to use the same forward-backward transformation code as "vjp" #364

Merged
merged 13 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 1 addition & 7 deletions thunder/benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,14 +501,8 @@ def _run_benchmark(

assert not use_grad_transform or not compile_backward, "Can't set both use_grad_transform and compile_backward!"
if use_grad_transform:
from thunder.core.transforms import _grad_specifier_default

def grad_specifier(outs) -> None:
grad_tensor = benchmark.postprocess_for_backward(outs)
_grad_specifier_default(grad_tensor)

benchmark_callable = constructor(benchmark_fn)
benchmark_callable = grad(benchmark_callable, grad_specifier=grad_specifier)
benchmark_callable = grad(benchmark_callable)
elif compile_backward:

def _fn(*args, **kwargs):
Expand Down
79 changes: 1 addition & 78 deletions thunder/benchmarks/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
import torch
import thunder
from thunder.core.transforms import grad, grad_v1, clear_grads, populate_grads, get_grad, put_grad, put_grads
from thunder.core.transforms import grad, clear_grads, populate_grads, get_grad, put_grad, put_grads
from thunder.core.interpreter import interpret

from thunder.benchmarks import (
Expand Down Expand Up @@ -178,42 +178,6 @@ def wrapper(*args, **kwargs):
return wrapper


# TODO Actually return the fwd, currently just requires computation
# by making the fwd equal to the grad
def thunder_value_and_grad_transform(b: Benchmark, compile_fn: Callable):
module: torch.nn.Module = b.fn()
cfn = compile_fn(module)

# Note on grad_specifier:
# requires the function output actually be computed to compute the grad
def grad_specifier(outs):
if not isinstance(outs, Sequence):
outs = (outs,)

for out in outs:
put_grad(out, out)

cfn_grad = grad(cfn, grad_specifier=grad_specifier)

if isinstance(module, torch.nn.Sequential):

@wraps(cfn_grad)
def wrapper(*args):
clear_grads(cfn)
grads = cfn_grad(args)
populate_grads(grads, cfn, args=args)

return wrapper

@wraps(cfn_grad)
def wrapper(*args, **kwargs):
clear_grads(cfn)
grads = cfn_grad(*args, **kwargs)
populate_grads(grads, cfn, args=args, kwargs=kwargs)

return wrapper


def thunder_grad_transform(b: Benchmark, compile_fn: Callable):
module: torch.nn.Module = b.fn()
cfn = compile_fn(module)
Expand All @@ -238,30 +202,6 @@ def wrapper(*args, **kwargs):
return wrapper


def thunder_grad_transform_v1(b: Benchmark, compile_fn: Callable):
module: torch.nn.Module = b.fn()
cfn = compile_fn(module)
cfn_grad = grad_v1(cfn)

if isinstance(module, torch.nn.Sequential):

@wraps(cfn_grad)
def wrapper(*args):
clear_grads(cfn)
grads = cfn_grad(args)
populate_grads(grads, cfn, args=args)

return wrapper

@wraps(cfn_grad)
def wrapper(*args, **kwargs):
clear_grads(cfn)
grads = cfn_grad(*args, **kwargs)
populate_grads(grads, cfn, args=args, kwargs=kwargs)

return wrapper


def thunder_fwd_bwd(b: Benchmark, compile_fn: Callable):
module: torch.nn.Module = b.fn()
cfn = compile_fn(module)
Expand Down Expand Up @@ -296,29 +236,12 @@ def wrapper(*args, **kwargs):
torch_fwd_bwd = partial(thunder_fwd_bwd, compile_fn=torch_executor)
torchcompile_fwd_bwd = partial(thunder_fwd_bwd, compile_fn=torch_compile_executor)

# Executing with just PyTorch
thunder_torch_grad = partial(thunder_grad_transform, compile_fn=thunder_torch_executor)
thunder_torch_gradv1 = partial(thunder_grad_transform_v1, compile_fn=thunder_torch_executor)
thunder_torch_value_and_grad = partial(thunder_value_and_grad_transform, compile_fn=thunder_torch_executor)

# Default thunder configs
thunder_fwd = partial(thunder_fwd, compile_fn=thunder_executor)
thunder_fwd_bwd = partial(thunder_fwd_bwd, compile_fn=thunder_executor)
thunder_grad = partial(thunder_grad_transform, compile_fn=thunder_executor)
thunder_gradv1 = partial(thunder_grad_transform_v1, compile_fn=thunder_executor)
thunder_value_and_grad = partial(thunder_value_and_grad_transform, compile_fn=thunder_executor)

# Executing with torchcompile as a Thunder executor
thunder_torchcompile_fwd = partial(thunder_fwd, compile_fn=thunder_torch_compile_executor)
thunder_torchcompile_grad = partial(thunder_grad_transform, compile_fn=thunder_torch_compile_executor)
thunder_torchcompile_gradv1 = partial(thunder_grad_transform_v1, compile_fn=thunder_torch_compile_executor)
thunder_torchcompile_value_and_grad = partial(
thunder_value_and_grad_transform, compile_fn=thunder_torch_compile_executor
)

# Executing with just the sdpa executor
thunder_sdpa_grad = partial(thunder_grad_transform, compile_fn=thunder_sdpa_executor)
thunder_sdpa_gradv1 = partial(thunder_grad_transform_v1, compile_fn=thunder_sdpa_executor)

# Executing with just the apex executor
# NOTE apex may or may not be available
Expand Down
18 changes: 14 additions & 4 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,7 +1337,7 @@ def _grad_out_specifier_default(pytree: Any) -> list[TensorProxy]:
# The algorithm for modifying the program has the following steps:
# 1) Flattens the original trace for the grad transform -- ensuing that all top-level symbols have
# a grad function.
def grad(
def __grad(
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
cfn, grad_specifier: Callable = _grad_specifier_default, grad_out_specifier: Callable = _grad_out_specifier_default
) -> Callable:
# Creates a custom transform callable that binds the additional arguments to the grad transform
Expand Down Expand Up @@ -1637,19 +1637,29 @@ def _selector(eligible_nodes: list[Node]) -> int:
return add_transform(cfn, _grad_transform)


def grad_v1(
def grad(
cfn,
) -> Callable:
def grad(func):

@wraps(func)
def grad_func(*args, **kwargs):
_, grads = value_and_grad(func)(*args, **kwargs)
grads = tree_flatten(grads)[0]
grads = [g for g in grads if g is not None]
return grads

return grad_func

def _grad_transform(trc: Trace, *, executors_list: Sequence[Any]) -> Trace:
gradtrc = construct_trace()(grad(trc.python_callable()), *trc.args, **trc.kwargs)
# Using trc.python_callable() makes it impossible to retrace the
# function because the python_callable uses python_ctx which replaces
# symbol occurrences with its symbol._call_ctx function
@wraps(trc.python_callable())
def python_callable(*args, **kwargs):
return eval_trace(trc, *args, **kwargs)

gradtrc = construct_trace()(grad(python_callable), *trc.args, **trc.kwargs)
return gradtrc

cfn._using_grad_transform = True
Expand Down Expand Up @@ -3680,7 +3690,7 @@ def ones_like(x):
elif isinstance(x, NumberProxy):
return type(x.value)(1)
else:
raise ValueError(f"ones_like inside value_and_grad got an unsupported type {type(x)}")
return None

def _value_and_grad(*args, **kwargs):
trace = construct_trace()(func, *args, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions thunder/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,6 +1048,7 @@ def find_producer_symbols(trace: TraceCtx, proxies: Sequence[Proxy], stop_proxie
(__b = ltorch.sub(x, y)
# __b = prims.sub(x, y),)
"""
stop_proxies = tuple(filter(lambda x: isinstance(x, Proxy), stop_proxies))
trace_producers = producers(trace)
result = set()
queue = list(proxies)
Expand Down
8 changes: 2 additions & 6 deletions thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -1405,18 +1405,14 @@ def test_populate_grads_nanogpt(executor, device, dtype):
(x, targets), kwargs = bench.make_batch()

logits, loss = model(x, targets)
loss.backward()
torch.autograd.backward((logits, loss), (torch.ones_like(logits), torch.ones_like(loss)))
torch_grads = extract_grads(model)

clear_grads(model)

tom = executor.make_callable(model)

def grad_specifier(out) -> None:
logits, loss = out
put_grad(loss, ltorch.ones_like(loss))

tom_grad = grad(tom, grad_specifier=grad_specifier)
tom_grad = grad(tom)
thunder_grads = tom_grad(x, targets)

populate_grads(thunder_grads, tom, args=[x, targets])
Expand Down