In [1]:
import triton 
import torch
import triton.language as tl

In [23]:
@triton.jit 
def transposefxn(inptr,outputr,m,n,blocksizem:tl.constexpr,blocksizen:tl.constexpr):
    id1=tl.program_id(0)
    id2=tl.program_id(1)

    mstart=id1*blocksizem
    nstart=id2*blocksizen

    m_range=mstart+tl.arange(0,blocksizem)
    n_range=nstart+tl.arange(0,blocksizen)

    m_idx=m_range[:,None]
    n_idx=n_range[None,:]

    m_mask = m_idx < m
    n_mask = n_idx < n
    mask = m_mask & n_mask

    inindx = inptr + m_idx * n + n_idx
    outindx = outputr + n_idx * m + m_idx

    val=tl.load(inindx,mask=mask)
    tl.store(outindx,val,mask=mask)
    
    

In [26]:
def test():
    M, N = 5, 5
    a = torch.randn((M, N), device='cuda', dtype=torch.float16)
    b=torch.zeros_like(a)
    mblocksize=64
    nblocksize=64

    noofblock_m=triton.cdiv(M,mblocksize)
    noofblock_n=triton.cdiv(N,nblocksize)

    transposefxn[(noofblock_m,noofblock_n)](a,b,M,N,mblocksize,nblocksize)
    print(a)
    print(b)
    
    

In [27]:
if __name__=='__main__':
    test()

tensor([[-0.1385,  0.5586, -0.4517,  0.9023,  0.0376],
        [ 0.6562, -1.2832, -0.1671,  1.3994, -0.0087],
        [-0.2954,  2.5918,  0.0861, -0.0414, -0.2920],
        [-0.2947, -0.3892, -1.1631, -0.9902, -0.1144],
        [-0.7700, -1.1514,  0.9722,  0.8081,  1.6787]], device='cuda:0',
       dtype=torch.float16)
tensor([[-0.1385,  0.6562, -0.2954, -0.2947, -0.7700],
        [ 0.5586, -1.2832,  2.5918, -0.3892, -1.1514],
        [-0.4517, -0.1671,  0.0861, -1.1631,  0.9722],
        [ 0.9023,  1.3994, -0.0414, -0.9902,  0.8081],
        [ 0.0376, -0.0087, -0.2920, -0.1144,  1.6787]], device='cuda:0',
       dtype=torch.float16)


i found this on leetgpu <br/>
triton has its own transpose native function so don't forget to use it. use tiling and row wise access and things should be blazingly fast.

## matrix tranpose variant 2 https://leetgpu.com/challenges/matrix-transpose , strides given instead of manually calc

In [None]:
import triton
import triton.language as tl

@triton.jit
def matrix_transpose_kernel(
    input_ptr, output_ptr,
    rows, cols,
    stride_ir, stride_ic,  
    stride_or, stride_oc
):
    input_ptr = input_ptr.to(tl.pointer_type(tl.float32))
    output_ptr = output_ptr.to(tl.pointer_type(tl.float32))
    row = tl.program_id(0)
    col = tl.program_id(1)

    mask_row=row <rows 
    mask_col= col < cols 
    mask= mask_row & mask_col  

    in_index = row * stride_ir + col * stride_ic
    out_index = col * stride_or + row * stride_oc

    val = tl.load(input_ptr + in_index,mask=mask)
    tl.store(output_ptr + out_index, val,mask= mask)
    


def solve(input_ptr: int, output_ptr: int, rows: int, cols: int):
    stride_ir, stride_ic = cols, 1  
    stride_or, stride_oc = rows, 1
    
    grid = (rows, cols)
    matrix_transpose_kernel[grid](
        input_ptr, output_ptr,
        rows, cols,
        stride_ir, stride_ic,
        stride_or, stride_oc
    ) 