In [1]:
import triton 
import torch
import triton.language as tl

In [5]:
@triton.jit 
def finddiagonal(inptr,outptr,m,n,blocksize:tl.constexpr):
    id1=tl.program_id(0)

    startm=id1*blocksize

    offsets=startm+tl.arange(0,blocksize)
    mask = offsets < min(m,n)

    input_offsets = offsets * (n + 1)
    diagonal_vals = tl.load(inptr + input_offsets, mask=mask)
    tl.store(outptr + offsets, diagonal_vals, mask=mask)
    
    

In [6]:
def test():
   M, N = 5, 5
   a = torch.randn((M, N), device='cuda', dtype=torch.float16)
   z = min(M, N)
   b = torch.zeros(z, device='cuda', dtype=torch.float16)  
   
   blocksize = 64
   noofblocks = triton.cdiv(z, blocksize)  
   
   finddiagonal[(noofblocks,)](a, b, M, N, blocksize)  
   
   print(a)
   print(b)

In [7]:
if __name__=='__main__':
    test()

tensor([[-1.4082, -2.3047,  1.3223, -0.8174,  0.0643],
        [-0.7202,  0.9331,  0.9731,  0.3411,  0.5620],
        [-0.8545,  0.3401, -0.2803,  0.0705,  0.0190],
        [ 0.4495, -0.1223, -0.2676,  0.1848,  1.2080],
        [-1.0684, -1.1436,  0.7744,  0.9448, -0.5815]], device='cuda:0',
       dtype=torch.float16)
tensor([-1.4082,  0.9331, -0.2803,  0.1848, -0.5815], device='cuda:0',
       dtype=torch.float16)
