<a href="https://colab.research.google.com/github/NShravanReddy/DeepLearning/blob/main/triton/Sigmoid_triton.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import triton
import triton.language as tl
import torch.nn as nn
import torch
import time

@triton.jit
def t_s_k(x_ptr,
          y_ptr,
          N0,
          BLOCK_SIZE: tl.constexpr):
    pid=tl.program_id(axis=0)
    block_start= pid* BLOCK_SIZE
    offsets= block_start+ tl.arange(0,BLOCK_SIZE)
    mask= offsets< N0
    x=tl.load(x_ptr+offsets,mask=mask)
    y= 1/ (1+tl.exp(-x))
    tl.store(y_ptr+offsets,y,mask=mask)


def t_s_k_h(x:torch.Tensor, BLOCK_SIZE=1024) ->torch.Tensor:
    y=torch.empty_like(x)
    N0=x.numel()
    grid= lambda meta:(triton.cdiv(N0,meta['BLOCK_SIZE']),)
    t_s_k[grid](x,y,N0,BLOCK_SIZE=BLOCK_SIZE)
    return y

def benchmark(func, *args, n_warmup=10, n_iters=100):
    for _ in range(n_warmup):
        func(*args)
    torch.cuda.synchronize()
    start = time.perf_counter()
    for _ in range(n_iters):
        func(*args)
    torch.cuda.synchronize()
    end = time.perf_counter()
    return (end - start) / n_iters * 1000


def benchmark(func, *args, n_warmup=10, n_iters=100):
    for _ in range(n_warmup):
        func(*args)
    torch.cuda.synchronize()
    start = time.perf_counter()
    for _ in range(n_iters):
        func(*args)
    torch.cuda.synchronize()
    end = time.perf_counter()
    return (end - start) / n_iters * 1000


if __name__=='__main__':
  N=1024*1024
  x=torch.randn(N, device='cuda', dtype=torch.float32)

  y_triton=t_s_k_h(x)
  Sigmoid = nn.Sigmoid()
  y_torch = Sigmoid(x)

  print(y_torch)
  print(y_triton)
  print(abs(y_torch-y_triton))
  #Benchmarking forward pass
  BLOCK_SIZE=1024**2
  x_torch = x.detach().clone().requires_grad_()
  Sigmoid = nn.Sigmoid()
  y_torch = Sigmoid(x)
  triton_time = benchmark(lambda: t_s_k_h(x, BLOCK_SIZE))
  torch_time = benchmark(lambda: Sigmoid(x))
  print(f"Average execution time (Forward Pass):")
  print(f"  Triton ReLU = {triton_time:.3f} ms")
  print(f"  PyTorch ReLU = {torch_time:.3f} ms")

tensor([0.5933, 0.0790, 0.2378,  ..., 0.7168, 0.7342, 0.3745], device='cuda:0')
tensor([0.5933, 0.0790, 0.2378,  ..., 0.7168, 0.7342, 0.3745], device='cuda:0')
tensor([0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
        2.9802e-08], device='cuda:0')
