In [1]:
import triton
import triton.language as tl
import torch
import numpy as np

In [2]:
triton.__version__

'3.4.0'

In [3]:
## Add 2 vectors - 1D - c = a + b

In [4]:
@triton.jit
def add_two_vectors(ip1_ptr, ip2_ptr, op_ptr, n_elements, BLOCK_SIZE:tl.constexpr):

    pid = tl.program_id(0)
    offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)

    masks = offset < n_elements
    ip1 = tl.load(ip1_ptr+offset, mask=masks)
    ip2 = tl.load(ip2_ptr+offset, mask=masks)

    op = ip1+ip2

    tl.store(op_ptr+offset, op, mask=masks)

In [5]:
def launch_add_kernel(x,y):
    assert x.size() == y.size()

    out = torch.zeros_like(x)
    n_elements = x.numel()

    ##Launch kernels parallely - well for triton its at grid level from a conceptyual standpoint
    grid = lambda meta : (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) #dont forget ,
    add_two_vectors[grid](x, y, out, n_elements, BLOCK_SIZE=1024)
    return out

In [6]:
torch.manual_seed(8)
SIZE = int(1e4)

x = torch.randn(SIZE, device='cuda', dtype=torch.float32)
y = torch.randn(SIZE, device='cuda', dtype=torch.float32)

print(launch_add_kernel(x,y))

tensor([-0.4691, -1.2322,  1.2438,  ..., -0.4972,  0.1273,  0.0804],
       device='cuda:0')


In [7]:
## Square and add - out = x ^ 2 + y

@triton.jit
def square_and_add(ip1_ptr, ip2_ptr, out_ptr, n_elements, BLOCK_SIZE:tl.constexpr):
    pid = tl.program_id(axis=0) # 1D array
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    masks = offsets < n_elements

    x = tl.load(ip1_ptr + offsets, mask=masks)
    y = tl.load(ip2_ptr + offsets, mask=masks)

    out = x * x + y

    tl.store(out_ptr + offsets, out, mask=masks)
    

In [8]:
def launch_square_and_add():
    x = torch.tensor([1,2,3,4], device="cuda", dtype=torch.float32)
    y = torch.tensor([10,20,30,40], device="cuda", dtype=torch.float32)

    out = torch.zeros_like(x)
    n_elements = x.numel()

    grid = lambda meta : (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
    square_and_add[grid](x, y, out, n_elements, BLOCK_SIZE=1024)

    print(out)

In [9]:
launch_square_and_add()

tensor([11., 24., 39., 56.], device='cuda:0')


In [10]:
## Lets check the PTX code. The caching method didnt work for me. 

from triton.compiler import CompiledKernel

# Compile the kernel with specific parameters
compiled = square_and_add.warmup(
    torch.float32,  # ip1_ptr type
    torch.float32,  # ip2_ptr type  
    torch.float32,  # out_ptr type
    4,              # n_elements
    BLOCK_SIZE=1024,
    grid=(1,)
)

# Get the PTX code
ptx = compiled.asm['ptx']
print(ptx)


//
// Generated by LLVM NVPTX Back-End
//

.version 8.7
.target sm_89
.address_size 64

	// .globl	square_and_add          // -- Begin function square_and_add
                                        // @square_and_add
.visible .entry square_and_add(
	.param .u64 .ptr .global .align 1 square_and_add_param_0,
	.param .u64 .ptr .global .align 1 square_and_add_param_1,
	.param .u64 .ptr .global .align 1 square_and_add_param_2,
	.param .u32 square_and_add_param_3,
	.param .u64 .ptr .global .align 1 square_and_add_param_4
)
.reqntid 128
{
	.reg .pred 	%p<25>;
	.reg .b32 	%r<39>;
	.reg .b64 	%rd<29>;
	.loc	1 4 0                           // 593028606.py:4:0
$L__func_begin0:
	.loc	1 4 0                           // 593028606.py:4:0

// %bb.0:
	ld.param.b64 	%rd25, [square_and_add_param_0];
	ld.param.b64 	%rd26, [square_and_add_param_1];
$L__tmp0:
	.loc	1 5 24                          // 593028606.py:5:24
	mov.u32 	%r25, %ctaid.x;
	.loc	1 6 20                          // 593028606.py:6:20
	shl.

In [11]:
## Key insights - Really blown away. Like soo many optimizations done by Triton
#automatically as opposed.

#1) Usage of FMA ops - which is better than 2 instructions plus one add
#2) Predictaed masking -  setp.lt.s32 ...  Unlike numba where your bound checks
##are done using if conditions, triton avoids it completely by masking. Which is actually quite smart
#One can get away with loop unrolling but still have to do ifs

#3) Memory access - Contiguous between threads but strided across threads
#Evidence - Adjacent threads getting adjacent offsets -> shift left by 2 bits operation shl.b32 %r29, %r28, 2  on %tid.x
# stride access is 2048 bytes (512 elements) - add.s64 %rd5, %rd1, 2048
# coalesced access - All threads execute same ld.global.b32 ..... - which is cool.

##Vectorization - Each thread processes 8 elements

In [12]:
## Element wise maximum with threshold - out = max(x, threshold)

@triton.jit
def max_with_threshold(x_ptr, threshold, out_ptr, n_elements, BLOCK_SIZE:tl.constexpr): 
    
    #Scalars will be automatically broadcasted.But it is not like numpy or pytorch broadcasting
    ## tensors will be strectched. But here scalars will be put in registers directtly

    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    masks = offsets < n_elements

    x = tl.load(x_ptr + offsets, mask=masks)

    out = tl.maximum(x, threshold)

    tl.store(out_ptr + offsets, out, mask=masks)
    


In [13]:
def launch_max_with_threshold():
    x = torch.tensor([-1, 0.3, 0.8, 1.5], device="cuda", dtype=torch.float32)
    threshold = 0.5
    out = torch.zeros_like(x)
    
    n_elements = x.numel()
    grid = lambda meta : (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), )

    max_with_threshold[grid](x, threshold, out, n_elements, BLOCK_SIZE=1024)

    print(out)

In [14]:
launch_max_with_threshold()

tensor([0.5000, 0.5000, 0.8000, 1.5000], device='cuda:0')
