In [1]:
# code for paper 'Transformer tricks: flash normalization'

import numpy as np

# reciprocal of RMS and activation functions
def r_rms(x): return 1 / np.sqrt(np.mean(x**2))
def r_ms(x): return 1 / np.mean(x**2)
def relu(x): return np.maximum(0, x)
def sigmoid(x): return 1 / (1 + np.exp(-x))
def silu(x): return x * sigmoid(x)  # often known as swish

# merge normalization weights g into weight matrix W
def flashify(g, W):
  Wnew = np.empty(W.shape)
  for i in range(g.shape[0]):
    Wnew[i, :] = g[i] * W[i, :]
  return Wnew

# alternative flashify (same as above but fewer lines)
#def flashify_alt(g, W):
#  G = np.repeat(g, W.shape[1]).reshape(W.shape)
#  return G * W  # elementwise multiply

In [2]:
# variables
n = 32
f = 128
a = np.random.rand(n)  # row-vector
g = np.random.rand(n)  # row-vector
W = np.random.rand(n, n)
UP = np.random.rand(n, f)
GATE = np.random.rand(n, f)
DOWN = np.random.rand(f, n)

# derived variables
s = r_rms(a)  # scaling factor
Wstar = flashify(g, W)
UPstar = flashify(g, UP)
GATEstar = flashify(g, GATE)

In [3]:
# code for section 1 of paper

# figures 1(a), 1(b), and 1(c) of paper
z_fig1a = (r_rms(a) * a * g) @ W
z_fig1b = (r_rms(a) * a) @ Wstar
z_fig1c = (a @ Wstar) * r_rms(a)

# compare against z_fig1a
print(np.allclose(z_fig1b, z_fig1a), '  (fig1b is close to fig1a if True)')
print(np.allclose(z_fig1c, z_fig1a), '  (fig1c is close to fig1a if True)')

True   (fig1b is close to fig1a if True)
True   (fig1c is close to fig1a if True)


In [4]:
# code for section 2.1 of paper

# reference and figures 2(a) and 2(b) of paper
y_ref2  = relu((s * a * g) @ UP) @ DOWN
y_fig2a = relu((a @ UPstar) * s) @ DOWN
y_fig2b = (relu(a @ UPstar) @ DOWN) * s

# compare against y_ref
print(np.allclose(y_fig2a, y_ref2), '  (fig2a is close to reference if True)')
print(np.allclose(y_fig2b, y_ref2), '  (fig2b is close to reference if True)')

True   (fig2a is close to reference if True)
True   (fig2b is close to reference if True)


In [5]:
# code for section 2.2 of paper

# shortcuts
a_norm = s * a * g
a_gate, a_up = (a @ GATEstar), (a @ UPstar)

# figure 3: reference and figures 3(a) and 3(b) of paper
y_ref3  = ((a_norm @ GATE) * silu(a_norm @ UP)) @ DOWN
y_fig3a = (a_gate * s * silu(a_up * s)) @ DOWN
y_fig3b = ((a_gate * silu(a_up * s)) @ DOWN) * s

# compare against y_ref3
print(np.allclose(y_fig3a, y_ref3), '  (fig3a is close to reference if True)')
print(np.allclose(y_fig3b, y_ref3), '  (fig3b is close to reference if True)')

# figure 4: reference and figures 4(a) and 4(b) of paper
y_ref4  = ((a_norm @ GATE) * relu(a_norm @ UP)) @ DOWN
y_fig4a = (a_gate * s * relu(a_up * s)) @ DOWN
y_fig4b = ((a_gate * relu(a_up)) @ DOWN) * r_ms(a)

# compare against y_ref4
print(np.allclose(y_fig4a, y_ref4), '  (fig4a is close to reference if True)')
print(np.allclose(y_fig4b, y_ref4), '  (fig4b is close to reference if True)')

True   (fig3a is close to reference if True)
True   (fig3b is close to reference if True)
True   (fig4a is close to reference if True)
True   (fig4b is close to reference if True)


In [None]:
# code for section 3 of paper

# TODO