In [2]:
import triton 
import torch
import triton.language as tl

In [3]:
@triton.jit 
def vectorcumsum(inptr,ans,m,blocksize:tl.constexpr):
    blockid=tl.program_id(0)
    block_start=blockid*blocksize
    offsets = block_start + tl.arange(0, blocksize)
    mask = offsets < m
    vals = tl.load(inptr + offsets, mask=mask)
    cumsum_vals = tl.cumsum(vals, axis=0)
    tl.store(ans + offsets, cumsum_vals, mask=mask)
    


In [6]:
def test():
   M=10
   a = torch.randn(M, device='cuda', dtype=torch.float16)
   ans = torch.zeros_like(a)
   blocksize = 128
   noofblocks = triton.cdiv(M, blocksize)  
   
   vectorcumsum[(noofblocks,)](a, ans,M,blocksize)  
   print(a)
   print(ans)

In [7]:
if __name__=='__main__':
    test()

tensor([-0.3650,  1.1035, -0.6265,  1.4199,  0.0507,  0.9785,  0.8892,  0.6997,
         0.6196, -0.7861], device='cuda:0', dtype=torch.float16)
tensor([-0.3650,  0.7383,  0.1121,  1.5312,  1.5820,  2.5605,  3.4492,  4.1484,
         4.7695,  3.9824], device='cuda:0', dtype=torch.float16)
