In [1]:
import numpy as np
import nvfuser
import thunder
from thunder.examine import examine, get_fusions
from thunder.executors.nvfuserex import nvfuser_version, nvfuserex
import torch
import torch.nn.functional as F

In [7]:
def fn(xy):
    x, y = torch.split(xy, xy.shape[-1] // 2, dim=-1)
    return torch.nn.functional.silu(x) * y

In [38]:
dtype = torch.float32
xy = torch.randn(4, 16, dtype=dtype, device='cuda', requires_grad=True)
grads = torch.randn(xy.shape[0], xy.shape[-1]//2, dtype=dtype, device='cuda')
jfunc = thunder.jit(fn, executors=[nvfuserex])
out = jfunc(xy)
out.backward(grads)

In [39]:
fwd_trace = thunder.last_traces(jfunc)[-1]
bwd_trace = thunder.last_backward_traces(jfunc)[-1]

In [40]:
fwd_trace.python_ctx()['nvFusion0'].last_used


def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T1 = fd.ops.slice(T0, start_indices=[0, 0], end_indices=[4, 8], strides=[1, 1])
    T2 = fd.ops.slice(T0, start_indices=[0, 8], end_indices=[4, 16], strides=[1, 1])
    T3 = fd.ops.neg(T1)
    T4 = fd.ops.exp(T3)
    S5 = fd.define_scalar(1.00000, dtype=DataType.Double)
    T6 = fd.ops.add(S5, T4)
    T7 = fd.ops.reciprocal(T6)
    T8 = fd.ops.mul(T1, T7)
    T9 = fd.ops.mul(T8, T2)
    fd.add_output(T9)


In [41]:
print(bwd_trace)

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, _, = saved_for_backward
  clear_mutable_collection(saved_for_backward)
  del saved_for_backward
  t8, = cotangents
  clear_mutable_collection(cotangents)
  del cotangents
  xy, = C0
  clear_mutable_collection(C0)
  del C0
  [t25] = nvFusion0(xy, t8)
    # t0 = prims.slice_prim(xy, [0, 0], [4, 8], [1, 1])  # t0: "cuda:0 f32[4, 8]"
    # t2 = prims.neg(t0)  # t2: "cuda:0 f32[4, 8]"
    # t1 = prims.slice_prim(xy, [0, 8], [4, 16], [1, 1])  # t1: "cuda:0 f32[4, 8]"
    # t3 = prims.exp(t2)  # t3: "cuda:0 f32[4, 8]"
    # t15 = prims.mul(t1, t8)  # t15: "cuda:0 f32[4, 8]"
    # t4 = prims.add(1.0, t3)  # t4: "cuda:0 f32[4, 8]"
    # t18 = prims.mul(t0, t15)  # t18: "cuda:0 f32[4, 8]"
    # t5 = prims.reciprocal(t4)  # t5: "c

In [42]:
bwd_trace.python_ctx()['nvFusion0'].last_used


def nvfuser_fusion_id1(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T1 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T2 = fd.ops.slice(T0, start_indices=[0, 0], end_indices=[4, 8], strides=[1, 1])
    T3 = fd.ops.neg(T2)
    T4 = fd.ops.slice(T0, start_indices=[0, 8], end_indices=[4, 16], strides=[1, 1])
    T5 = fd.ops.exp(T3)
    T6 = fd.ops.mul(T4, T1)
    S7 = fd.define_scalar(1.00000, dtype=DataType.Double)
    T8 = fd.ops.add(S7, T5)
    T9 = fd.ops.mul(T2, T6)
    T10 = fd.ops.reciprocal(T8)
    T11 = fd.ops.neg(T9)
    T12 = fd.ops.mul(T11, T10)
    T13 = fd.ops.mul(T12, T10)
    T14 = fd.ops.mul(T13, T5)
    T15 = fd.ops.mul(T2, T10)
    T16 = fd.ops.neg(T14)
    T17 = fd.ops.mul(T10, T6)
    T18 = fd.ops.mul(T15, T1)
    T19 = fd.ops.add(T17, T16)
    T20 = fd.ops.cat([T19, T18], dim=-1)