Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify cudnnex checkers. #57

Merged
merged 2 commits into from
Mar 23, 2024
Merged

Simplify cudnnex checkers. #57

merged 2 commits into from
Mar 23, 2024

Conversation

wujingyue
Copy link
Collaborator

Currently, checkers create graphs for real and try to catch exceptions to decide whether a config is supported. This hides unintentional failures in graph creation, leading to suboptimal UX.

For example, if I put an assert False at the beginning of _make_cudnn_sdpa_backward_graph, I got the following mysterious error that doesn't point to the real error at all.

This PR will remove those dry-runs and rely on existing heuristics to reject unsupported SDPA operations. This will, unfortunately, make cudnnex overtly aggressive because the existing heuristics don't reject all unsupported cases. Therefore, cudnnex will claim more SDPA operations than it can support. It may happen that during execution that support is actually missing and then it would be on thunder to somehow try another executor from its list. This behaviour is probably fine as cudnn is not a default executor yet.

_______________________________________________________________________________________________________________________________________________________________________________________________ test_vjp_correctness_sdpa_cudnnex_manual_grad_forward_scaled_dot_product_attention_nvfuser_cuda_float16 _______________________________________________________________________________________________________________________________________________________________________________________________

    def test():
>       result = template(opinfo, device_str, dtype, executor, comp)

thunder/tests/framework.py:285:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
thunder/tests/test_cudnn_executor.py:229: in test_vjp_correctness_sdpa_cudnnex_manual
    actual_out, actual_grad = cfoo(filtered_args, (v,))
thunder/common.py:779: in _fn
    trc_or_result = trace(compile_data=cd)(processed_function, *args, **kwargs)
thunder/core/interpreter.py:1293: in fn_
    return fn(*args, **kwargs)
thunder/common.py:530: in _trace
    result = fn(*proxyargs, **proxykwargs)
thunder/core/transforms.py:3696: in _vjp
    result, vjp_result = vjp_call(flat_args, cotangents, trace=trace)
thunder/core/transforms.py:3670: in vjp_call_metafunc
    result, env = augmented_forward_pass(*primals, trace=trace, **kwargs)
thunder/core/transforms.py:3477: in augmented_forward_pass
    result, env = eval_trace(
thunder/core/transforms.py:1679: in eval_trace
    prim_func = symbol_mapper(symbol)
thunder/core/transforms.py:3393: in vjp_symbol_mapper
    vjp_impl, backward_fn = make_aug_forward_and_backward(symbol)
thunder/core/vjp_utils.py:52: in make_aug_forward_and_backward
    joint_trace = thunder.trace(inline_trace=False, use_dce=False)(joint_forward_backward, *bsym.args, **bsym.kwargs)
thunder/core/interpreter.py:1293: in fn_
    return fn(*args, **kwargs)
thunder/common.py:530: in _trace
    result = fn(*proxyargs, **proxykwargs)
thunder/core/langctxs.py:124: in _fn
    result = fn(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

query = query, key = key, value = value, attn_mask = None, dropout_p = [FloatProxy name=dropout_p, value=0.0], is_causal = [IntegerProxy (bool type) name=is_causal, value=True]

    @langctx("torch")
    def _cudnn_sdpa_grad(
        query: TensorProxy,
        key: TensorProxy,
        value: TensorProxy,
        attn_mask: None | TensorProxy,
        dropout_p: float = 0.0,
        is_causal: bool = False,
        *,
        scale: None | float = None,
    ):
        primal, softmax_stats, seed, offset = cudnn_sdpa_fwd(
            query, key, value, attn_mask, dropout_p, is_causal, scale=scale
        )

        g = get_grad(primal)
>       grad_query, grad_key, grad_val, grad_attn_mask = cudnn_sdpa_bwd(
            g,
            query,
            key,
            value,
            attn_mask,
            dropout_p,
            is_causal,
            primal,
            softmax_stats,
            seed,
            offset,
            scale=scale,
        )
E       ValueError: not enough values to unpack (expected 4, got 3)

thunder/executors/cudnnex.py:730: ValueError
======================================================================================================================================================================================================================================= short test summary info =======================================================================================================================================================================================================================================
FAILED thunder/tests/test_cudnn_executor.py::test_vjp_correctness_sdpa_cudnnex_manual_grad_forward_scaled_dot_product_attention_nvfuser_cuda_bfloat16 - ValueError: not enough values to unpack (expected 4, got 3)
FAILED thunder/tests/test_cudnn_executor.py::test_vjp_correctness_sdpa_cudnnex_manual_grad_forward_scaled_dot_product_attention_nvfuser_cuda_float16 - ValueError: not enough values to unpack (expected 4, got 3)
============================================================================================================================================================================================================================== 2 failed, 9 passed, 6 warnings in 16.55s ===============================================================================================================================================================================================================================

Currently, checkers create graphs for real and try to catch exceptions
to decide whether a config is supported. This hides unintentional
failures in graph creation. This PR replaces those dry-runs with
heuristics, most of which are already there.
@wujingyue wujingyue requested a review from vedaanta March 22, 2024 23:57
@wujingyue
Copy link
Collaborator Author

https://github.com/Lightning-AI/lit-thunder-LEGACY/pull/2480 approved this already. Ready to merge.

Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @wujingyue

@t-vi t-vi merged commit 8f65c9b into main Mar 23, 2024
36 checks passed
@t-vi t-vi deleted the wjy/exception branch March 23, 2024 12:31
wujingyue added a commit that referenced this pull request Mar 23, 2024
wujingyue added a commit that referenced this pull request Mar 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants