In [1]:
import triton 
import torch
import triton.language as tl

GELU(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) <br/>
Swish(x) = x * sigmoid(x) = x / (1 + exp(-x)) <br/>

tanh(x) = (exp(2x) - 1) / (exp(2x) + 1)
  

In [7]:
@triton.jit 
def gelu_swish_fused_kernel(inptr1,outptr1,outptr2,m,blocksize:tl.constexpr):
    p_id=tl.program_id(0)
    block_start=blocksize*p_id
    offsets=block_start+tl.arange(0,blocksize)
    mask=offsets<m
    a=tl.load(inptr1+offsets,mask=mask)

    inner = 0.7978845608028654 * (a + 0.044715 * (a * a * a))
    # tanh_inner = tl.math.tanh(inner) -> error , not inbuilt triton so make own
    exp_2x = tl.math.exp(2.0 * inner)
    tanh = (exp_2x - 1.0) / (exp_2x + 1.0)
    gelu = 0.5 * a * (1.0 + tanh)

    sigmoid= 1 / (1+ tl.math.exp(-a))
    swish=a * sigmoid

    tl.store(outptr1 + offsets, gelu, mask=mask)
    tl.store(outptr2 + offsets, swish, mask=mask)
    
    
    

In [9]:
def test():
    m=500
    vector_a = torch.randn(m, device='cuda', dtype=torch.float32)
    vector_g=torch.zeros_like(vector_a)
    vector_s=torch.zeros_like(vector_a)
    blocksize=128
    noofblock=triton.cdiv(m,blocksize)
    gelu_swish_fused_kernel[(noofblock,)](vector_a,vector_g,vector_s,m,blocksize)
    print(vector_a)
    print(vector_g)
    print(vector_s)
