In [2]:
import torch
import thunder



In [3]:
from thunder.executors.nvfuserex import nvfuserex
from thunder.benchmarks import NanoGPTBlockBenchmark

In [4]:
bench = NanoGPTBlockBenchmark(
        config="gpt2", device="cuda:0", dtype=thunder.bfloat16, requires_grad=True
    )
args, kwargs = bench.make_batch()

jfn = thunder.jit(
  bench.fn(), executors=[nvfuserex], 
  nv_enable_sdpa=True, 
  nv_enable_matmul=True, 
  nv_enable_linear=True,
  disable_replace_uniform=True
)

In [5]:
out = jfn(*args, **kwargs)

In [6]:
fwd_traces = thunder.last_traces(jfn)[-1].python_ctx()

In [7]:
fwd_traces['nvFusion0'].last_used


def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[16, 128, 768], contiguity=[True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T1 = fd.define_tensor(shape=[768], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T2 = fd.define_tensor(shape=[768], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T3 = fd.define_tensor(shape=[2304, 768], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T4 = fd.define_tensor(shape=[2304], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T5 = fd.define_tensor(shape=[768, 768], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T6 = fd.define_tensor(shape=[768], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T7 = fd.define_tensor(shape=[768], contiguity=[True], dtype=DataType.BFloat16

In [8]:
grads = torch.randn_like(out, dtype=out.dtype, device=out.device)
out.backward(grads)


In [11]:
thunder.last_backward_traces(jfn)[-1].python_ctx()['nvFusion0'].last_used


def nvfuser_fusion_id1(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[16, 128, 3072], contiguity=[True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T1 = fd.define_tensor(shape=[16, 128, 768], contiguity=[True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T2 = fd.define_tensor(shape=[16, 128, 768], contiguity=[True, True, True], dtype=DataType.Bool, is_cpu=False, stride_order=[2, 1, 0])
    T3 = fd.define_tensor(shape=[768, 3072], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T4 = fd.define_tensor(shape=[16, 128, 768], contiguity=[True, True, True], dtype=DataType.Bool, is_cpu=False, stride_order=[2, 1, 0])
    T5 = fd.define_tensor(shape=[16, 128, 768], contiguity=[True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T6 = fd.define_tensor(shape=[3072, 768], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, s