In [7]:
import torch
import torch.nn.functional as F

# Compare Python Impl with PyTorch

In [8]:
def gelu(x: torch.Tensor):
    """
    This implementation is slow on GPU because there are multiple data movement
    between HBM and SRAM.

    Every time there is an operation to the input "x", the input is loaded
    from the global memory to SRAM, and the result is sent back to global memory
    after computation is completed, back and forth.
    """
    return 0.5 * x * (1 + torch.tanh((2 / torch.pi) ** 0.5 * (x + 0.044715 * x**3)))

In [8]:
torch.manual_seed(42)
x = torch.randn((1024, 1024), device="cuda")

In [9]:
# Make sure the results are the same.
gelu(x) - F.gelu(x, approximate="tanh")

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')

We can see this implementation is **5x slower** than PyTorch.

In [7]:
%timeit gelu(x); torch.cuda.synchronize()

118 µs ± 97.9 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [8]:
%timeit F.gelu(x, approximate="tanh"); torch.cuda.synchronize()

23.5 µs ± 235 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


# Compare Fused Kernel with PyTorch

In [2]:
import os
from pathlib import Path

cuda_source = Path("gelu.cu").read_text()
cpp_source = "torch::Tensor gelu(const torch::Tensor& input);"
# You may need to check the line below
os.environ["CUDA_HOME"] = "/public/apps/cuda/12.1"

In [3]:
from torch.utils.cpp_extension import load_inline

module = load_inline(
    name="gelu",
    cpp_sources=cpp_source,
    cuda_sources=cuda_source,
    functions=["gelu"],
    with_cuda=True,
    extra_cuda_cflags=["-O2"],
    # build_directory='./cuda_build',
)

If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


In [5]:
# Make sure the results are the same.
module.gelu(x) - F.gelu(x, approximate="tanh")

tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00, -1.4901e-08],
        [ 0.0000e+00, -5.9605e-08,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        ...,
        [ 0.0000e+00, -2.9802e-08, -5.2154e-08,  ...,  0.0000e+00,
         -5.9605e-08,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00, -2.9802e-08,  ..., -4.4703e-08,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00, -5.9605e-08,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00]], device='cuda:0')

In [10]:
(module.gelu(x) - gelu(x)).abs().max()

tensor(2.3842e-07, device='cuda:0')

We can see our custom kernel is close to PyTorch Implementation in terms of latency.

In [11]:
%timeit _ = module.gelu(x); torch.cuda.synchronize()

27.3 µs ± 81.6 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


We fused all computation around the input into a single kernel, to reduce the data movement between global memory and SRAM, thus reduced latency.