In [1]:
import jax
from jax import lax
from jax import random
import jax.numpy as jnp
import flax.linen as nn
import flax

import optax

In [2]:
from transformer_attention import MSALayerConfig
from transformer import *

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
B = 16
H = 1
L = 10
C = 30
key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (L, C))
x.shape

(10, 30)

In [5]:
config = MSALayerConfig(n_heads=H, qk_dim=C*H, v_dim=C*9, out_dim=C)
model = TransformerEncoder(n_layer=3, MSAConfig=config, filter_size=C, hidden_size=C, dropout=0.0)

In [6]:
params = model.init(key2, x)
params

FrozenDict({
    params: {
        EncoderBlock_0: {
            LayerNorm_0: {
                scale: DeviceArray([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                             1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],            dtype=float32),
                bias: DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],            dtype=float32),
            },
            MultiHeadAttention_0: {
                Query linear: {
                    kernel: DeviceArray([[-2.60007977e-01, -8.37481394e-02, -1.35079414e-01,
                                   1.10644929e-01, -4.80784401e-02,  6.80456385e-02,
                                  -2.11267516e-01, -5.54012135e-02, -1.58904016e-01,
                                   2.98679531e-01, -1.61373034e-01,  1.75812006e-01,
                                  -2.76106987e-02,  2.69401409e-

In [7]:
y = model.apply(params, x)

In [8]:
y.shape

(10, 30)

In [9]:
jnp.allclose(x, y)

DeviceArray(False, dtype=bool)

In [10]:
@jax.jit
def cross_entropy(params, x_batched, y_batched):
    # Define the squared loss for a single pair (x,y)
    def loss(x, y):
        pred = model.apply(params, x)
        return jnp.sum(-y*jnp.loss)
    
    # Vectorize the previous to compute the average of the loss on all samples.
    return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)

In [11]:
learning_rate = 0.0003
tx = optax.sgd(learning_rate=learning_rate)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)

In [12]:
batched_x = random.uniform(key1, (B, L, C))

In [13]:
loss, _ = loss_grad_fn(params, batched_x, batched_x)
loss

DeviceArray(537.1627, dtype=float32)

In [14]:
for i in range(5000):
    loss, grads = loss_grad_fn(params, batched_x, batched_x)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 100 == 0:
        print('Loss step {}: '.format(i), loss)

Loss step 0:  537.1627
Loss step 100:  18.904345
Loss step 200:  10.639772
Loss step 300:  6.9435534
Loss step 400:  5.0836244
Loss step 500:  4.240226
Loss step 600:  4.2816935
Loss step 700:  2.9581003
Loss step 800:  2.8170025
Loss step 900:  2.152847
Loss step 1000:  2.0779424
Loss step 1100:  2.633573
Loss step 1200:  1.6582034
Loss step 1300:  1.5454391
Loss step 1400:  1.4353138
Loss step 1500:  1.4591608
Loss step 1600:  1.4317627
Loss step 1700:  1.1593287
Loss step 1800:  1.1122348
Loss step 1900:  1.1646137
Loss step 2000:  1.2310821
Loss step 2100:  1.129478
Loss step 2200:  0.91813827
Loss step 2300:  0.82637453
Loss step 2400:  0.83169127
Loss step 2500:  0.94297504
Loss step 2600:  0.87845486
Loss step 2700:  0.7650429
Loss step 2800:  0.7810136
Loss step 2900:  0.7004932
Loss step 3000:  0.77391696
Loss step 3100:  0.70781016
Loss step 3200:  0.69692147
Loss step 3300:  0.7261093
Loss step 3400:  0.5952754
Loss step 3500:  0.71668327
Loss step 3600:  0.55830526
Loss ste

KeyboardInterrupt: 

In [16]:
pred_batched = jax.vmap(model.apply, (None, 0))(params, batched_x)

In [17]:
jnp.allclose(pred_batched, batched_x, 1, 1)

DeviceArray(True, dtype=bool)