M A S K I N G 

In [38]:
import torch
import triton
import triton.language as tl


@triton.jit
def kernel(xptr, yptr, outptr, nelements, BLOCKSIZE: tl.constexpr):
    pid = tl.program_id(axis=0)

    offsets = pid * BLOCKSIZE + tl.arange(0, BLOCKSIZE)
    mask = offsets < nelements

    x = tl.load(xptr + offsets, mask=mask, other=0.0)
    y = tl.load(yptr + offsets, mask=mask, other=0.0)

    acc = x + y
    acc = tl.sum(acc, axis=0)

    tl.store(outptr + pid, acc)


def test(x, y):
    nelements = x.numel()
    BLOCKSIZE = 1024

    grid = lambda meta: (triton.cdiv(nelements, meta["BLOCKSIZE"]),)

    out = torch.empty((grid({"BLOCKSIZE": BLOCKSIZE})[0],), device="cuda")

    kernel[grid](x, y, out, nelements, BLOCKSIZE)

    return out.sum()


x = torch.tensor([1.0, 2.0, 3.0, 4.0], device="cuda")
y = torch.tensor([2.0, 3.0, 4.0, 5.0], device="cuda")

res = test(x, y)
print(res)

tensor(24., device='cuda:0')


In [50]:
import torch , triton , math
import triton.language as tl

@triton.jit
def kernel(xptr , outptr , threshold , nelements , blocksize : tl.constexpr):
    pid = tl.program_id(axis = 0)

    offsets = pid * blocksize + tl.arange(0 , blocksize)
    mask = offsets < nelements

    x = tl.load(xptr + offsets , mask = mask , other = 0.0)
    sqmask = (x > threshold)

    acc = tl.where(sqmask , x * x , 0)
    acc = tl.sum(acc , axis = 0)

    tl.store(outptr + pid , acc)

def test(x):
    nelements = x.numel()
    blocksize = 2

    out = torch.zeros(math.ceil(nelements / blocksize) , device = 'cuda')
    grid = grid = lambda meta: (triton.cdiv(nelements , meta['blocksize']), )

    kernel[grid](x , out , 3 , nelements , blocksize)

    return out

x = torch.tensor([1.0 , 5.0 , 2.0 , 7.0 , 4.0 , 9.0], device='cuda')
res = test(x)
res

tensor([25., 49., 97.], device='cuda:0')

In [None]:
@triton.jit
def kernel(xptr, outptr, nelements, threshold, counter, blocksize: tl.constexpr):
    pid = tl.program_id(0)
    offsets = pid * blocksize + tl.arange(0, blocksize)
    mask = offsets < nelements

    x = tl.load(xptr + offsets, mask=mask, other=0.0)

    
    x_mask = x > threshold
    combined_mask = mask & x_mask

    valid_ind = tl.where(combined_mask , 1 , 0)
    block_count = tl.sum(valid_ind, axis=0)
    start_idx = tl.atomic_add(counter, block_count)

    
    prefix = tl.cumsum(valid_ind, axis=0) - valid_ind
    out_pos = start_idx + prefix

    tl.store(outptr + out_pos, x, mask=combined_mask)

def test(x: torch.Tensor, threshold: float):
    x = x.contiguous()
    out = torch.empty_like(x)
    out_counter = torch.zeros(1, dtype=torch.int32, device=x.device)

    BLOCK_SIZE = 128
    grid = (triton.cdiv(x.numel(), BLOCK_SIZE),)

    kernel[grid](
        x,
        out,
        x.numel(),
        threshold,
        out_counter,
        blocksize=BLOCK_SIZE   
    )

    return out[:out_counter.item()]

x = torch.tensor([1.0 , 5.0 , 2.0 , 7.0 , 5.0 , 9.0], device='cuda')
res = test(x , 4)
res


tensor([5., 7., 5., 9.], device='cuda:0')

S T R I D E S 

In [53]:
import torch , triton 
import triton.language as tl

@triton.jit
def transposekernel(xptr , outptr , xstrm , xstrn , outstrm , outstrn , M , N , MBLOCK: tl.constexpr , NBLOCK: tl.constexpr):
    pid = tl.program_id(axis = 0)
    ypid = tl.program_id(axis = 1)

    row = pid * MBLOCK + tl.arange(0 , MBLOCK)
    col = ypid * NBLOCK + tl.arange(0 , NBLOCK)

    ptrx = xptr + row[: , None] * xstrm + col[None , :] * xstrn
    mask = (row[: , None] < M) & (col[None , :] < N)

    x = tl.load(ptrx , mask = mask , other = 0)
    out = outptr + col[: , None] * outstrm + row[None , :] * outstrn
    maskout = (col[None , :] < N) & (row[: , None] < M)

    x_transposed = tl.trans(x)
    tl.store(out , x_transposed , mask = maskout)

def test(x):
    M , N = x.shape
    mblock , nblock = 8 , 8

    out = torch.empty((N , M) , device = 'cuda')
    grid = (
        (M + mblock - 1) // mblock,
        (N + nblock - 1) // nblock,
    )

    transposekernel[grid](x , out , x.stride()[0] , x.stride()[1] , out.stride()[0] , out.stride()[1] , M , N , mblock , nblock)
    return out

x = torch.randn(128 , 64 , device = 'cuda')

res = test(x)
res 

tensor([[-0.1696, -0.5476,  1.2832,  ...,  1.4566, -0.6243, -0.3408],
        [-0.6959,  0.7655, -0.8117,  ...,  1.3591, -1.0552, -1.1866],
        [-1.0471,  0.9832,  1.0137,  ..., -0.4676, -0.2475,  1.6505],
        ...,
        [-1.6181, -1.0190,  0.0295,  ...,  0.0285, -0.2558,  0.0209],
        [-0.0846, -0.6835, -0.2950,  ...,  0.6142,  0.3847, -0.7668],
        [-0.5564,  1.6412,  1.6846,  ..., -1.9642,  0.7100,  1.4516]],
       device='cuda:0')

In [None]:
@triton.jit
def indexkernel(xptr , outptr , index , xstrx , xstry):

    x , y = index

    ptr = xptr + x[: , None] * xstrx + y[None , :] * xstry

    tl.store(outptr , tl.load(ptr))

def test(x , index):
    out = torch.empty((1,) , device = 'cuda')
    indexkernel[(1,)](x , out , index , x.stride()[0] , x.stride()[1])
    return out

x = torch.tensor([[1.0 , 2.0 , 3.0]] , device = 'cuda')
index = (0 , 1)

res = test(x , index)

res

In [None]:
x = torch.arange(12).reshape(3,4) # 4 , 1
a = x.t()# 1 , 4
b = x[:, ::2] # 4 , 2
c = x[::2, :] # 8 , 1
d = x.permute(1,0) # 1 , 4

In [83]:
x = torch.arange(16).reshape(4 , 4)
y = x.t()
y.is_contiguous()

False

In [84]:
# most likely the reason is as memory was never changed just the layout , but view depends on layout , but the actual tensor size didn't matched 
y , x

(tensor([[ 0,  4,  8, 12],
         [ 1,  5,  9, 13],
         [ 2,  6, 10, 14],
         [ 3,  7, 11, 15]]),
 tensor([[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]]))

In [85]:
x.view(-1)

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])

In [87]:
y.reshape(-1)

tensor([ 0,  4,  8, 12,  1,  5,  9, 13,  2,  6, 10, 14,  3,  7, 11, 15])

In [89]:
y_flat = y.contiguous().view(-1)
print(y_flat)

tensor([ 0,  4,  8, 12,  1,  5,  9, 13,  2,  6, 10, 14,  3,  7, 11, 15])


In [44]:
x = torch.arange(12).reshape(3,4) 
transposed = torch.as_strided(x , (4 , 3) , (1 , 4))
x.t() , transposed

(tensor([[ 0,  4,  8],
         [ 1,  5,  9],
         [ 2,  6, 10],
         [ 3,  7, 11]]),
 tensor([[ 0,  4,  8],
         [ 1,  5,  9],
         [ 2,  6, 10],
         [ 3,  7, 11]]))

In [50]:
x = torch.arange(16).reshape(4,4) # 4 , 1
diagnol = torch.as_strided(x , (4 , 1) , (5 , 1))
diagnol 

tensor([[ 0],
        [ 5],
        [10],
        [15]])

In [59]:
@triton.jit
def flattenkernel(
    xptr, outptr, xstrm, xstrn, M, N, Mblock: tl.constexpr, Nblock: tl.constexpr
):
    pid = tl.program_id(0)
    ypid = tl.program_id(1)

    row = pid * Mblock + tl.arange(0, Mblock)
    col = ypid * Nblock + tl.arange(0, Nblock)

    ptr = xptr + row[:, None] * xstrm + col[None, :] * xstrn
    mask = (row[:, None] < M) & (col[None, :] < N)

    out = outptr + row[:, None] * N + col[None, :]

    vals = tl.load(ptr, mask=mask)

    tl.store(out, vals, mask=mask)

def test(x):
    m , n = x.shape
    out = torch.empty((m * n) , device = 'cuda')
    mblock , nblock = 32 , 32

    grid = (
        (m + mblock - 1) // mblock,
        (n + nblock - 1) // nblock,
    )

    flattenkernel[grid](x , out , x.stride()[0] , x.stride()[1] , m , n , mblock , nblock)
    return out

x = torch.randn(2 , 3 , device = 'cuda')
res = test(x)
res


tensor([0.2547, 1.0185, 0.9231, 1.3405, 0.8570, 1.1698], device='cuda:0')

In [70]:
x = torch.tensor([[2.0 , 3.0] , [1.0 , 4.0]] , device = 'cuda')
y = torch.as_strided(x , (2 , 2) , (2 , 1) )
y

tensor([[2., 3.],
        [1., 4.]], device='cuda:0')

In [1]:
import torch
x = torch.tensor([[1.0 , 2.0 , 3.0] , [4.0 , 5.0 , 6.0]] , device = 'cuda')
x.permute(1 , 0) , print(" \n"), x


 



(tensor([[1., 4.],
         [2., 5.],
         [3., 6.]], device='cuda:0'),
 None,
 tensor([[1., 2., 3.],
         [4., 5., 6.]], device='cuda:0'))

In [80]:
permute = torch.as_strided(x , (3 , 2) , (1 , 3))
permute

tensor([[1., 4.],
        [2., 5.],
        [3., 6.]], device='cuda:0')

In [90]:
x = torch.arange(6)
y = x.as_strided((3,2),(1,2))
y

tensor([[0, 2],
        [1, 3],
        [2, 4]])

In [3]:
x = torch.arange(16).reshape(4,4)
y = x[:, ::2].contiguous()
y.is_contiguous()

True