In [9]:
import jax
import optax
import jax.numpy as jnp
import matplotlib.pyplot as plt

from meta_icl.transformer import Transformer
from meta_icl.data import sample_regression_dataset

main_rng = jax.random.key(42)

lr = 3e-3
set_size = 10
input_size = 10
num_epochs = 2000

model = Transformer(
    num_heads=1,
    embedding_size=11,
    key_size=11,
    num_layers=1,
    output_size=1,
    use_layer_norm=False,
    only_attention=False
)

x = jnp.ones((1, 11, 11))
params = model.init(main_rng, x, is_training=False) 
y = model.apply(params, x, is_training=False)

opt_adamw = optax.adamw(learning_rate=lr)
opt_params = opt_adamw.init(params)


def compute_mse_loss(params, model, X, y):
    # TODO: Compare slicing with original codebase
    y_hat = model.apply(params, X, is_training=True)[:, -1, -1] # inference
    loss = jnp.mean((y_hat - y[:, -1]) ** 2)  # MSE loss
    return loss


for epoch in range(num_epochs):
    X, y, _ = sample_regression_dataset(main_rng, input_size, set_size=set_size, input_range=1.0)
    
    # TODO: Refactor step to a func with @jax.jit train_step()
    loss_grad_fn = jax.value_and_grad(compute_mse_loss)
    loss, grads = loss_grad_fn(params, model, X, y)

    updates, opt_params = opt_adamw.update(grads, opt_params, params)
    params = optax.apply_updates(params, updates)

    # Created new params, opt_params
    if epoch % 100 == 0:
        print(epoch, loss)

In [None]:
# Inspect example of the data

In [11]:
indice = 4
X, y, _ = sample_regression_dataset(jax.random.key(42), input_size=1, input_range=1.0)
y.shape, X.shape

# The 0.0 value is the target for the transformer
plt.figure(figsize=(8, 6))
plt.scatter(X[indice, :, 0], X[indice, :, 1], label='Feature', marker='o')
plt.title('Regression Data Visualization')
plt.xlabel('Sequence Index')
plt.ylabel('Value')
plt.legend()
plt.grid(True)
plt.show()

NameError: name 'plt' is not defined