In [1]:
import jax
import jax.numpy as jnp
import optax
import treescope
from IPython.display import display as ipython_display
from penzai.core.named_axes import NamedArray
from penzai.models.transformer.model_parts import TransformerLM

from simplexity.generative_processes.hidden_markov_model import HiddenMarkovModel
from simplexity.generative_processes.state_sampler import StateSampler
from simplexity.load_objects import load_config, load_objects


In [2]:
treescope.basic_interactive_setup(autovisualize_arrays=True)

In [3]:
from typing import Any

from penzai import pz


@pz.pytree_dataclass
class SaveActivations(pz.nn.Layer):
    """Layer to save activations."""

    saved_activations: pz.StateVariable[list[Any]]

    def __call__(self, activations: Any, **unused_side_inputs) -> Any:
        """Save activations as a side effect."""
        self.saved_activations.value = self.saved_activations.value + [activations]
        return activations

In [4]:
from simplexity.persistence.s3_persister import S3Persister
from simplexity.predictive_models.types import ModelFramework

persister = S3Persister.from_config("/workspaces/simplexity/config.ini", ModelFramework.Penzai)

cfg = load_config("/workspaces/simplexity/simplexity/configs", "load.yaml")
d = load_objects(cfg)

generative_process: HiddenMarkovModel = d["generative_process"]
state_sampler: StateSampler = d["state_sampler"]
unbound_model: TransformerLM = d["model"]

activations = pz.StateVariable(value=[], label="activations")
saving_model = (pz.select(unbound_model).at_instances_of(pz.nn.Residual).insert_after(SaveActivations(activations)))
ipython_display(saving_model)

In [5]:
import equinox as eqx

from simplexity.training.train_penzai_model import generate_data_batch

sample_states = eqx.filter_jit(eqx.filter_vmap(state_sampler.sample))

batch_size = 2048
sequence_len = 100

key = jax.random.PRNGKey(0)
state_key, key = jax.random.split(key)
state_keys = jax.random.split(state_key, batch_size)

states = sample_states(state_keys)

_, inputs, labels = generate_data_batch(
    states,
    generative_process,
    batch_size=batch_size,
    sequence_len=sequence_len,
    key=key,
    bos_token=5,
)

name = jnp.logical_or(inputs == 3, inputs == 4)
named_name = pz.nx.wrap(name, "batch", "seq")
ipython_display(named_name)


In [6]:
results = {}

In [7]:
step = 10
model: TransformerLM = persister.load_weights(unbound_model, step)  # type: ignore

named_inputs = pz.nx.wrap(inputs, "batch", "seq")
named_logits = model(named_inputs)
assert isinstance(named_logits, NamedArray)
logits = named_logits.unwrap("batch", "seq", "vocabulary")
losses = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
named_losses = pz.nx.wrap(losses, "batch", "seq")

# ipython_display(named_losses)



In [None]:
name_losses = [[] for _ in range(batch_size)]
for b in range(batch_size):
    for s in range(sequence_len - 1):
        if name[b, s]:
            if s > 0:
                loss = losses[b, s - 1].item()  # type: ignore
            else:
                loss = -1
            # if loss > 0.01:
            #     print(b, s, name[b, s].item(), losses[b, s-1].item())  # type: ignore
            name_losses[b].append(loss)

# ipython_display(name_losses)

max_len = max(map(len, name_losses))

name_pos_posses = [[] for _ in range(max_len)]
for seq in name_losses:
    for i, loss in enumerate(seq):
        name_pos_posses[i].append(loss)
name_pos_posses[0] = [v for v in name_pos_posses[0] if v != -1]

# ipython_display(name_pos_posses)

mean_losses = jnp.array([jnp.mean(jnp.array(v)) for v in name_pos_posses])
ipython_display(mean_losses)


In [None]:
import matplotlib.pyplot as plt

plt.bar(range(len(mean_losses)), mean_losses)
plt.yscale("log")
plt.xticks(range(len(mean_losses)), list(map(str, range(len(mean_losses)))))
plt.xlabel("Times seen name previously")
plt.ylabel("Loss predicting name")
plt.show()


In [None]:
results[step] = mean_losses
results