In [1]:
# %env JAX_CHECK_TRACER_LEAKS = 1
from models.transformer import Transformer, TransformerConfig
from models.mlp import MLP, MLPConfig
from flax import nnx
import optax
import jax.numpy as jnp
import jax
from functools import partial
from tqdm.auto import tqdm
from typing import Tuple, List, Dict, Any, Callable
import dataclasses



In [2]:

@dataclasses.dataclass(unsafe_hash=True)
class TrainConfig:
  lrs: jax.Array
  num_seeds: int
  criterion: Callable
  model: nnx.Module
  model_config: Any = False

  def replace(self, **kwargs):
    return dataclasses.replace(self, **kwargs)


def init_model(config: TrainConfig):
  rngs = nnx.Rngs(params=0, dropout=1)
  if not config.model_config:
    config.model_config = config.model.config()
  @nnx.jit
  @nnx.split_rngs(splits=(config.num_seeds,))
  @nnx.vmap(in_axes=(nnx.StateAxes({(nnx.Param,'params', 'dropout'): 0, ...: None}),None))
  @nnx.vmap(in_axes=(None,0,))
  def _init_model(rngs, lr):
      model = config.model(config.model_config, rngs)
      optimizer = nnx.Optimizer(model, optax.adam(lr), wrt=nnx.Param)

      metrics = nnx.MultiMetric(
        accuracy=nnx.metrics.Accuracy(threshold=None),
        loss=nnx.metrics.Average('loss'),
      )
      return model, nnx.state(optimizer), metrics
  return _init_model(rngs, config.lrs)

In [3]:


@partial(jax.jit, static_argnames=('criterion'))
@nnx.vmap(in_axes=(0,0,0, 0, 0, None, None)) # data seeds #in_axes=(state_axes, nnx.StateAxes({nnx.Variable:0, ...: None}), 0)
@nnx.vmap(in_axes=(0,0,0, None, 0,0, None)) # model params
def _train_step(graphdef, state, metric_split, data, optimizer_states, lr,criterion):
  model = nnx.merge(graphdef, state)
  metrics = nnx.merge(*metric_split)
  X, y = data
  def loss_fn(model):
    y_pred = model(X)[..., -1, :]
    return criterion(y_pred, y), y_pred
  (loss,logits), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)
  metrics.update(logits=logits, loss=loss, labels=y)
  temp = nnx.Optimizer(model, optax.adamw(lr), wrt=nnx.Param)
  optimizer = nnx.merge(nnx.graphdef(temp), optimizer_states)
  optimizer.update(model, grads)
  return model, nnx.state(optimizer), metrics

def train_step(models, data, optimizer_states, config, metrics):
  models.train()
  graphdef, state = nnx.split(models)
  models, optimizer_states, metrics = _train_step(graphdef, state, nnx.split(metrics)  , data, optimizer_states, config.lrs, config.criterion)
  # models = nnx.merge(graphdef, state)
  # metrics.update(logits=logits, loss=loss, labels=y)
  return models, optimizer_states, metrics

@partial(jax.jit, static_argnames=('criterion'))
@nnx.vmap(in_axes=(0,0, None, None)) # data seeds #in_axes=(state_axes, nnx.StateAxes({nnx.Variable:0, ...: None}), 0)
@nnx.vmap(in_axes=(0,0, None, None)) # model params
# @partial(jax.jit, static_argnames=('criterion', 'data'))
def _eval_step(model, metric, data, criterion):
  # model, metrics = nnx.merge(graphdef, state)
  model = nnx.merge(*model)
  metric = nnx.merge(*metric)
  X, y = data
  logits = model(X)[..., -1, :]
  loss = criterion(logits, y)
  metric.update(logits=logits, loss=loss, labels=y)
  # (logits=logits, loss=loss, labels=y)
  return metric

def eval_step(models, data, config, metrics):
  models.eval()
  # graphdef, state = 
  metrics = _eval_step(nnx.split(models), nnx.split(metrics), data, config.criterion)
  # metrics = nnx.merge(graphdef, state)
  return metrics

@nnx.jit
@nnx.vmap
@nnx.vmap
def reset(metrics):
    metrics.reset()

In [4]:
def create_batch(rng, n: int, d: int, k: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    Returns:
      X_train : (n, d)
      y_train : (n,)
    """
    # rng = rng()
    X   = jax.random.bernoulli(rng, p=0.5, shape=(n, d)).astype(jnp.int32)
    # X = ((jax.random.choice(rng, 2**d, (n,), replace=False)[:, None] >> jnp.arange(d, dtype=jnp.uint32)) & 1).astype(jnp.float16)
    # x = jax.random.choice(rng, 2**d, (n,), replace=False)
    # shifts = jnp.arange(d - 1, -1, -1)
    
    # Add a new axis to input for broadcasting
    # x_expanded = jnp.expand_dims(x, axis=-1)
    
    # Broadcast and perform bit operations in parallel
    # X =  (x_expanded >> shifts) & 1
    
    # y   = - X[:, 0] * X[:, 1]
    y = evaluate_parity(X, k)
    return X, y

def evaluate_parity(x: jnp.ndarray, k: int = 2) -> jnp.ndarray:
    # return jnp.prod(x[..., :k], axis=-1, dtype=jnp.float32)
    return jnp.sum(x[..., :k], axis=-1)% 2
    # return nnx.one_hot(jnp.sum(x[..., :k], axis=-1)% 2, 2, dtype=jnp.int32)
@partial(jax.jit, static_argnums=(1,2,3))
# @nnx.split_rngs(splits=config.num_seeds)
@nnx.vmap(in_axes=(0,None, None, None))
def create_batches(rng, n: int, d: int, k: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
    return create_batch(rng(), n, d, k)



In [5]:

batch_size = 32
d = 20
k = 6


model_config = TransformerConfig(
    vocab_size=2,
    max_len=d,
    embd_dim=256,
    num_heads=8,
    mlp_dim=1024,
    qkv_dim=256,
    num_layers = 2
)
# model_config = MLPConfig(
#     in_dim=d,
#     out_dim=1,
#     hidden_dim=32,
#     hidden_layers=4
# )

config = TrainConfig(
    lrs = jnp.geomspace(1e-4,1e-1,2),
    num_seeds=20,
    # criterion = lambda y_pred, y: optax.squared_error(y_pred.squeeze(-1), y).mean(),
    criterion = lambda y_pred, y: optax.softmax_cross_entropy_with_integer_labels(y_pred, y).mean(),
    model = Transformer,
    model_config = model_config
)

In [6]:
models, optimizer_states, metrics = init_model(config)

In [None]:
data_rng = nnx.Rngs(0)
backup = nnx.split_rngs(data_rng, splits=config.num_seeds)
data = create_batches(data_rng, batch_size, d, k)
models, optimizer_states, metrics = train_step(models, data, optimizer_states, config, metrics)

In [8]:
from IPython.display import clear_output
import matplotlib.pyplot as plt
metrics_history = {
  'train_loss': [],
  'train_accuracy': [],
  'test_loss': [],
  'test_accuracy': [],
}
@nnx.jit
@nnx.vmap
def iterate_rngs(rngs):
  rngs()
num_steps = 2000
data_rng = nnx.Rngs(0)
backup = nnx.split_rngs(data_rng, splits=config.num_seeds)
iterate_rngs(data_rng)
# print(models.lm_head.kernel.value[1][0])
test_data = create_batch(nnx.Rngs(-2)(), min(100, 2**d), d,   k)
for step in tqdm(range(num_steps)):
    data = create_batches(data_rng, batch_size, d, k)
    iterate_rngs(data_rng)
    models, optimizer_states, metrics = train_step(models, data, optimizer_states, config, metrics)
    acc, loss = metrics.compute().values()
    metrics_history[f'train_loss'].append(loss.mean(axis=0).min()) # Record the metrics.
    metrics_history[f'train_accuracy'].append(acc.mean(axis=0).max()) # Record the metrics.
    # Reset the metrics for the test set.
    reset(metrics)
    metrics = eval_step(models, test_data, config, metrics)
    acc, loss = metrics.compute().values()
    metrics_history[f'test_loss'].append(loss.mean(axis=0).min()) # Record the metrics.
    metrics_history[f'test_accuracy'].append(jnp.mean(acc,axis=0).max()) # Record the metrics.
      # Reset the metrics for the test set.
    reset(metrics)
    # for metric, value in metrics.compute().items():  # Compute the metrics.
    #   metrics_history[f'test_{metric}'].append(value)  # Record the metrics.
    # metrics.reset()  # Reset the metrics for the test set.
    # if (step+1) % 30 == 0:
    #   clear_output(wait=True)
    #   fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    #   ax1.set_title('Loss')
    #   ax2.set_title('Accuracy')
    #   for dataset in ('train', 'test'):
    #     ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss')
    #     ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy')
    #   ax1.legend()
    #   ax2.legend()
    #   plt.show()

  0%|          | 0/2000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [183]:
import equinox as eqx
from jax import lax
from jax import numpy as jnp
import jax
from typing import Sequence


def test(temp):
    def f(_x, _y):
        return None,None
    return lax.scan(f, None, temp)


linear = lambda i: eqx.nn.Linear(1,1, key=jax.random.PRNGKey(0))
linears = jax.vmap(linear)(jnp.array([0]))
dynamic_layers, static_layers = eqx.partition(linears, eqx.is_array)
# print(dynamic_layers)
jax.vmap(test)(dynamic_layers)
# no crash

drop = eqx.nn.MultiheadAttention(1,1,1,1,1, key=jax.random.PRNGKey(0))
def drops(i):
    return drop
dropouts = jax.vmap(lambda i: drop)(jnp.array([0]))
def filt(x): # array and not eqx.nn.Dropout
    return eqx.is_array(x) and (not isinstance(x, bool))  and (not x.weak_type)
dynamic_layers, static_layers = eqx.partition(dropouts, filt)#, is_leaf=lambda x: isinstance(x, eqx.nn.Dropout))
print(dynamic_layers)

def check(n):
    return isinstance(n, eqx.nn.Dropout)
print(jax.vmap(check)(dropouts))
print(check(linear))
# print(isinstance(dropouts.inference, Sequence))
print(eqx.filter(dropouts, filt))
eqx.filter_vmap(test)(dynamic_layers)
# no crash


# def dropout(i):
#     temp = eqx.nn.Dropout(0.5)
#     return temp
# dropouts = jax.vmap(dropout)(jnp.array([0]))

# def part(x):
#     return eqx.partition(x, eqx.is_array)
# # dynamic_layers, static_layers = jax.vmap(part)(dropouts)
# # print(dynamic_layers)
# eqx.filter_vmap(test)(dynamic_layers)
# # crash

MultiheadAttention(
  query_proj=Linear(
    weight=f32[1,1,1], bias=None, in_features=1, out_features=1, use_bias=False
  ),
  key_proj=Linear(
    weight=f32[1,1,1], bias=None, in_features=1, out_features=1, use_bias=False
  ),
  value_proj=Linear(
    weight=f32[1,1,1], bias=None, in_features=1, out_features=1, use_bias=False
  ),
  output_proj=Linear(
    weight=f32[1,1,1], bias=None, in_features=1, out_features=1, use_bias=False
  ),
  dropout=Dropout(p=None, inference=bool[1]),
  num_heads=1,
  query_size=1,
  key_size=1,
  value_size=1,
  output_size=1,
  qk_size=1,
  vo_size=1,
  use_query_bias=False,
  use_key_bias=False,
  use_value_bias=False,
  use_output_bias=False
)
[False]
False
MultiheadAttention(
  query_proj=Linear(
    weight=f32[1,1,1], bias=None, in_features=1, out_features=1, use_bias=False
  ),
  key_proj=Linear(
    weight=f32[1,1,1], bias=None, in_features=1, out_features=1, use_bias=False
  ),
  value_proj=Linear(
    weight=f32[1,1,1], bias=None, in_features=

IndexError: tuple index out of range

In [None]:
@nnx.vmap(in_axes=(0,None))
@nnx.vmap(in_axes=(0, None))
@nnx.jit
def test(model, data):
    return model(data)
print(data[0].shape)
data = create_batch(nnx.Rngs(-3)(), 10, d, k)
print(test(models, data[0]).shape)
print(jnp.argmax(test(models, data[0])[0,0,:,-1], axis=-1))
print(data[1])

(20, 32, 20)
(20, 2, 10, 20, 2)
[0 0 0 0 0 0 0 0 0 0]
[1 0 1 1 1 1 0 0 0 0]


In [20]:

def loss_fn(model, X,y):
  y_pred = model(X)[..., -1, :]
  return config.criterion(y_pred, y), y_pred
@jax.jit
@nnx.vmap(in_axes=(0,0,0, 0, 0, None)) # data seeds #in_axes=(state_axes, nnx.StateAxes({nnx.Variable:0, ...: None}), 0)
@nnx.vmap(in_axes=(0,0,0, None, 0,0,)) # model params
def _train_step2(graphdef, state, metric_split, data, optimizer_states, lr):
  model = nnx.merge(graphdef, state)
  # metrics = nnx.merge(*metric_split)
  X, y = data
  (loss,logits), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model, X,y)
  # metrics.update(logits=logits, loss=loss, labels=y)
  temp = nnx.Optimizer(model, optax.adamw(lr), wrt=nnx.Param)
  optimizer = nnx.merge(nnx.graphdef(temp), optimizer_states)
  optimizer.update(model, grads)
  return nnx.state(model), nnx.state(optimizer), None#nnx.state(metrics)
metrics_history = {
  'train_loss': [],
  'train_accuracy': [],
  'test_loss': [],
  'test_accuracy': [],
}
num_steps = 2000
data_rng = nnx.Rngs(0)
# backup = nnx.split_rngs(data_rng, splits=config.num_seeds)
# print(models.lm_head.kernel.value[1][0])
test_data = create_batch(nnx.Rngs(-2), min(1000, 2**d), d,   k)
graphdef, state = nnx.split(models)
metrics_graphdef, metrics_state = nnx.split(metrics)
for step in tqdm(range(num_steps)):
    data = create_batches(data_rng, batch_size, d, k)
    # state, optimizer_states, metrics_state = _train_step2(graphdef, state, None, data, optimizer_states, config.lrs)
    
    # acc, loss = metrics.compute().values()
    # metrics_history[f'train_loss'].append(loss.mean(axis=0).min()) # Record the metrics.
    # metrics_history[f'train_accuracy'].append(acc.mean(axis=0).max()) # Record the metrics.
    # # Reset the metrics for the test set.
    # reset(metrics)
    # metrics = eval_step(models, test_data, config, metrics)
    # acc, loss = metrics.compute().values()
    # metrics_history[f'test_loss'].append(loss.mean(axis=0).min()) # Record the metrics.
    # metrics_history[f'test_accuracy'].append(jnp.mean(acc,axis=0).max()) # Record the metrics.
    #   # Reset the metrics for the test set.
    # reset(metrics)
    # for metric, value in metrics.compute().items():  # Compute the metrics.
    #   metrics_history[f'test_{metric}'].append(value)  # Record the metrics.
    # metrics.reset()  # Reset the metrics for the test set.
    # if (step+1) % 30 == 0:
    #   clear_output(wait=True)
    #   fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    #   ax1.set_title('Loss')
    #   ax2.set_title('Accuracy')
    #   for dataset in ('train'):
    #     ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss')
    #     ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy')
    #   ax1.legend()
    #   ax2.legend()
    #   plt.show()

  0%|          | 0/2000 [00:00<?, ?it/s]

In [None]:
metrics.update(train_loss=train_loss, )

TypeError: Accuracy.update() missing 2 required keyword-only arguments: 'logits' and 'labels'

In [57]:
rngs = nnx.Rngs(params=0, dropout=1)
temp = nnx.split_rngs(rngs, splits=2, only='params')
print(rngs.params.key)
nnx.restore_rngs(rngs)

temp = nnx.split_rngs(rngs, splits=2, only='params')
print(rngs.params.key)
# print(temp[1].params)


[38;2;79;201;177mRngKey[0m[38;2;255;213;3m([0m[38;2;105;105;105m # 1 (8 B)[0m
  [38;2;156;220;254mvalue[0m[38;2;212;212;212m=[0mArray((), dtype=key<fry>) overlaying:
  [0 0],
  [38;2;156;220;254mtag[0m[38;2;212;212;212m=[0m[38;2;207;144;120m'params'[0m
[38;2;255;213;3m)[0m
[38;2;79;201;177mRngKey[0m[38;2;255;213;3m([0m[38;2;105;105;105m # 2 (16 B)[0m
  [38;2;156;220;254mvalue[0m[38;2;212;212;212m=[0mArray((2,), dtype=key<fry>) overlaying:
  [[4165894930  804218099]
   [1353695780 2116000888]],
  [38;2;156;220;254mtag[0m[38;2;212;212;212m=[0m[38;2;207;144;120m'params'[0m
[38;2;255;213;3m)[0m
[38;2;79;201;177mRngKey[0m[38;2;255;213;3m([0m[38;2;105;105;105m # 1 (8 B)[0m
  [38;2;156;220;254mvalue[0m[38;2;212;212;212m=[0mArray((), dtype=key<fry>) overlaying:
  [0 0],
  [38;2;156;220;254mtag[0m[38;2;212;212;212m=[0m[38;2;207;144;120m'params'[0m
[38;2;255;213;3m)[0m
[38;2;79;201;177mRngKey[0m[38;2;255;213;3m([0m[38;2;105;105;105m # 2 