In [None]:
!pip install pycuda -q

import pycuda.autoinit
import pycuda.driver as drv
import numpy as np
from pycuda.compiler import SourceModule

Layernorm 用于将张量在 channel 维度的值归一化为均值为0方差为1，核心计算在于计算均值和方差。gamma 和 beta 是可学习的参数:
$$\text{LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2}} + \beta$$


Layernorm kernel（layernorm_fwd）采用比较直观的实现：一个 warp 负责计算一个 channel。将 thread block 划分成 warp groups，每个 warp 映射到输入输出张量上的一个 channel。

In [None]:
bs = 4            # batch size
n_seq = 512       # sequence len
n_cxt = 1024      # max context len
n_hidden = 768    # hidden size
n_vocab = 50237   # vocab_size

np.random.seed(42)
gamma = np.random.randn(n_hidden).astype(np.float32)
beta = np.random.randn(n_hidden).astype(np.float32)

prog = SourceModule("""
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>

namespace cg = cooperative_groups;

extern "C"
__global__ void layernorm_fwd(float *out, const float *inp, const float *gamma, const float *beta, int B, int T, int C) {
  cg::thread_block tb = cg::this_thread_block();
  cg::thread_block_tile<32> warp = cg::tiled_partition<32>(tb);
  int warp_idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();
  int N = B * T;

  const float *x = inp + warp_idx * C;

  if (warp_idx < N) {
    // mean
    float sum = 0.0f;
    for (int i = warp.thread_rank(); i < C; i += warp.num_threads()) {
      sum += x[i];
    }
    float mean = cg::reduce(warp, sum, cg::plus<float>{}) / C;

    // variance
    float sum_sq = 0.0f;
    for (int i = warp.thread_rank(); i < C; i += warp.num_threads()) {
      float diff = x[i] - mean;
      sum_sq += diff * diff;
    }
    float var = cg::reduce(warp, sum_sq, cg::plus<float>{}) / C;
    var = rsqrt(var + 1e-5f);

    // normalize
    float *y = out + warp_idx * C;
    for (int i = warp.thread_rank(); i < C; i += warp.num_threads()) {
      float s = (__ldcs(x + i) - mean) * var;
      y[i] = gamma[i] * s + beta[i];
    }
  }
}
""", no_extern_c=True)

block size = 512，一个 block 计算 512/32 channels，一个 grid 需要 (bs * n_seq) / (block_size / 32) 个 blocks。

In [None]:
layernorm_fwd = prog.get_function("layernorm_fwd")
out = np.empty((bs, n_seq, n_hidden), dtype=np.float32)
input = np.random.randn(bs, n_seq, n_hidden).astype(np.float32)
block_size = 512
grid_size = int(np.ceil(bs * n_seq * 32 / block_size))
layernorm_fwd(drv.Out(out), drv.In(input), drv.In(gamma), drv.In(beta), np.int32(bs), np.int32(n_seq), np.int32(n_hidden), block=(block_size,1,1), grid=(grid_size,1,1))

  globals().clear()


验证计算结果：

In [None]:
def ref_layernorm_fwd(input, gamma, beta):
  eps = 1e-5
  mean = np.mean(input, axis=-1, keepdims=True)
  variance = np.var(input, axis=-1, keepdims=True)
  x = (input - mean) / np.sqrt(variance + eps)
  return gamma * x + beta

np.allclose(out, ref_layernorm_fwd(input, gamma, beta))

True