In [1]:
import jax
import treescope
from IPython.display import display as ipython_display
from penzai import pz
from penzai.models import simple_mlp
from penzai.models.transformer.variants import llamalike_common as llamalike_transformer
from simplexity.predictive_models.save_activations import SaveActivations

In [2]:
treescope.basic_interactive_setup(autovisualize_arrays=True)

In [3]:
mlp = simple_mlp.MLP.from_config(name="mlp", init_base_rng=jax.random.key(0), feature_sizes=[8, 32, 32, 8])
ipython_display(mlp)

In [4]:
mlp_activations = pz.StateVariable(value=[], label="activations")
saving_model = pz.select(mlp).at_instances_of(pz.nn.Elementwise).insert_after(SaveActivations(mlp_activations))
ipython_display(saving_model)

In [5]:
x = pz.nx.ones({"features": 8})
output = saving_model(x)
ipython_display(mlp_activations)

In [6]:
vocab_size = 8
config = llamalike_transformer.LlamalikeTransformerConfig(
    num_kv_heads=2,
    query_head_multiplier=2,
    embedding_dim=16,
    projection_dim=16,
    mlp_hidden_dim=16,
    num_decoder_blocks=2,
    vocab_size=vocab_size,
    mlp_variant="geglu_approx",
    tie_embedder_and_logits=True,
)
transformer = llamalike_transformer.build_llamalike_transformer(
    config, init_base_rng=jax.random.key(0), name="llama_like"
)
ipython_display(transformer)

In [7]:
transformer_activations = pz.StateVariable(value=[], label="transformer_activations")
saving_transformer = (
    pz.select(transformer).at_instances_of(pz.nn.Residual).insert_after(SaveActivations(transformer_activations))
)
ipython_display(saving_transformer)

In [8]:
key = jax.random.PRNGKey(0)
batch_size = 4
sequence_length = 16
sequences = jax.random.randint(key, (batch_size, sequence_length), 0, vocab_size)
x = pz.nx.wrap(sequences, "batch", "seq")
ipython_display(x)

In [9]:
onehot_sequences = jax.nn.one_hot(sequences, num_classes=vocab_size)
onehot_sequences = pz.nx.wrap(onehot_sequences, "batch", "seq", "features")
ipython_display(onehot_sequences)

In [10]:
output = saving_transformer(x)
ipython_display(transformer_activations)