In [2]:
import jax
from jax import numpy as jnp
from jax import random, vmap, jit, grad
from einstein import jax_codegen
import pyperclip

ModuleNotFoundError: No module named 'einstein_tensors'

In [7]:
s = """
function Linear(x)
    z[o] = w[o,i] * x[i] + b[o] + 1[o] + 1[i] + 1[o_1,i] + x[i]*1[o_2,i] + x[i]*1[o_3]
    return z[o]
end
"""
c = jax_codegen(s)
print(c)
pyperclip.copy(c)
exec(c)

class Linear:
    def __init__(self, o):
        self.old_init = self.init
        self.init = lambda *args, **kwargs: self.old_init(*args, o=o, **kwargs)

    @staticmethod
    def init(key, i, o, **kwargs):
        keys = jax.random.split(key, 2)
        return {
            "w_o_i": jax.random.normal(keys[0], shape=[o, i]),
            "b_o":   jax.random.normal(keys[1], shape=[o])
        }
    
    @staticmethod
    @jit
    @lambda apply: vmap(apply, in_axes=(0, None))
    def apply(x_i, params):
        o, i = params['w_o_i'].shape[0], x_i.shape[0]
        z_o = jnp.einsum('oi,i->o', params['w_o_i'], x_i) + params['b_o'] + i + o*i + (o)*jnp.einsum('i->', x_i) + (o)*jnp.einsum('i->', x_i)
        return z_o


In [22]:
class Linear:
    def __init__(self, o):
        self.old_init = self.init
        self.init = lambda *args, **kwargs: self.old_init(*args, o=o, **kwargs)

    @staticmethod
    def init(key, i, o, **kwargs):
        keys = jax.random.split(key, 2)
        return {
            "w_o_i": jax.random.normal(keys[0], shape=[o, i]),
            "b_o":   jax.random.normal(keys[1], shape=[o])
        }
    
    @staticmethod
    @jit
    @lambda apply: vmap(apply, in_axes=(0, None))
    def apply(x_i, params):
        o, i = params['w_o_i'].shape[0], x_i.shape[0]
        z_o = jnp.einsum('oi,i->o', params['w_o_i'], x_i) + params['b_o'] + i + o*i + (o)*jnp.einsum('i->', x_i) + (o)*jnp.einsum('i->', x_i)
        return z_o

In [23]:
@jit
def linear(x, W, b):
    return jnp.dot(W, x) + b + 500 + 500*200 + 200*x.sum() + 200*x.sum()

linear = jit(vmap(linear, in_axes=(0, None, None)))

lin = Linear(o=200)
params = lin.init(jax.random.PRNGKey(0), i=500)
# print(params)
x = random.normal(jax.random.PRNGKey(0), shape=[1,500])

out1 = lin.apply(x, params).block_until_ready()
out2 = linear(x, params['w_o_i'], params['b_o']).block_until_ready()

print(f"out1.shape: {out1.shape}")
print(f"out2.shape: {out2.shape}")

print("Trials 1: generated vs hand-written")
%timeit -n 100 out1 = lin.apply(x, params).block_until_ready()
%timeit -n 100 out2 = linear(x, params['w_o_i'], params['b_o']).block_until_ready()
print("Trials 2: hand-written vs generated")
%timeit -n 100 out2 = linear(x, params['w_o_i'], params['b_o']).block_until_ready()
%timeit -n 100 out1 = lin.apply(x, params).block_until_ready()


print(f"jnp.isclose(out1, out2).mean(): {jnp.isclose(out1, out2).mean()}")
print(f"jnp.abs(lin.apply(x, params) - linear(x, params['w_o_i'], params['b_o'])).mean(): {jnp.abs(lin.apply(x, params) - linear(x, params['w_o_i'], params['b_o'])).mean()}")

out1.shape: (1, 200)
out2.shape: (1, 200)
Trials 1: generated vs hand-written
10.7 µs ± 2.42 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
9.86 µs ± 794 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
Trials 2: hand-written vs generated
9.81 µs ± 1.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
10.3 µs ± 718 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
jnp.isclose(out1, out2).mean(): 1.0
jnp.abs(lin.apply(x, params) - linear(x, params['w_o_i'], params['b_o'])).mean(): 0.0


In [28]:
class Linear2:
    def __init__(self, o):
        self.old_init = self.init
        self.init = lambda *args, **kwargs: self.old_init(*args, o=o, **kwargs)

    @staticmethod
    def init(key, i, o, **kwargs):
        keys = jax.random.split(key, 2)
        return {
            "w_o_i": jax.random.normal(keys[0], shape=[o, i]),
            "b_o":   jax.random.normal(keys[1], shape=[o])
        }
    
    @staticmethod
    @jit
    @lambda apply: vmap(apply, in_axes=(0, None))
    def apply(x_i, params):
        z_o = x_i.sum() + x_i.shape[0]
        return z_o

class Linear3:
    def __init__(self, o):
        self.old_init = self.init
        self.init = lambda *args, **kwargs: self.old_init(*args, o=o, **kwargs)

    @staticmethod
    def init(key, i, o, **kwargs):
        keys = jax.random.split(key, 2)
        return {
            "w_o_i": jax.random.normal(keys[0], shape=[o, i]),
            "b_o":   jax.random.normal(keys[1], shape=[o])
        }
    
    @staticmethod
    @jit
    @lambda apply: vmap(apply, in_axes=(0, None))
    def apply(x_i, params):
        z_o = jnp.einsum('i,i', x_i, jnp.ones(x_i.shape[0]),)
        return z_o

input_size = 50000
key = jax.random.PRNGKey(0)
x = random.normal(key, shape=[1,input_size])
lin2 = Linear2(o=200)
params2 = lin2.init(jax.random.PRNGKey(0), i=input_size)
lin3 = Linear3(o=200)
params3 = lin3.init(jax.random.PRNGKey(0), i=input_size)

out2 = lin2.apply(x, params2).block_until_ready()
out3 = lin2.apply(x, params3).block_until_ready()
print(f"out2.shape: {out2.shape}")
print(f"out3.shape: {out3.shape}")

print("Trials 1: generated vs hand-written")
key, subkey = jax.random.split(key, 2)
x = random.normal(subkey, shape=[1,input_size])
%timeit -n 100 out2 = lin2.apply(x, params2).block_until_ready()
%timeit -n 100 out3 = lin3.apply(x, params3).block_until_ready()
print("Trials 2: hand-written vs generated")
key, subkey = jax.random.split(key, 2)
x = random.normal(subkey, shape=[1,input_size])
%timeit -n 100 out3 = lin3.apply(x, params2).block_until_ready()
%timeit -n 100 out2 = lin2.apply(x, params3).block_until_ready()
print("Trials 3: hand-written vs generated")
key, subkey = jax.random.split(key, 2)
x = random.normal(subkey, shape=[1,input_size])
%timeit -n 100 out2 = lin2.apply(x, params2).block_until_ready()
%timeit -n 100 out3 = lin3.apply(x, params3).block_until_ready()


print(f"jnp.isclose(out2, out3).mean(): {jnp.isclose(out2, out3).mean()}")


out2.shape: (1,)
out3.shape: (1,)
Trials 1: generated vs hand-written
16.6 µs ± 862 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
83.1 µs ± 33.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Trials 2: hand-written vs generated
68.3 µs ± 682 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
16.8 µs ± 637 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
Trials 3: hand-written vs generated
16.5 µs ± 536 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
68.6 µs ± 1.23 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
jnp.isclose(out2, out3).mean(): 1.0


In [None]:

# s = """
# function MultiheadSelfAttention(x)
#     x[t,i] = x[t,i] * gx[i] * (x[t,i_1]^2 * 1[i_1] / (x[t,i_2] * 1[i_2]))^0.5
#     q[h,t,j] = q[h,j,i] * x[t,i]
#     k[h,t,j] = k[h,j,i] * x[t,i]
#     v[h,t,j] = v[h,j,i] * x[t,i]
#     a[h,t_1,t_2] = jnp.exp(q[h,t_1,j] * k[h,t_2,j])
#     u[t,k] = activation(wu[h,j,k] * a[h,t,t_2] * v[h,t_2,j] + bu[k])
#     z[t,i] = wz[i,k] * u[t,k] + bz[i]
#     z[t,i] = z[t,i] * gz[i] * (z[t,i]^2 * 1[i] / (z[t,i] + 1[i]))
#     return z[t,i]
# end


In [21]:
make_jax_module("""
function MultiheadSelfAttention(x)
    x[t,i] = x[t,i] * gx[i]
    q[h,t,j] = q[h,j,i] * x[t,i]
    k[h,t,j] = k[h,j,i] * x[t,i]
    v[h,t,j] = v[h,j,i] * x[t,i]
    a[h,t_1,t_2] = jnp.exp(q[h,t_1,j] * k[h,t_2,j])
    u[t,k] = activation(wu[h,j,k] * a[h,t,t_2] * v[h,t_2,j] + bu[k])
    z[t,i] = wz[t,k] * u[t,k] + bz[i]
    return z[t,i]
end
""")
print(jax_codegen(s))

In [22]:
def activation(x):
    # Squared activation function
    return x**2

In [31]:
class MultiheadSelfAttention:
    def __init__(self, k, j, h):
        self.old_init = self.init
        self.init = lambda *args, **kwargs: self.old_init(*args, k=k, j=j, h=h, **kwargs)

    @staticmethod
    def init(key, i, t, k, j, h, **kwargs):
        keys = jax.random.split(key, 8)
        return {
            "gx_i": jax.random.normal(keys[0], shape=[i]),
            "q_h_j_i": jax.random.normal(keys[1], shape=[h, j, i]),
            "k_h_j_i": jax.random.normal(keys[2], shape=[h, j, i]),
            "v_h_j_i": jax.random.normal(keys[3], shape=[h, j, i]),
            "wu_h_j_k": jax.random.normal(keys[4], shape=[h, j, k]),
            "bu_k": jax.random.normal(keys[5], shape=[k]),
            "wz_t_k": jax.random.normal(keys[6], shape=[t, k]),
            "bz_i": jax.random.normal(keys[7], shape=[i])
        }
    
    @staticmethod
    @lambda apply: vmap(apply, in_axes=(0, None))
    def apply(x_t_i, params):
        x_t_i = ((x_t_i*params['gx_i'][None, :])).sum(axis=[])
        q_h_t_j = ((params['q_h_j_i'][:, None, :, :]*x_t_i[None, :, None, :])).sum(axis=[3])
        k_h_t_j = ((params['k_h_j_i'][:, None, :, :]*x_t_i[None, :, None, :])).sum(axis=[3])
        v_h_t_j = ((params['v_h_j_i'][:, None, :, :]*x_t_i[None, :, None, :])).sum(axis=[3])
        a_h_t_t = (jnp.exp((q_h_t_j[:, :, None, :]*k_h_t_j[:, None, :, :]))).sum(axis=[3])
        u_t_k = (activation(((params['wu_h_j_k'].transpose((2, 0, 1))[None, :, :, :, None]*a_h_t_t.transpose((1, 0, 2))[:, None, :, None, :]*v_h_t_j.transpose((0, 2, 1))[None, None, :, :, :]).sum(axis=[2, 3, 4], keepdims=True)+(params['bu_k'][None, :, None, None, None]).sum(axis=[2, 3, 4], keepdims=True)))).sum(axis=[2, 3, 4])
        z_t_i = (((params['wz_t_k'][:, None, :]*u_t_k[:, None, :]).sum(axis=[2], keepdims=True)+(params['bz_i'][None, :, None]).sum(axis=[2], keepdims=True))).sum(axis=[2])
        return z_t_i


attention = MultiheadSelfAttention(k=2, j=2, h=2)
params = attention.init(jax.random.PRNGKey(0), i=2, t=2)
# print(params)
x = jnp.array([[[0,1], [2,3]]])
out = attention.apply(x, params)
print(out)

[[[-2.3988705e+02 -2.3823891e+02]
  [-1.1472459e+07 -1.1472458e+07]]]
