### 相关接口函数

C++主要是cmath/math.h，见[link](https://legacy.cplusplus.com/reference/cmath/);  
CUDA主要是cuda-math-api，见[link](https://docs.nvidia.com/cuda/cuda-math-api/index.html)。

### python加载cuda kernal示例

-U...: Undefine restrictions for numerical types  ——取消定义宏定义。  
__functionName()命名约定的函数直接映射到硬件级别。 它们速度更快，但准确度稍低（如__sinf(x)和__expf(x)），functionName()命名约定的函数较慢但具有较高的准确性（例如，sinf(x)和expf(x)，-use_fast_math编译器选项强制每个functionName()调用等效的__functionName()调用。  
--expt-relaxed-constexpr/--expt-extended-lambda，**含义未理解**。


```
lib = load(name='elementwise_lib', 
           sources=['elementwise.cu'], 
           extra_cuda_cflags=[
               "-O3",
                "-U__CUDA_NO_HALF_OPERATORS__",
                "-U__CUDA_NO_HALF_CONVERSIONS__",
                "-U__CUDA_NO_HALF2_OPERATORS__",
                "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
                "--expt-relaxed-constexpr",
                "--expt-extended-lambda",
                "--use_fast_math",
            ], 
           extra_cflags=['-std=c++17'])
```

### benchmark示例

先执行一段warmup，完成Kernel launch；再通过批量测试统计耗时情况。两步均通过**torch.cuda.synchronize()**完成同步。

```
def run_benchmark(perf_func: callable, a: torch.Tensor, b: torch.Tensor, tag: str, 
                  out: Optional[torch.Tensor] = None, warmup: int = 10, 
                  iters: int = 1000, show_all: bool = False):
    # torch.dot vs custom dot_prod kernel
    if out is not None: 
        out.fill_(0)    
    # warmup
    if out is not None:
        for i in range(warmup):
            perf_func(a, b, out)
    else:
        for i in range(warmup):
            _ = perf_func(a, b) 
    torch.cuda.synchronize()
    start = time.time()
    # iters
    if out is not None:
        for i in range(iters):
            perf_func(a, b, out)
    else:
        for i in range(iters):
            out = perf_func(a, b) 
    torch.cuda.synchronize()
    end = time.time()
    total_time = (end - start) * 1000 # ms
    mean_time = total_time / iters
    out_info = f"out_{tag}"
    out_val = out.flatten().detach().cpu().numpy().tolist()[:2]
    out_val = [round(v, 8) for v in out_val]
    print(f"{out_info:>18}: {out_val}, time:{mean_time:.8f}ms")
    if show_all: print(out)
    return out, mean_time
```

### cuda kernal实现要点

1. f32实现、f16实现分别执行2个f32、2个f16的elementwise运算，[Half Arithmetic Functions](https://docs.nvidia.com/cuda/cuda-math-api/cuda_math_api/group__CUDA__MATH____HALF__ARITHMETIC.html)。

```
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < N) c[idx] = a[idx] + b[idx];

int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < N) c[idx] = __hadd(a[idx], b[idx]);
```

2. f32x4、f16x2、f16x8类似向量化，执行多个f32或f16的elementwise运算。
```
int idx = 4 * (blockIdx.x * blockDim.x + threadIdx.x);
if (idx < N) {
    float4 reg_a = FLOAT4(a[idx]);
    float4 reg_b = FLOAT4(b[idx]);
    float4 reg_c;
    reg_c.x = reg_a.x + reg_b.x;
    reg_c.y = reg_a.y + reg_b.y;
    reg_c.z = reg_a.z + reg_b.z;
    reg_c.w = reg_a.w + reg_b.w;
    FLOAT4(c[idx]) = reg_c;
}

int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
if (idx < N) {
    half2 reg_a = HALF2(a[idx]);
    half2 reg_b = HALF2(b[idx]);
    half2 reg_c;
    reg_c.x = __hadd(reg_a.x, reg_b.x);
    reg_c.y = __hadd(reg_a.y, reg_b.y);
    HALF2(c[idx]) = reg_c;
}

int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x);
half2 reg_a_0 = HALF2(a[idx + 0]);
half2 reg_a_1 = HALF2(a[idx + 2]);
half2 reg_a_2 = HALF2(a[idx + 4]);
half2 reg_a_3 = HALF2(a[idx + 6]);
half2 reg_b_0 = HALF2(b[idx + 0]);
half2 reg_b_1 = HALF2(b[idx + 2]);
half2 reg_b_2 = HALF2(b[idx + 4]);
half2 reg_b_3 = HALF2(b[idx + 6]);
half2 reg_c_0, reg_c_1, reg_c_2, reg_c_3;
reg_c_0.x = __hadd(reg_a_0.x, reg_b_0.x);
reg_c_0.y = __hadd(reg_a_0.y, reg_b_0.y);
reg_c_1.x = __hadd(reg_a_1.x, reg_b_1.x);
reg_c_1.y = __hadd(reg_a_1.y, reg_b_1.y);
reg_c_2.x = __hadd(reg_a_2.x, reg_b_2.x);
reg_c_2.y = __hadd(reg_a_2.y, reg_b_2.y);
reg_c_3.x = __hadd(reg_a_3.x, reg_b_3.x);
reg_c_3.y = __hadd(reg_a_3.y, reg_b_3.y);
if ((idx + 0) < N) { HALF2(c[idx + 0]) = reg_c_0; }
if ((idx + 2) < N) { HALF2(c[idx + 2]) = reg_c_1; }
if ((idx + 4) < N) { HALF2(c[idx + 4]) = reg_c_2; }
if ((idx + 6) < N) { HALF2(c[idx + 6]) = reg_c_3; }

```

3. f16x8_pack特点在于**同时加载128位**，运算过程同f16x8基本等效。

```
int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x);
  // temporary register(memory), .local space in ptx, addressable
  half pack_a[8], pack_b[8], pack_c[8]; // 8x16 bits=128 bits.
  // reinterpret as float4 and load 128 bits in 1 memory issue.
  LDST128BITS(pack_a[0]) = LDST128BITS(a[idx]); // load 128 bits
  LDST128BITS(pack_b[0]) = LDST128BITS(b[idx]); // load 128 bits

  #pragma unroll
  for (int i = 0; i < 8; i += 2) {
    // __hadd2 for half2 x 4
    HALF2(pack_c[i]) = __hadd2(HALF2(pack_a[i]), HALF2(pack_b[i]));
  }
  // reinterpret as float4 and store 128 bits in 1 memory issue.
  if ((idx + 7) < N) { LDST128BITS(c[idx]) = LDST128BITS(pack_c[0]); }
```

In [None]:
!TORCH_CUDA_ARCH_LIST=Ampere python3 elementwise.py

## sigmoid/relu/elu/gelu

### sigmoid cuda kernal实现

当x的值非常大时，expf(x)会溢出。对于float类型的输入，这个溢出点通常发生在x大约等于88.7；当x的值非常小时，expf(x)会趋近于零，下溢点通常发生在x小于大约-88.7时。  
对于公式中常量，采用__float2half由f32精度转换至half精度。

```
// -------------------------------------- FP32 -------------------------------------- 
// Sigmoid x: N, y: N y=1/(1+exp(-x))
// grid(N/256), block(K=256) 
__global__ void sigmoid_f32_kernel(float* x, float* y, int N) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < N) {
    float v = x[idx];
    // fminf/fmaxf/expf 出自cmath/math.h
    v = fminf(fmaxf(v, MIN_EXP_F32), MAX_EXP_F32); 
    y[idx] = 1.0f / (1.0f + expf(-v));
  }
}

// -------------------------------------- FP16 -------------------------------------- 
__global__ void sigmoid_f16_kernel(half* x, half* y, int N) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  const half f = __float2half(1.0f);
  if (idx < N) {
    half v = x[idx];
    // __hmin/__hmax/hexp 出自Half Precision Intrinsics
    v = __hmin(__hmax(v, MIN_EXP_F16), MAX_EXP_F16);
    y[idx] = f / (f + hexp(-v));
  }
}
```

### relu cuda kernal实现

```
// -------------------------------------- FP32 -------------------------------------- 
// Relu x: N, y: N y=max(0,x)
// grid(N/256), block(K=256) 
__global__ void relu_f32_kernel(float* x, float* y, int N) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < N) y[idx] = fmaxf(0.0f, x[idx]);
}

// -------------------------------------- FP16 -------------------------------------- 
__global__ void relu_f16_kernel(half* x, half* y, int N) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < N) y[idx] = __hmax(__float2half(0.0f), x[idx]);
}
```

### elu cuda kernal实现

$$
\text{ELU}(x) = 
\begin{cases} 
x & \text{if } x > 0 \\
\alpha (\exp(x) - 1) & \text{if } x \leq 0 
\end{cases}
$$

```
// ELU 计算函数
// inline 是标准C/C++的一部分,__forceinline__ 是MSVC特有的指令
// -------------------------------------- FP32 --------------------------------------
__device__ __forceinline__ float elu(float x) {
  return x > 0.f ? x : ALPHA * (expf(x) - 1.f);
}

// -------------------------------------- FP16 --------------------------------------
__device__ __forceinline__ half elu_half(half x) {
  return __hgt(x, __float2half(0.f)) ? x : __hmul(__float2half(ALPHA), __hsub(hexp(x), __float2half(1.f)));
}

// -------------------------------------- FP32 --------------------------------------
__global__ void elu_f32_kernel(float* x, float* y, int N) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < N) y[idx] = elu(x[idx]);
}

// -------------------------------------- FP16 --------------------------------------
__global__ void elu_f16_kernel(half* x, half* y, int N) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < N) y[idx] = elu_half(x[idx]);
}
```

### gelu cuda kernal实现

$$
\text{GELU}(x) = 0.5x \left(1 + \tanh\left(\sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3\right)\right)\right)
$$

```
__inline__ __device__ float gelu_tanh_approximate(float x){
  return 0.5f * x * (1.0f + tanhf(SQRT_2_PI * (x + 0.044715f * x * x * x)));
}

__inline__ __device__ half gelu_tanh_approximate(half x){
  half x_cube = x * x * x;
  // compute mid value : inner = 0.7978845608 * (x + 0.044715 * x * x * x)
  half inner = HALF_SQRT_2_PI * (x + HALF_V_APP * x_cube);
  // compute tanh
  return HALF_DIV2 * x * (HALF_1 + ((hexp(inner * HALF_2) - HALF_1) / (hexp(inner * HALF_2) + HALF_1))); 
}

// -------------------------------------- FP32 -------------------------------------- 
// GELU tanh approximate: x, y:x 0.5 * x * (1.0 + tanh(0.7978845608 * x * (1.0 + 0.044715 * x * x)))
// grid(N/256), block(K=256) 
__global__ void gelu_f32_kernel(float* x, float* y, int N) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < N) {
    float v = fminf(fmaxf(x[idx], MIN_EXP_F32), MAX_EXP_F32);
    y[idx] = GELU_OPS(v);
  }
}

__global__ void gelu_f16_kernel(half* x, half* y, int N) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < N) {
    half v = x[idx];
    v = __hmin(__hmax(v, MIN_EXP_F16), MAX_EXP_F16);
    
    y[idx] = HALF_GELU_OPS(v);
  }
}
```



In [None]:
!TORCH_CUDA_ARCH_LIST=Ampere python3 sigmoid.py
!TORCH_CUDA_ARCH_LIST=Ampere python3 relu.py
!TORCH_CUDA_ARCH_LIST=Ampere python3 elu.py
!TORCH_CUDA_ARCH_LIST=Ampere python3 gelu.py

## swish/hardswish/hardshrink

### swish

$$
\text{Swish}(x) = x \cdot \sigma(x) = \frac{x}{1 + e^{-x}}
$$

```
// -------------------------------------- FP32 --------------------------------------
// Swish x: N, y: N y=x*sigmoid(x)
__device__ __forceinline__ float swish(float x) {
  return x / (1.0f + expf(-x));
}

__global__ void swish_f32_kernel(float* x, float* y, int N) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < N) y[idx] = swish(x[idx]);
}

// -------------------------------------- FP16 --------------------------------------
__device__ __forceinline__ half swish_half(half x) {
  return __hmul(x, __hdiv(
    __float2half(1.0f), __hadd(__float2half(1.0f), hexp(__hneg(x)))));
}

__global__ void swish_f16_kernel(half* x, half* y, int N) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < N) y[idx] = swish_half(x[idx]);
}
```

### hardswish

$$
\text{HardSwish}(x) = 
\begin{cases} 
0, & \text{if } x \leq -3 \\
x \left( \frac{x}{6} + \frac{1}{2} \right), & \text{if } -3 < x < 3 \\
x, & \text{if } x \geq 3 
\end{cases}
$$

```
// -------------------------------------- FP32 --------------------------------------
__device__ __forceinline__ float hardswish(float x) {
  if (x >= THRESHOLD_A) {
    return x;
  } else if (x <= THRESHOLD_B) {
    return 0;
  } else {
    return x * (x + 3) / 6;
  }
}
__global__ void hardswish_f32_kernel(float* x, float* y, int N) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < N) y[idx] = hardswish(x[idx]);
}

// -------------------------------------- FP16 --------------------------------------
__device__ __forceinline__ half hardswish_half(half x) {
  if (x > __float2half(THRESHOLD_A)) {
    return x;
  } else if (x < __float2half(THRESHOLD_B)) {
    return __float2half(0.f);
  } else {
    return x * (x + __float2half(3.f)) / __float2half(6.f);
  }
}
__global__ void hardswish_f16_kernel(half* x, half* y, int N) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < N) y[idx] = hardswish_half(x[idx]);
}

```

### hardshrink

$$
\text{HardShrink}(x) = 
\begin{cases} 
x, & \text{if } x > \lambda \\
x, & \text{if } x < -\lambda \\
0, & \text{otherwise} 
\end{cases}
$$

```
// -------------------------------------- FP32 --------------------------------------
__device__ __forceinline__ float hardshrink(float x) {
  if (x > LAMBD || x < -LAMBD) {
    return x;
  } else {
    return 0;
  }
}
__global__ void hardshrink_f32_kernel(float* x, float* y, int N) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < N) y[idx] = hardshrink(x[idx]);
}

// -------------------------------------- FP16 --------------------------------------
__device__ __forceinline__ half hardshrink_half(half x) {
  if(x > __float2half(LAMBD) || x < __float2half(-LAMBD)) {
    return x;
  } else {
    return __float2half(0.f);
  }
}
__global__ void hardshrink_f16_kernel(half* x, half* y, int N) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < N) y[idx] = hardshrink_half(x[idx]);
}
```


In [None]:
!TORCH_CUDA_ARCH_LIST=Ampere python3 swish.py
!TORCH_CUDA_ARCH_LIST=Ampere python3 hardswish.py
!TORCH_CUDA_ARCH_LIST=Ampere python3 hardshrink.py