In [1]:
!pip install pycuda -q
import pycuda.autoinit
import pycuda.driver as drv
import numpy as np
from pycuda.compiler import SourceModule

In [2]:
bs = 4            # batch size
n_seq = 512       # sequence len
n_cxt = 1024      # max context len
n_hidden = 768    # hidden size
n_vocab = 50237   # vocab_size

np.random.seed(42)
wte = np.random.randn(n_vocab, n_hidden).astype(np.float32)
wpe = np.random.randn(n_cxt, n_hidden).astype(np.float32)

prog = SourceModule("""
//typedef __attribute__((__ext_vector_type__(4))) float smid_float4;

__device__ inline float4 add_float4(float4 a, float4 b) {
  float4 c;
  c.x = a.x + b.x;
  c.y = a.y + b.y;
  c.z = a.z + b.z;
  c.w = a.w + b.w;
  return c;
}

__global__ void embedding_fwd(float4 *out, const int *input, const float4 *wte, const float4 *wpe, int B, int T, int C) {
  int C4 = C / 4;
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  int N = B * T * C;
  if (idx < N) {
    int bt = idx / C4;
    int t = bt % T;
    int b = bt / T;
    int c = idx % C4;
    int wte_idx = input[b * T + t];
    out[b * T * C4 + t * C4 + c] = add_float4(wte[wte_idx * C4 + c], wpe[t * C4 + c]);
  }
}
""")

In [3]:
def ref_embedding_fwd(input, wte, wpe):
  return wte[input] + wpe[range(input.shape[-1])]

In [4]:
embedding_fwd = prog.get_function("embedding_fwd")
out = np.empty((bs, n_seq, n_hidden), dtype=np.float32)
input = np.random.randint(n_vocab, size=(bs, n_seq), dtype=np.int32)
N = bs * n_seq * n_hidden
N4 = N // 4
embedding_fwd(drv.Out(out), drv.In(input), drv.In(wte), drv.In(wpe), np.int32(bs), np.int32(n_seq), np.int32(n_hidden), block=(512,1,1), grid=(int(np.ceil(N4 / 512)),1,1))
np.allclose(out, ref_embedding_fwd(input, wte, wpe))

True