<a href="https://colab.research.google.com/github/NShravanReddy/DeepLearning/blob/main/triton/torch_complie_add.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
import triton
import triton.language as tl
import torch
import time
DEVICE='cuda'
@triton.jit

def t_a_k(x_ptr,
          y_ptr,
          output_ptr,
          N0,
          BLOCK_SIZE:tl.constexpr):
    pid=tl.program_id(axis=0)
    block_start= pid* BLOCK_SIZE
    offsets = block_start + tl.arange(0,BLOCK_SIZE)
    mask= offsets<N0
    x=tl.load(x_ptr+offsets,mask=mask)
    y=tl.load(y_ptr+offsets,mask=mask)
    output=x+y
    output=tl.store(output_ptr+offsets,output,mask=mask)

@torch.compile(fullgraph=True)
def t_a_k_h(x:torch.Tensor,y:torch.Tensor,BLOCK_SIZE=1024) -> torch.Tensor:
    output=torch.empty_like(x)
    N0=x.numel()
    grid=lambda meta:(triton.cdiv(N0,meta['BLOCK_SIZE']),)
    t_a_k[grid](x,y,output,N0,BLOCK_SIZE=BLOCK_SIZE)
    return output

def benchmark(func, *args, n_warmup=10, n_iters=100):
    for _ in range(n_warmup):
        func(*args)
    torch.cuda.synchronize()
    start = time.perf_counter()
    for _ in range(n_iters):
        func(*args)
    torch.cuda.synchronize()
    end = time.perf_counter()
    return (end - start) / n_iters * 1000

if __name__=='__main__':
  N0=1024
  x=torch.arange(0,N0,device=DEVICE,dtype=torch.float32)
  y=torch.arange(0,N0,device=DEVICE,dtype=torch.float32)
  add=x+y
  y_triton=t_a_k_h(x,y)
  print((add,y_triton))


  BLOCK_SIZE=1024
  x_torch = x.detach().clone().requires_grad_()
  y_torch = y.detach().clone().requires_grad_()
  add=x+y
  y_triton=t_a_k_h(x,y)

  triton_time = benchmark(lambda: t_a_k_h(x, y))
  torch_time = benchmark(lambda: x+y)
  print(f"Average execution time (Forward Pass):")
  print(f"  Triton  = {triton_time:.3f} ms")
  print(f"  PyTorch = {torch_time:.3f} ms")

(tensor([0.0000e+00, 2.0000e+00, 4.0000e+00,  ..., 2.0420e+03, 2.0440e+03,
        2.0460e+03], device='cuda:0'), tensor([0.0000e+00, 2.0000e+00, 4.0000e+00,  ..., 2.0420e+03, 2.0440e+03,
        2.0460e+03], device='cuda:0'))
Average execution time (Forward Pass):
  Triton  = 0.049 ms
  PyTorch = 0.008 ms
