<a href="https://colab.research.google.com/github/NShravanReddy/DeepLearning/blob/main/triton/Relu_benchmark.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import triton
import triton.language as tl
import torch
import torch.nn as nn
import time
DEVICE = 'cuda'


@triton.jit
def l_r_k(x_ptr,
          y_ptr,
          alpha,
          N0,
          BLOCK_SIZE:tl.constexpr):

  pid=tl.program_id(axis=0)
  block_start= BLOCK_SIZE * pid
  offsets = block_start + tl.arange(0, BLOCK_SIZE)
  mask = offsets < N0
  x=tl.load(x_ptr+offsets,mask=mask)
  y=tl.maximum(x,alpha * x)
  tl.store(y_ptr+offsets,y,mask=mask)



def l_r_k_h(x:torch.Tensor,alpha:float=1, BLOCK_SIZE=1024**2)->torch.Tensor:
  y=torch.empty_like(x)
  N0=x.numel()
  grid= lambda meta :(triton.cdiv(N0,meta['BLOCK_SIZE']),)
  assert x.is_cuda and y.is_cuda
  l_r_k[grid](x,y,alpha,N0,BLOCK_SIZE=BLOCK_SIZE)
  return y,alpha

def benchmark(func, *args, n_warmup=10, n_iters=100):
    for _ in range(n_warmup):
        func(*args)
    torch.cuda.synchronize()
    start = time.perf_counter()
    for _ in range(n_iters):
        func(*args)
    torch.cuda.synchronize()
    end = time.perf_counter()
    return (end - start) / n_iters * 1000

if __name__=='__main__':
  N=1024*1024
  x=torch.randn(N, device='cuda', dtype=torch.float32)

  y_triton,alpha=l_r_k_h(x)
  leaky_relu = nn.LeakyReLU(negative_slope=alpha)
  y_torch = leaky_relu(x)

  print(y_torch)
  print(y_triton)
  print(abs(y_torch-y_triton))

  #Benchmarking forward pass
  BLOCK_SIZE=1024
  x_torch = x.detach().clone().requires_grad_()
  leaky_relu = nn.LeakyReLU(negative_slope=alpha)
  y_torch = leaky_relu(x)
  triton_time = benchmark(lambda: l_r_k_h(x, BLOCK_SIZE))
  torch_time = benchmark(lambda: leaky_relu(x))
  print(f"Average execution time (Forward Pass):")
  print(f"  Triton LeakyReLU = {triton_time:.3f} ms")
  print(f"  PyTorch LeakyReLU = {torch_time:.3f} ms")

RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
