In [1]:
%load_ext lab_black

In [2]:
import math
from typing import Union

import torch as th
import torch.nn.functional as F
import triton
import triton.language as tl
from kitsu.utils.utils import cummul
from torch import Tensor, nn
from torch.nn.parameter import Parameter

from fast_gem.functional import triton_utils as tu

In [3]:
th.cuda.set_device(6)

In [4]:
def gem_torch_3d(x: Tensor, p: Union[float, Tensor] = 3.0, eps=1e-6, keepdim=True):
    assert x.ndim == 5, f"Unknown `x` shape: {x.shape}"
    x = x.clamp_min(eps).pow_(p).mean((2, 3, 4), keepdim=keepdim).pow_(1.0 / p)
    return x

In [8]:
@triton.autotune(
    configs=[
        # triton.Config({"BLK_L": 256, "BLK_M": 32, "BLK_N": 32}),
        # triton.Config({"BLK_L": 128, "BLK_M": 64, "BLK_N": 32}),
        # triton.Config({"BLK_L": 128, "BLK_M": 32, "BLK_N": 64}),
        # triton.Config({"BLK_L": 64, "BLK_M": 128, "BLK_N": 32}),
        # triton.Config({"BLK_L": 64, "BLK_M": 64, "BLK_N": 64}),
        # triton.Config({"BLK_L": 64, "BLK_M": 32, "BLK_N": 128}),
        # triton.Config({"BLK_L": 32, "BLK_M": 256, "BLK_N": 32}),
        # triton.Config({"BLK_L": 32, "BLK_M": 128, "BLK_N": 64}),
        # triton.Config({"BLK_L": 32, "BLK_M": 64, "BLK_N": 128}),
        triton.Config({"BLK_L": 32, "BLK_M": 32, "BLK_N": 256}),
    ],
    key=["str_x_L", "str_x_M", "str_x_N"],
)
@triton.jit
def gem_forward_3d_kernel(
    x_ptr,
    y_ptr,
    p,
    eps,
    str_x_B,
    str_x_C,
    str_x_L,
    str_x_M,
    str_x_N,
    str_y_B,
    str_y_C,
    L,
    M,
    N,
    IS_P_TENSOR: tl.constexpr,
    BLK_L: tl.constexpr,
    BLK_M: tl.constexpr,
    BLK_N: tl.constexpr,
):
    pid_b = tl.program_id(0)
    pid_c = tl.program_id(1)
    offs_l = tl.arange(0, BLK_L)  # l
    offs_m = tl.arange(0, BLK_M)  # m
    offs_n = tl.arange(0, BLK_N)  # m
    x_ptrs = (
        x_ptr
        + pid_b * str_x_B
        + pid_c * str_x_C
        + offs_l[:, None, None] * str_x_L
        + offs_m[None, :, None] * str_x_M
        + offs_n[None, None, :] * str_x_N
    )

    if IS_P_TENSOR:
        p = tl.load(p)

    y = 0.0
    for idx_l in range(tl.cdiv(L, BLK_L)):
        mask_l = offs_l[:, None, None] < L - idx_l * BLK_L
        for idx_m in range(tl.cdiv(M, BLK_M)):
            mask_m = offs_m[None, :, None] < M - idx_m * BLK_M
            for idx_n in range(tl.cdiv(N, BLK_N)):
                mask = mask_l & mask_m & (offs_n[None, None, :] < N - idx_n * BLK_N)

                x = tl.load(x_ptrs, mask=mask, other=0.0)  # l m

                # calculate adaptive average pooling
                x = tl.where((x < eps) & mask, eps, x)
                x = tu.pow(x, p)  # l
                y += tl.sum(x)  # 1

                x_ptrs += BLK_N * str_x_N
            x_ptrs += BLK_M * str_x_M
        x_ptrs += BLK_L * str_x_L

    y /= L * M * N
    y = tu.pow(y, 1 / p)

    y_ptrs = y_ptr + pid_b * str_y_B + pid_c * str_y_C
    tl.store(y_ptrs, y)

In [9]:
@triton.autotune(
    configs=[
        # triton.Config({"BLK_L": 256, "BLK_M": 32, "BLK_N": 32}),
        # triton.Config({"BLK_L": 128, "BLK_M": 64, "BLK_N": 32}),
        # triton.Config({"BLK_L": 128, "BLK_M": 32, "BLK_N": 64}),
        # triton.Config({"BLK_L": 64, "BLK_M": 128, "BLK_N": 32}),
        # triton.Config({"BLK_L": 64, "BLK_M": 64, "BLK_N": 64}),
        # triton.Config({"BLK_L": 64, "BLK_M": 32, "BLK_N": 128}),
        # triton.Config({"BLK_L": 32, "BLK_M": 256, "BLK_N": 32}),
        # triton.Config({"BLK_L": 32, "BLK_M": 128, "BLK_N": 64}),
        # triton.Config({"BLK_L": 32, "BLK_M": 64, "BLK_N": 128}),
        triton.Config({"BLK_L": 32, "BLK_M": 32, "BLK_N": 256}),
    ],
    key=["str_x_L", "str_x_M", "str_x_N"],
    reset_to_zero=["dp_ptr"],
)
@triton.jit
def gem_backward_3d_kernel(
    x_ptr,
    y_ptr,
    p,
    dx_ptr,
    dy_ptr,
    dp_ptr,
    eps,
    str_x_B,
    str_x_C,
    str_x_L,
    str_x_M,
    str_x_N,
    str_y_B,
    str_y_C,
    L,
    M,
    N,
    IS_P_TENSOR: tl.constexpr,
    BLK_L: tl.constexpr,
    BLK_M: tl.constexpr,
    BLK_N: tl.constexpr,
):
    pid_b = tl.program_id(0)
    pid_c = tl.program_id(1)
    offs_l = tl.arange(0, BLK_L)  # l
    offs_m = tl.arange(0, BLK_M)  # m
    offs_n = tl.arange(0, BLK_N)  # m

    offs_x = (
        pid_b * str_x_B
        + pid_c * str_x_C
        + offs_l[:, None, None] * str_x_L
        + offs_m[None, :, None] * str_x_M
        + offs_n[None, None, :] * str_x_N
    )
    x_ptrs = x_ptr + offs_x
    dx_ptrs = dx_ptr + offs_x
    y_ptrs = y_ptr + pid_b * str_y_B + pid_c * str_y_C
    dy_ptrs = dy_ptr + pid_b * str_y_B + pid_c * str_y_C

    if IS_P_TENSOR:
        p = tl.load(p)

    # calculate y-level grad
    y = tl.load(y_ptrs)
    dy = tl.load(dy_ptrs)

    if IS_P_TENSOR:
        dp = -tl.log(y) / p * y * dy
    dy = dy / p * y / (tu.pow(y, p) * L * M * N)

    # re-calculate x for x-level grad
    for idx_l in range(tl.cdiv(L, BLK_L)):
        mask_l = offs_l[:, None, None] < L - idx_l * BLK_L
        for idx_m in range(tl.cdiv(M, BLK_M)):
            mask_m = offs_m[None, :, None] < M - idx_m * BLK_M
            for idx_n in range(tl.cdiv(N, BLK_N)):
                mask = mask_l & mask_m & (offs_n[None, None, :] < N - idx_n * BLK_N)
                x = tl.load(x_ptrs, mask=mask, other=0.0)  # l

                # calculate adaptive average pooling
                x_ = tl.where(mask & (x < eps), eps, x)
                x_p1 = tu.pow(x_, p - 1)
                dx = tl.zeros((BLK_L, BLK_M, BLK_N), dtype=x.dtype) + dy  # l m

                if IS_P_TENSOR:
                    dp_tmp = tl.where(mask, x_p1 * x_ * tl.log(x_) * dx, 0.0)
                    dp += tl.sum(dp_tmp)

                dx *= p * x_p1
                dx = tl.where((x < eps) & mask, 0.0, dx)

                tl.store(dx_ptrs, dx, mask=mask)
                x_ptrs += BLK_N * str_x_N
                dx_ptr += BLK_N * str_x_N
            x_ptrs += BLK_M * str_x_M
            dx_ptrs += BLK_M * str_x_M
        x_ptrs += BLK_L * str_x_L
        dx_ptrs += BLK_L * str_x_L

    if IS_P_TENSOR:
        tl.atomic_add(dp_ptr, dp)

In [10]:
class GeMOps3d(th.autograd.Function):
    @staticmethod
    def forward(ctx, x: Tensor, p: Union[float, Tensor] = 3.0, eps: float = 1e-6, keepdim=True):
        ctx.is_p_tensor = isinstance(p, Tensor)
        assert x.ndim == 5, f"Unknown shape of `x`: {x.shape}"

        B, C, L, M, N = x.shape
        y = x.new_empty(list(x.shape[:2]) + ([1, 1, 1] if keepdim else []))  # b c
        # str_x_B, str_x_C, str_x_L, str_x_M, str_x_N = x.stride()
        # str_y_B, str_y_C = y.stride(0), y.stride(1)
        strides = (*x.stride(), *y.stride()[:2])

        grid = lambda meta: (B, C)
        gem_forward_3d_kernel[grid](x, y, p, eps, *strides, L, M, N, IS_P_TENSOR=ctx.is_p_tensor)

        if ctx.is_p_tensor:
            ctx.save_for_backward(x, p, y)
            ctx.params = (eps,)
        else:
            ctx.save_for_backward(x, y)
            ctx.params = p, eps
        return y

    @staticmethod
    def backward(ctx, dy: Tensor):
        if ctx.is_p_tensor:
            x, p, y = ctx.saved_tensors
            (eps,) = ctx.params
        else:
            x, y = ctx.saved_tensors
            p, eps = ctx.params

        B, C, L, M, N = x.shape
        # str_x_B, str_x_C, str_x_L, str_x_M, str_x_N = x.stride()
        # str_y_B, str_y_C = y.stride()
        strides = (*x.stride(), *y.stride()[:2])

        dx = th.empty_like(x)
        dp = None
        if ctx.is_p_tensor:
            dp = th.zeros_like(p)

        grid = lambda meta: (B, C)
        gem_backward_3d_kernel[grid](x, y, p, dx, dy, dp, eps, *strides, L, M, N, IS_P_TENSOR=ctx.is_p_tensor)
        return dx, dp, None, None, None


def gem_ops3d(x: Tensor, p: Union[float, Tensor] = 3.0, eps: float = 1e-6, keepdim=True):
    return GeMOps3d.apply(x, p, eps, keepdim)

In [11]:
B, C, L, M, N = 5, 6, 7, 8, 9
x = th.rand(B, C, L, M, N, device="cuda", requires_grad=True)
x = x.permute(0, 2, 1, 3, 4)
x.retain_grad()
p = th.full((1,), 2.0, device="cuda", requires_grad=True)
# y_gt = th.randn(B, C, device="cuda")
y_gt = th.randn(B, L, device="cuda")

In [None]:
y_tri = gem_ops3d(x, p, keepdim=False)
y_tri.backward(y_gt, retain_graph=True)
x_grad_tri, x.grad = x.grad, None
p_grad_tri, p.grad = p.grad, None

In [None]:
y_pth = gem_torch_2d(x, p, keepdim=False)
y_pth.backward(y_gt, retain_graph=True)
x_grad_pth, x.grad = x.grad, None
p_grad_pth, p.grad = p.grad, None

In [None]:
th.allclose(y_pth, y_tri)

In [None]:
th.allclose(x_grad_pth, x_grad_tri)

In [None]:
th.allclose(p_grad_pth, p_grad_tri)

In [None]:
p_grad_pth - p_grad_tri

# Benchmark

In [None]:
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["L", "M"],
        x_vals=[16 * 2**i for i in range(1, 8)],
        x_log=False,
        line_arg="provider",
        line_vals=["triton", "torch", "torch_old"],
        line_names=["Triton", "Torch", "Torch_old"],
        styles=[("blue", "-"), ("green", "-"), ("red", "-")],
        ylabel="GB/s",
        plot_name="a",
        args={"B": 2, "C": 64, "p": 2.0},
    )
)
def benchmark(B, C, L, M, p, provider):
    x = th.rand(B, C, L, M, device="cuda")

    quantiles = [0.5, 0.2, 0.8]
    if provider == "torch":
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: gem_torch_2d(x, p, keepdim=False),
            quantiles=quantiles,
        )
    elif provider == "torch_old":
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: gem_torch_old(x, p),
            quantiles=quantiles,
        )
    elif provider == "triton":
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: gem_ops2d(x, p, keepdim=False),
            quantiles=quantiles,
        )

    gbps = lambda ms: (x.nelement() * x.element_size()) * 1e-9 / (ms * 1e-3)
    return gbps(ms), gbps(max_ms), gbps(min_ms)


benchmark.run(show_plots=True, print_data=True)