In [None]:
import torch

import triton
import triton.language as tl

In [None]:
import lovely_tensors
lovely_tensors.monkey_patch()

In [None]:
@triton.jit
def sigmoid_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    y = 1.0 / (1.0 + tl.exp(-x))
    tl.store(y_ptr + offsets, y, mask=mask)

def tri_sigmoid(x: torch.Tensor):
    y = torch.empty_like(x)
    n_elements = x.numel()
    BLOCK_SIZE = 1024
    def grid(meta): return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
    sigmoid_kernel[grid](x, y, n_elements, BLOCK_SIZE=BLOCK_SIZE)
    return y

In [None]:
@torch.compile(mode="max-autotune")
def torch_sigmoid(x):
    return 1 / (1 + torch.exp(-x))

In [None]:
a = torch.randn(int(1024*16*16*16)).cuda()

In [None]:
ms = triton.testing.do_bench(lambda: tri_sigmoid(a))
print(ms)

In [None]:
ms = triton.testing.do_bench(lambda: torch_sigmoid(a))
print(ms)