In [1]:
import torch
from nvfuser import (
    FusionDefinition,
    DataType,
    ParallelType,
    MemoryType,
    LoadStoreOpType,
)

In [2]:
batch_size = 1024
tensor_size = 4096
inputs = [
    torch.randn(batch_size, tensor_size, dtype=torch.bfloat16, device="cuda"),
]

In [3]:
class LayerNorm(FusionDefinition):
    def definition(self):
        self.t0 = self.from_pytorch(inputs[0])
        self.s0 = self.define_scalar(1e-6, dtype=DataType.Double)
        self.norm_const = self.define_scalar(tensor_size, dtype=DataType.Int)

        self.mean_cast = self.ops.cast(self.t0, dtype=DataType.Float)
        self.sum0 = self.ops.sum(self.mean_cast, dims=[-1])
        # NOTE Manually broadcast because fusion definition cannot access hidden reduction tensor view.
        self.bcast_sum0 = self.ops.broadcast(self.sum0, [False, True])
        self.mean = self.ops.div(self.bcast_sum0, self.norm_const)

        self.var_cast = self.ops.cast(self.t0, dtype=DataType.Float)
        self.diff = self.ops.sub(self.var_cast, self.mean)
        self.diff_sq = self.ops.mul(self.diff, self.diff)
        self.sum1 = self.ops.sum(self.diff_sq, dims=[-1])
        # NOTE Manually broadcast because fusion definition cannot access hidden reduction tensor view.
        self.bcast_sum1 = self.ops.broadcast(self.sum1, [False, True])
        self.var = self.ops.div(self.bcast_sum1, self.norm_const)

        self.t0_cast = self.ops.cast(self.t0, dtype=DataType.Float)
        self.t0_diff = self.ops.sub(self.t0_cast, self.mean)
        self.var_eps = self.ops.sqrt(self.ops.add(self.var, self.s0))
        self.t0_norm = self.ops.div(self.t0_diff, self.var_eps)

        self.t0_norm_cast = self.ops.cast(self.t0_norm, dtype=DataType.BFloat16)
        self.add_output(self.t0_norm_cast)

In [4]:
# Build FusionDefinition
fn = LayerNorm()
fn._setup_definition()
fn.definition()
fn._finalize_definition()

In [5]:
# Create user schedule for this input
# NOTE: Schedules defined by the user for specific input sizes.
fn._setup_schedule(inputs)
print(fn._user_schedule_ir())


%kernel {
T11_l[ iS22{i0}, iS23{i1} ]
   = __bfloat2float(T0_g[ iS0{i0}, iS1{i1} ]);
T1_l[ iS2{i0}, iS3{i1} ]
   = __bfloat2float(T0_g[ iS0{i0}, iS1{i1} ]);
T2_l[ iS4{i0}, rS5{i1} ]
   = reduction( T1_l[ iS2{i0}, iS3{i1} ], op = add, initial value = float(0), allreduce = false )
T3_l[ iS6{i0}, bS7{1} ]
   = broadcast( T2_l[ iS4{i0}, rS5{i1} ] )
f16 = (float)(4096);
T4_l[ iS8{i0}, bS9{1} ]
   = T3_l[ iS6{i0}, bS7{1} ]
   / f16;
T12_l[ iS24{i0}, iS25{i1} ]
   = T11_l[ iS22{i0}, iS23{i1} ]
   - T4_l[ iS8{i0}, bS9{1} ];
T5_l[ iS10{i0}, iS11{i1} ]
   = __bfloat2float(T0_g[ iS0{i0}, iS1{i1} ]);
T6_l[ iS12{i0}, iS13{i1} ]
   = T5_l[ iS10{i0}, iS11{i1} ]
   - T4_l[ iS8{i0}, bS9{1} ];
T7_l[ iS14{i0}, iS15{i1} ]
   = T6_l[ iS12{i0}, iS13{i1} ]
   * T6_l[ iS12{i0}, iS13{i1} ];
T8_l[ iS16{i0}, rS17{i1} ]
   = reduction( T7_l[ iS14{i0}, iS15{i1} ], op = add, initial value = float(0), allreduce = false )
T9_l[ iS18{i0}, bS19{1} ]
   = broadcast( T8_l[ iS16{i0}, rS17{i1} ] )
f41 = (float)(4096);
T10

In [6]:
# create cache tensors
cache_after_t0 = fn.sched.cache_after(fn.t0)
fn.sched.set_memory_type(cache_after_t0, MemoryType.shared)

cache_before_t0_norm = fn.sched.cache_before(fn.t0_norm)
cache_tvs = [cache_after_t0, cache_before_t0_norm]
print(list(map(fn.sched.to_string, cache_tvs)))

['T17_s[ iS34{i0}, iS35{i1} ]', 'T18_l[ iS30{i0}, iS31{i1} ]']


In [7]:
# Schedule Reference Tensor
reference_tv = fn.mean
fn.sched.split(reference_tv, dim=-1, factor=256 * 4)
fn.sched.split(reference_tv, dim=-1, factor=4)
fn.sched.transform_like(reference_tv)
print(fn.sched.to_string(reference_tv))

T4_l[ iS8{i0}, bS38{( ceilDiv(1, 1024) )}, bS40{( ceilDiv(1024, 4) )}, bS41{4} ]


In [8]:
# Add rfactor TensorViews
reduction_tvs = list(
    filter(fn.sched.is_reduction, fn.sched.tensors())
)
assert len(reduction_tvs) == 2
rfactor_tvs = [fn.sched.rfactor(tv, dims=[-1]) for tv in reduction_tvs]
print(list(map(fn.sched.to_string, rfactor_tvs)))

['T19_l[ iS114{i0}, iS116{( ceilDiv(i1, 1024) )}rf, iS118{( ceilDiv(1024, 4) )}rf, rS119{4}rf ]', 'T20_l[ iS123{i0}, iS125{( ceilDiv(i1, 1024) )}rf, iS127{( ceilDiv(1024, 4) )}rf, rS128{4}rf ]']


In [9]:
# Add common parallelization
fn.sched.parallelize(reference_tv, axis := 0, ParallelType.grid_x)
fn.sched.parallelize(reference_tv, axis := -2, ParallelType.block_x)
fn.sched.parallelize_like(reference_tv)
print(fn.sched.to_string(reference_tv))

T4_l[ iblockIdx.x8{i0}, bS38{( ceilDiv(1, 1024) )}, bthreadIdx.x40{( ceilDiv(1024, 4) )}, bS41{4} ]


In [10]:
# Vectorize input load and output store
fn.sched.parallelize(cache_after_t0, axis := -1, ParallelType.vectorize)
fn.sched.parallelize(fn.t0_norm, axis := -1, ParallelType.vectorize)
print(fn.sched.to_string(fn.t0_norm))

T15_l[ iblockIdx.x36{i0}, iS106{( ceilDiv(i1, 1024) )}, ithreadIdx.x108{( ceilDiv(1024, 4) )}, iV109{4} ]


In [11]:
# Add computeAt; inline_most automatically skips vectorized iterDomains
fn.sched.inline_most()
print(fn._user_schedule_ir())


%kernel {
T17_s[ iblockIdx.x34{i0}, iS54{( ceilDiv(i1, 1024) )}, ithreadIdx.x56{( ceilDiv(1024, 4) )}, iV57{4} ] ca_pos( 1 )
   = Set( T0_g[ iS0{i0}, iS58{( ceilDiv(i1, 1024) )}, iS60{( ceilDiv(1024, 4) )}, iS61{4} ], cache_op=Streaming )
T11_l[ iblockIdx.x22{i0}, iS70{( ceilDiv(i1, 1024) )}, ithreadIdx.x72{( ceilDiv(1024, 4) )}, iS73{4} ] ca_pos( 4 ) produce_pos( 1 )
   = __bfloat2float(T17_s[ iblockIdx.x34{i0}, iS54{( ceilDiv(i1, 1024) )}, ithreadIdx.x56{( ceilDiv(1024, 4) )}, iV57{4} ] ca_pos( 1 ));
T1_l[ iblockIdx.x2{i0}, iS62{( ceilDiv(i1, 1024) )}, ithreadIdx.x64{( ceilDiv(1024, 4) )}, iS65{4} ] ca_pos( 4 ) produce_pos( 1 )
   = __bfloat2float(T17_s[ iblockIdx.x34{i0}, iS54{( ceilDiv(i1, 1024) )}, ithreadIdx.x56{( ceilDiv(1024, 4) )}, iV57{4} ] ca_pos( 1 ));
T19_l[ iblockIdx.x114{i0}, iS116{( ceilDiv(i1, 1024) )}rf, ithreadIdx.x118{( ceilDiv(1024, 4) )}rf, rS119{4}rf ] ca_pos( 3 ) produce_pos( 4 )
   = reduction( T1_l[ iblockIdx.x2{i0}, iS62{( ceilDiv(i1, 1024) )}, ithreadIdx.x6

In [12]:
# Compile Fusion
fn._finalize_schedule(inputs)

In [13]:
nvf_out = fn.execute(inputs, profile=True)
print(nvf_out)

[tensor([[ 2.4844, -0.8281, -0.3281,  ..., -0.4902,  1.3906,  0.2598],
        [ 1.9297, -0.9648,  0.6992,  ...,  0.7695, -0.1060,  0.9023],
        [-1.0859,  0.4551, -1.2500,  ..., -0.5625,  0.8242, -1.7109],
        ...,
        [ 0.7734, -0.6328, -1.9062,  ..., -0.6914,  1.0625, -0.2480],
        [ 0.6250, -0.7188,  1.6875,  ..., -0.2168, -1.7188,  0.0840],
        [-0.4746, -0.2656,  0.2402,  ...,  0.3086,  2.5938,  0.4863]],
       device='cuda:0', dtype=torch.bfloat16)]


In [14]:
torch_out = torch.nn.functional.layer_norm(inputs[0], normalized_shape=inputs[0].shape[1:])
print(torch_out)

tensor([[ 2.4844, -0.8281, -0.3281,  ..., -0.4902,  1.3906,  0.2598],
        [ 1.9297, -0.9648,  0.6992,  ...,  0.7695, -0.1060,  0.9023],
        [-1.0859,  0.4551, -1.2500,  ..., -0.5625,  0.8242, -1.7109],
        ...,
        [ 0.7734, -0.6328, -1.9062,  ..., -0.6914,  1.0625, -0.2480],
        [ 0.6250, -0.7188,  1.6875,  ..., -0.2168, -1.7188,  0.0840],
        [-0.4746, -0.2656,  0.2402,  ...,  0.3086,  2.5938,  0.4863]],
       device='cuda:0', dtype=torch.bfloat16)


In [15]:
torch.allclose(nvf_out[0], torch_out, rtol=1e-2, atol=1e-2)

True

In [16]:
def print_kernel_profile(kp):
    basic_information = f"name: {kp.name}, schedule: {kp.scheduler}, segment_id: {kp.segment_id}, device: {kp.device}, stream: {kp.stream}"
    print(basic_information)

    kernel_information = f"compile time: {kp.compile_time_ms:.2f} ms, grid: {kp.grid_str}, block: {kp.block_str}, registers: {kp.registers}"
    print(kernel_information)

    runtime_information = f"input size: {kp.input_bytes} bytes, output size: {kp.output_bytes} bytes, time: {kp.time_ms:2f} ms"
    print(runtime_information)

    bandwidth_information = f"Effective Bandwidth: {kp.effective_bandwidth_gbs:.2f} GB/s, Peak Bandwidth: {kp.percentage_peak_bandwidth:2f}%"
    print(bandwidth_information)

In [17]:
kps = fn.profile().kernel_profiles
for kp in kps:
    print_kernel_profile(kp)

name: nvfuser_none_f37_c0_r0_g0, schedule: user, segment_id: 0, device: 0, stream: 7
compile time: 88.50 ms, grid: [1024, 1, 1], block: [256, 1, 1], registers: 30
input size: 8388608 bytes, output size: 8388608 bytes, time: 0.035584 ms
Effective Bandwidth: 471.48 GB/s, Peak Bandwidth: 50.366844%
