In [10]:
import nvfuser
import thunder
import torch
import torch.nn.functional as F
from nvfuser import FusionDefinition, DataType

In [11]:
from thunder.tests.framework import nvFuserExecutor
import thunder.examine as examine
from thunder.examine import get_fusions

In [12]:
def test_sdpa(nv_enable_sdpa):
    device = 'cuda'
    executor = nvFuserExecutor
    dropout_p = 0.01
    is_causal = False
    scale = None
    dtype = torch.float16
    
    def sdpa_fn(q, k, v, dropout_p, is_causal, scale):
        return F.scaled_dot_product_attention(
            q, k, v, dropout_p=dropout_p, is_causal=is_causal, scale=scale
        )

    torch.manual_seed(0)

    N, H, L, S, E = 4, 8, 16, 16, 16
    q = torch.randn((N, H, L, E), device=device, dtype=dtype, requires_grad=True)
    k = torch.randn((N, H, S, E), device=device, dtype=dtype, requires_grad=True)
    v = torch.randn((N, H, S, E), device=device, dtype=dtype, requires_grad=True)
    grad_out = torch.randn((N, H, L, E), device=device, dtype=dtype)

    tensor_inputs = [q, k, v]
    scalar_inputs = [dropout_p, is_causal, scale]

    compiled_func = thunder.jit(sdpa_fn, executors=executor.executors_list(), nv_enable_sdpa=False, nv_enable_matmul=True, nv_enable_linear=True)
    with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
        attn_out = compiled_func(*tensor_inputs, *scalar_inputs)
    attn_out.backward(grad_out)
    fwd_trace = thunder.last_traces(compiled_func)[-1]
    bwd_trace = thunder.last_backward_traces(compiled_func)[-1]
    fwd_fusion = examine.get_fusions(fwd_trace)
    bwd_fusion = examine.get_fusions(bwd_trace)
    print(nv_enable_sdpa)
    if (nv_enable_sdpa):
        assert len(fwd_fusion) == 1
        assert len(bwd_fusion) == 1
        assert "nv_sdpfa_fwd" in fwd_fusion[-1][-1].name

        # Check nv_sdpfa_fwd is not in bwd_fusion -> that would indicate rematerialization
        assert "nv_sdpfa_bwd" in bwd_fusion[-1][-1].name and "nv_sdpfa_fwd" not in bwd_fusion[-1][-1].name
        
    return fwd_fusion, bwd_fusion

In [13]:
sdpa_fwd, sdpa_bwd = test_sdpa(True)

True


In [47]:
print(sdpa_fwd[0][-1].last_used)
print(sdpa_bwd[0][-1].last_used)


def nvfuser_fusion_id1(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Half, is_cpu=False, stride_order=[3, 2, 1, 0])
    T1 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Half, is_cpu=False, stride_order=[3, 2, 1, 0])
    T2 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Half, is_cpu=False, stride_order=[3, 2, 1, 0])
    S3 = fd.define_scalar(0.0100000, dtype=DataType.Double)
    S4 = fd.define_scalar(False, dtype=DataType.Bool)
    T5, T6, T7, T8 = fd.ops.sdpfa_fwd(T0, T1, T2, S3, S4, None)
    fd.add_output(T5)
    fd.add_output(T6)
    fd.add_output(T7)
    fd.add_output(T8)



def nvfuser_fusion_id2(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Half, is_cpu=False, stride_order=[3, 2, 1, 0])
    T1 = fd.define_ten

In [8]:
non_sdpa_fwd, non_sdpa_bwd = test_sdpa(False)

False


In [9]:
print(non_sdpa_fwd[0][-1].last_used)
print(non_sdpa_bwd[0][-1].last_used)


def nvfuser_fusion_id5(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Half, is_cpu=False, stride_order=[3, 2, 1, 0])
    T1 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Half, is_cpu=False, stride_order=[3, 2, 1, 0])
    T2 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Half, is_cpu=False, stride_order=[3, 2, 1, 0])
    S3 = fd.define_scalar(0.0100000, dtype=DataType.Double)
    S4 = fd.define_scalar(False, dtype=DataType.Bool)
    T5, T6, T7, T8 = fd.ops.sdpfa_fwd(T0, T1, T2, S3, S4, None)
    fd.add_output(T5)
    fd.add_output(T6)
    fd.add_output(T7)
    fd.add_output(T8)



def nvfuser_fusion_id6(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Half, is_cpu=False, stride_order=[3, 2, 1, 0])
    T1 = fd.define_ten