In [1]:
import jax
from jax import lax
from jax import random
import jax.numpy as jnp
import flax.linen as nn
import flax

import optax

In [2]:
from transformer_attention import MultiHeadAttention, MSALayerConfig

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
B = 16
H = 1
L = 10
C = 30
key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (L, C))
x.shape

(10, 30)

In [5]:
config = MSALayerConfig(n_heads=H, qk_dim=C*H, v_dim=C*9, out_dim=C)
model = MultiHeadAttention(config=config)

In [6]:
params = model.init(key2, x, x)
params

FrozenDict({
    params: {
        Query linear: {
            kernel: DeviceArray([[-6.42497465e-02, -2.36821896e-03, -2.55484164e-01,
                           3.02393377e-01, -2.65125394e-01,  1.69518664e-01,
                           1.76010817e-01,  8.86650756e-02,  8.82424116e-02,
                          -1.06638610e-01,  6.37522936e-02, -1.39350057e-01,
                          -6.07129885e-03, -2.55048394e-01,  2.64953285e-01,
                           2.44156733e-01,  1.37580842e-01,  2.75281638e-01,
                          -1.24030180e-01, -2.97418624e-01,  8.40861350e-02,
                          -6.57675937e-02, -7.03287348e-02, -2.19439149e-01,
                           1.75009206e-01, -2.25424722e-01, -2.66799536e-02,
                           1.88050508e-01,  2.00266913e-01,  1.35190845e-01],
                         [ 2.83158809e-01,  1.33366972e-01,  2.64529467e-01,
                          -2.77500510e-01, -4.56557870e-02,  2.95003206e-01,
                

In [7]:
y = model.apply(params, x, x)

In [8]:
y.shape

(10, 30)

In [9]:
jnp.allclose(x, y)

DeviceArray(False, dtype=bool)

In [10]:
def squared_error(x, y):
    pred = model.apply(params, x, x)
    return jnp.sum(pred**2 - y**2)

squared_error(x, x)

DeviceArray(-76.99953, dtype=float32)

In [11]:
@jax.jit
def mse(params, x_batched, y_batched):
    # Define the squared loss for a single pair (x,y)
    def squared_error(x, y):
        pred = model.apply(params, x, x)
        return jnp.sum((pred - y)**2)
    # Vectorize the previous to compute the average of the loss on all samples.
    return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)

In [12]:
learning_rate = 0.003
tx = optax.sgd(learning_rate=learning_rate)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)

In [13]:
batched_x = random.uniform(key1, (B, L, C))

In [14]:
loss, _ = loss_grad_fn(params, batched_x, batched_x)
loss

DeviceArray(138.03511, dtype=float32)

In [15]:
for i in range(5000*10):
    loss, grads = loss_grad_fn(params, batched_x, batched_x)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 100 == 0:
        print('Loss step {}: '.format(i), loss)

Loss step 0:  138.03511
Loss step 100:  23.783562
Loss step 200:  23.143154
Loss step 300:  22.839537
Loss step 400:  22.657236
Loss step 500:  22.52399
Loss step 600:  22.410793
Loss step 700:  22.30418
Loss step 800:  22.21812
Loss step 900:  22.151695
Loss step 1000:  21.952236
Loss step 1100:  21.804718
Loss step 1200:  21.634573
Loss step 1300:  21.469627
Loss step 1400:  21.231966
Loss step 1500:  20.899384
Loss step 1600:  20.570944
Loss step 1700:  20.155554
Loss step 1800:  20.35525
Loss step 1900:  19.068459
Loss step 2000:  18.412138
Loss step 2100:  18.295933
Loss step 2200:  16.923866
Loss step 2300:  16.576612
Loss step 2400:  15.351936
Loss step 2500:  14.360514
Loss step 2600:  13.3552
Loss step 2700:  13.264492
Loss step 2800:  11.811957
Loss step 2900:  10.72107
Loss step 3000:  9.66177
Loss step 3100:  8.820654
Loss step 3200:  8.052191
Loss step 3300:  7.342058
Loss step 3400:  6.823267
Loss step 3500:  6.135132
Loss step 3600:  5.648237
Loss step 3700:  5.231432
Lo

Loss step 29100:  0.06575925
Loss step 29200:  0.06534882
Loss step 29300:  0.06496289
Loss step 29400:  0.06458356
Loss step 29500:  0.06419725
Loss step 29600:  0.06382203
Loss step 29700:  0.063457176
Loss step 29800:  0.06308693
Loss step 29900:  0.062723026
Loss step 30000:  0.06236265
Loss step 30100:  0.061987653
Loss step 30200:  0.061651684
Loss step 30300:  0.061287355
Loss step 30400:  0.06095719
Loss step 30500:  0.060596384
Loss step 30600:  0.060249977
Loss step 30700:  0.05991205
Loss step 30800:  0.059572183
Loss step 30900:  0.059256677
Loss step 31000:  0.058930382
Loss step 31100:  0.0586063
Loss step 31200:  0.058288578
Loss step 31300:  0.057960138
Loss step 31400:  0.057663612
Loss step 31500:  0.05733236
Loss step 31600:  0.05704167
Loss step 31700:  0.056731645
Loss step 31800:  0.056412503
Loss step 31900:  0.05613278
Loss step 32000:  0.0558224
Loss step 32100:  0.05553657
Loss step 32200:  0.055230577
Loss step 32300:  0.054948535
Loss step 32400:  0.05465872

In [16]:
pred_batched = jax.vmap(model.apply, (None, 0, 0))(params, batched_x, batched_x)

In [17]:
jnp.allclose(pred_batched, batched_x, 0.1, 0.1)

DeviceArray(True, dtype=bool)