In [1]:
import triton 
import torch
import triton.language as tl

## return same matrix with diag=diag , rest =0

In [6]:
@triton.jit 
def finddiagonal(inptr,outptr,m,n,blocksizem:tl.constexpr,blocksizen:tl.constexpr):
    id1=tl.program_id(0)
    id2=tl.program_id(1)

    startm=id1*blocksizem
    startn=id2*blocksizen

    offsets_m=startm+tl.arange(0,blocksizem)
    offsets_n=startn+tl.arange(0,blocksizen)

    m_idx=offsets_m[:,None]
    n_idx=offsets_n[None,:]

    mask_m=m_idx<m
    mask_n=n_idx<n
    mask=mask_m & mask_n

    input_offsets = m_idx * n + n_idx
    is_diagonal = m_idx == n_idx

    val=tl.load(inptr+input_offsets,mask=mask)
    newval=tl.where(is_diagonal,val,0)
    tl.store(outptr+input_offsets,newval,mask=mask)
    

In [9]:
def test():
   M, N = 1000, 120
   a = torch.randn((M, N), device='cuda', dtype=torch.float16)
   b = torch.zeros_like(a)  
   
   blocksize = 128
   noofblocks1 = triton.cdiv(M, blocksize)  
   noofblocks2 = triton.cdiv(N, blocksize)  
   
   finddiagonal[(noofblocks1,noofblocks2)](a, b, M, N, blocksize,blocksize)  
   
   print(a)
   print(b)

In [10]:
if __name__=='__main__':
    test()

tensor([[-1.9668,  0.3950,  0.0696,  ..., -0.2600,  1.4121, -0.8857],
        [-1.2461,  0.5625, -1.3623,  ...,  0.1709, -0.3994,  0.7012],
        [ 0.5762, -0.5010,  1.2363,  ...,  0.3660, -1.5127,  0.2468],
        ...,
        [ 1.6846, -0.4146,  0.9194,  ...,  0.3699, -2.1094,  2.6504],
        [-1.0557,  0.2275, -0.1411,  ...,  1.4375,  0.0275, -1.6982],
        [ 1.4678, -1.8877,  0.2900,  ...,  0.8540, -0.1973, -1.4443]],
       device='cuda:0', dtype=torch.float16)
tensor([[-1.9668,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.5625,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  1.2363,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0', dtype=torch.float16)
