In [1]:
import jax
from jax import numpy as jnp
from jax import random, vmap, jit, grad
from einstein import jax_codegen, build_jax_module
import pyperclip



In [2]:
s = """
function Linear(x)
    z[o] = w[o,i] * x[i] + b[o]
    return z[o]
end
"""
c = jax_codegen(s)
print(c)
copy = False
if copy:
    pyperclip.copy(c)
else:
    exec(c)

class Linear:
    def __init__(self, o):
        self.old_init = self.init
        self.init = lambda *args, **kwargs: self.old_init(*args, o=o, **kwargs)

    @staticmethod
    def init(key, i, o, **kwargs):
        keys = jax.random.split(key, 2)
        return {
            "w_o_i": jax.random.normal(keys[0], shape=[o, i]),
            "b_o":   jax.random.normal(keys[1], shape=[o])
        }
    
    @staticmethod
    @jit
    @lambda apply: vmap(apply, in_axes=(0, None))
    def apply(x_i, params):
        z_o = (jnp.einsum('oi,i->o', params['w_o_i'], x_i)) + params['b_o']
        return z_o


In [3]:
def activation(x):
    return x**2

In [5]:
s = """
function MultiheadSelfAttention(x)
    x[t,i] = x[t,i] * gx[i] * (x[t,i_1]^2 * 1[i_1] / (x[t,i_2] * 1[i_2]))^0.5
    q[h,t,j] = q[h,j,i] * x[t,i]
    k[h,t,j] = k[h,j,i] * x[t,i]
    v[h,t,j] = v[h,j,i] * x[t,i]
    a[h,t_1,t_2] = jnp.exp(q[h,t_1,j] * k[h,t_2,j])
    u[t,k] = activation(wu[h,j,k] * a[h,t,t_2] * v[h,t_2,j] + bu[k])
    z[t,i] = wz[i,k] * u[t,k] + bz[i]
    z[t,i] = z[t,i] * gz[i] * (z[t,i]^2 * 1[i] / (z[t,i] + 1[i]))
    return z[t,i]
end
"""
MultiheadSelfAttention = build_jax_module(s)

c = jax_codegen(s)
exec(c)
# print(c)
# copy = False
# pyperclip.copy(c)
# if copy:
#     pyperclip.copy(c) 
# else:
#     exec(c)

<function __main__.MultiheadSelfAttention.__init__.<locals>.<lambda>(*args, **kwargs)>

In [6]:
class MultiheadSelfAttention:
    def __init__(self, j, k, h): 
        self.old_init = self.init
        self.init = lambda *args, **kwargs: self.old_init(*args, j=j, k=k, h=h, **kwargs)

    @staticmethod
    def init(key, t, i, j, k, h, **kwargs):
        keys = jax.random.split(key, 9)
        return {
            "gx_i":     jax.random.normal(keys[0], shape=[i]),
            "q_h_j_i":  jax.random.normal(keys[1], shape=[h, j, i]),
            "k_h_j_i":  jax.random.normal(keys[2], shape=[h, j, i]),
            "v_h_j_i":  jax.random.normal(keys[3], shape=[h, j, i]),
            "wu_h_j_k": jax.random.normal(keys[4], shape=[h, j, k]),
            "bu_k":     jax.random.normal(keys[5], shape=[k]),
            "wz_i_k":   jax.random.normal(keys[6], shape=[i, k]),
            "bz_i":     jax.random.normal(keys[7], shape=[i]),
            "gz_i":     jax.random.normal(keys[8], shape=[i])
        }
    
    @staticmethod
    @jit
    @lambda apply: vmap(apply, in_axes=(0, None))
    def apply(x_t_i, params):
        x_t_i = jnp.einsum('ti,i,t->ti', x_t_i, params['gx_i'], (jnp.einsum('ti,t->t', (x_t_i**2), (jnp.einsum('ti->t', x_t_i)**-1))**0.5))
        q_h_t_j = jnp.einsum('hji,ti->htj', params['q_h_j_i'], x_t_i)
        k_h_t_j = jnp.einsum('hji,ti->htj', params['k_h_j_i'], x_t_i)
        v_h_t_j = jnp.einsum('hji,ti->htj', params['v_h_j_i'], x_t_i)
        a_h_t_t = jnp.exp(jnp.einsum('htj,haj->hta', q_h_t_j, k_h_t_j))
        u_t_k = activation((jnp.einsum('hjk,hta,haj->tk', params['wu_h_j_k'], a_h_t_t, v_h_t_j) + params['bu_k'][None, :]))
        z_t_i = jnp.einsum('ik,tk->ti', params['wz_i_k'], u_t_k) + params['bz_i'][None, :]
        z_t_i = jnp.einsum('ti,i,ti->ti', z_t_i, params['gz_i'], jnp.einsum('ti,ti->ti', (z_t_i**2), ((z_t_i)**-1)))
        return z_t_i


In [9]:
attention = MultiheadSelfAttention(h=2,k=3,j=5)
params = attention.init(key=random.PRNGKey(0), t=None, i=7)
x = jnp.ones((1,3,7))
y = attention.apply(x, params)

In [33]:
tc = makeclass()
tc(j=2, k=3, h=5)

<__main__.makeclass.<locals>.testclass at 0x28c939940>

In [1]:
print(eval("1+2"))

3
