<a href="https://colab.research.google.com/github/NShravanReddy/DeepLearning/blob/main/triton/GELU_triton.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import triton
import triton.language as tl
import torch
import time
import torch.nn as nn

@triton.jit
def t_g_k(x_ptr, y_ptr, N0, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < N0
    x = tl.load(x_ptr + offsets, mask=mask)

    # GELU approximation
    a = 0.7978845608028654 * (x + 0.044715 * x * x * x)
    exp2a = tl.exp(2 * a)
    tanha = (exp2a - 1) / (exp2a + 1)
    y = 0.5 * x * (1 + tanha)

    tl.store(y_ptr + offsets, y, mask=mask)

def t_g_k_h(x: torch.Tensor, BLOCK_SIZE=1024):
    assert x.is_cuda
    y = torch.empty_like(x)
    N0 = x.numel()
    grid = lambda meta: (triton.cdiv(N0, meta['BLOCK_SIZE']),)
    t_g_k[grid](x, y, N0, BLOCK_SIZE=BLOCK_SIZE)
    return y

def benchmark(description: str, run, num_warmups: int = 1, num_trials: int = 3):
    # Warm-up runs
    for _ in range(num_warmups):
        run()
    if torch.cuda.is_available():
        torch.cuda.synchronize()

    # Timing
    times = []
    for _ in range(num_trials):
        start_time = time.perf_counter()
        run()
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        end_time = time.perf_counter()

        elapsed_ms = (end_time - start_time) * 1000
        times.append(elapsed_ms)

    mean_time = sum(times) / len(times)

    print(f"{description}: {list(map(lambda x: round(x, 1), sorted(times)))} (mean {round(mean_time, 3)} ms)")

    return mean_time

if __name__ == '__main__':
    N = 1024 * 1024
    BLOCK_SIZE = 1024
    print(torch.cuda.memory_summary())
    torch.cuda.empty_cache()
    # Input tensor
    x = torch.randn(N, device='cuda', dtype=torch.float32)

    # Triton GELU
    GELU_triton = t_g_k_h(x, BLOCK_SIZE)
    # PyTorch GELU
    gelu = nn.GELU(approximate='tanh')

    GELU_pytorch = gelu(x)

    # Verify correctness
    print("Triton GELU output:", GELU_triton)
    print("PyTorch GELU output:", GELU_pytorch)
    print("Are they close?", torch.allclose(GELU_triton, GELU_pytorch, atol=1e-5))

    # Benchmarking forward pass
    x_torch = x.detach().clone().requires_grad_()

    GELU_pytorch = nn.GELU()

    triton_time = benchmark("Triton GELU", lambda: t_g_k_h(x, BLOCK_SIZE))
    torch_time = benchmark("PyTorch GELU", lambda: GELU_pytorch(x))
    print(torch.cuda.memory_summary())

    print(f"\nAverage execution time (Forward Pass):")
    print(f"  Triton GELU = {triton_time:.3f} ms")
    print(f"  PyTorch GELU = {torch_time:.3f} ms")


|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 2            |        cudaMalloc retries: 2         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  12288 MiB |  13312 MiB |  13312 MiB |   1024 MiB |
|       from large pool |  12288 MiB |  13312 MiB |  13312 MiB |   1024 MiB |
|       from small pool |      0 MiB |      0 MiB |      0 MiB |      0 MiB |
|---------------------------------------------------------------------------|
| Active memory         |  12288 MiB |  13312 MiB |  13312 MiB |   1024 MiB |
|       from large pool |  12288 MiB |  13312 MiB |  13312 MiB |   1024 MiB |
|       from small pool |      0 MiB |      0 MiB |      0 MiB |      0 MiB |
|---------------------------------------------------------------