In [7]:
import torch
import triton
import triton.language as tl

DEVICE = torch.device(f'cuda:{torch.cuda.current_device()}')


In [8]:
@triton.jit
def _add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
    PID = tl.program_id(axis=0)
    block_start = PID * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    # load data from HBM to SRAM
    x = tl.load(x_ptr + offsets, mask=mask, other=None)
    y = tl.load(y_ptr + offsets, mask=mask, other=None)
    output = x + y
    tl.store(output_ptr + offsets, output, mask=mask)



In [9]:
def add_tri(x: torch.Tensor, y: torch.Tensor):
    output = torch.empty_like(x)
    # check tensors are on same device
    assert x.device == y.device == DEVICE
    # define our launch grid
    n_elements = x.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)  # for 4096 inputs, 4,0

    _add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
    return output



In [10]:
def test_add_kernel(size: int, atol=1e-3, rtol=1e-3, device: torch.device = DEVICE):
    torch.manual_seed(0)
    x = torch.randn(size, device=device)
    y = torch.randn(size, device=device)
    # run Pytorch and Triton add kernel
    z_triton = add_tri(x, y)
    z_torch = x + y
    torch.testing.assert_close(z_triton, z_torch, atol=atol, rtol=rtol)
    print("Triton add kernel passed")


In [11]:
NUM_ELEMENTS = 1024
grid = lambda meta: (triton.cdiv(NUM_ELEMENTS, meta['BLOCK_SIZE']),)
x = torch.randn(NUM_ELEMENTS, device=DEVICE)
y = torch.randn(NUM_ELEMENTS, device=DEVICE)
z = torch.empty_like(x)
kernel = _add_kernel[grid](x, y, z, NUM_ELEMENTS, BLOCK_SIZE=1024)

In [12]:
# Print all keys in dict kernel.asm
[key for key in kernel.asm.keys()]

['ttir', 'ttgir', 'llir', 'ptx', 'cubin']

In [13]:
print(kernel.asm['ptx'])


//
// Generated by LLVM NVPTX Back-End
//

.version 8.4
.target sm_75
.address_size 64

	// .globl	_add_kernel             // -- Begin function _add_kernel
                                        // @_add_kernel
.visible .entry _add_kernel(
	.param .u64 .ptr .global .align 1 _add_kernel_param_0,
	.param .u64 .ptr .global .align 1 _add_kernel_param_1,
	.param .u64 .ptr .global .align 1 _add_kernel_param_2,
	.param .u32 _add_kernel_param_3,
	.param .u64 .ptr .global .align 1 _add_kernel_param_4
)
.reqntid 128, 1, 1
{
	.reg .pred 	%p<7>;
	.reg .b32 	%r<33>;
	.reg .f32 	%f<25>;
	.reg .b64 	%rd<11>;
	.loc	1 2 0                           // 2601426447.py:2:0
$L__func_begin0:
	.loc	1 2 0                           // 2601426447.py:2:0

// %bb.0:
	ld.param.u64 	%rd7, [_add_kernel_param_0];
	ld.param.u64 	%rd8, [_add_kernel_param_1];
$L__tmp0:
	.loc	1 3 24                          // 2601426447.py:3:24
	mov.u32 	%r25, %ctaid.x;
	.loc	1 4 24                          // 2601426447.py:4:24
	shl.b32

In [14]:
print(kernel.asm['ttir'])


#loc = loc("/tmp/ipykernel_3388653/2601426447.py":2:0)
module {
  tt.func public @_add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/tmp/ipykernel_3388653/2601426447.py":2:0), %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/tmp/ipykernel_3388653/2601426447.py":2:0), %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/tmp/ipykernel_3388653/2601426447.py":2:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/tmp/ipykernel_3388653/2601426447.py":2:0)) attributes {noinline = false} {
    %c1024_i32 = arith.constant 1024 : i32 loc(#loc1)
    %0 = tt.get_program_id x : i32 loc(#loc2)
    %1 = arith.muli %0, %c1024_i32 : i32 loc(#loc3)
    %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> loc(#loc4)
    %3 = tt.splat %1 : i32 -> tensor<1024xi32> loc(#loc5)
    %4 = arith.addi %3, %2 : tensor<1024xi32> loc(#loc5)
    %5 = tt.splat %arg3 : i32 -> tensor<1024xi32> loc(#loc6)
    %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32> loc(#loc6)
    %7 = tt.

In [None]:
print(kernel.asm['ttgir'])

#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#loc = loc("/tmp/ipykernel_3388653/2601426447.py":2:0)
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:75", "ttg.threads-per-warp" = 32 : i32} {
  tt.func public @_add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/tmp/ipykernel_3388653/2601426447.py":2:0), %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/tmp/ipykernel_3388653/2601426447.py":2:0), %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/tmp/ipykernel_3388653/2601426447.py":2:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/tmp/ipykernel_3388653/2601426447.py":2:0)) attributes {noinline = false} {
    %c1024_i32 = arith.constant 1024 : i32 loc(#loc1)
    %0 = tt.get_program_id x : i32 loc(#loc2)
    %1 = arith.muli %0, %c1024_i32 : i32 loc(#loc3)
    %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> loc(#loc4)
    %3 = tt.sp