In [2]:
import triton
import triton.language as tl
import torch

@triton.jit
def softmax_kernel(
    input_ptr, output_ptr,
    N,
    BLOCK_SIZE: tl.constexpr
):
    input_ptr = input_ptr.to(tl.pointer_type(tl.float32))
    output_ptr = output_ptr.to(tl.pointer_type(tl.float32))
    pid = tl.program_id(0)
    block_start= BLOCK_SIZE * pid
    offsets= block_start + tl.arange(0, BLOCK_SIZE)
    mask= offsets < N
    
    x = tl.load(input_ptr + offsets, mask=mask,other=float('-inf')) #imp
    
    maxv=tl.max(x)
    x=x-maxv
    exp_x = tl.exp(x)
    
    exp_x = tl.where(mask, exp_x, 0.0) #imp
    
    sum_exp = tl.sum(exp_x)
    softmax = exp_x / sum_exp
    tl.store(output_ptr + offsets, softmax, mask=mask)

def solve(input_ptr: int, output_ptr: int, N: int):
    BLOCK_SIZE = 1024 

    grid_size = triton.cdiv(N, BLOCK_SIZE)
    
    softmax_kernel[(grid_size,)](
        input_ptr, output_ptr,
        N,
        BLOCK_SIZE=BLOCK_SIZE
    )

if __name__ == "__main__":
    input_data = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32, device='cuda')
    output_data = torch.zeros_like(input_data)
    
    solve(input_data.data_ptr(), output_data.data_ptr(), len(input_data))
    
    print(input_data)
    print(output_data)
    


tensor([1., 2., 3.], device='cuda:0')
tensor([0.0900, 0.2447, 0.6652], device='cuda:0')
