In [1]:
from src.model import TransformerLM
from src.trainer import (
    create_learning_rate_schedule,
    setup_initial_state
)
from src.steps import train_step, eval_step
import optax
import jax
from flax import nnx
from jax import random
from flax.training import common_utils
from tqdm.auto import tqdm
from src.confings import TransformerConfig, TrainerConfig
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from datasets import load_dataset
import jax.numpy as jnp
from transformers import PreTrainedTokenizerFast
import jax.random as random
from src.data import process_dataset
import orbax.checkpoint as ocp
import numpy as np

In [2]:
def const(config, key):
    return TransformerLM(config, rngs=nnx.Rngs(params=key))

In [3]:
trainer_config = TrainerConfig()
transformer_config = TransformerConfig(vocab_size=30_000, emb_dim=512, num_heads=8,
                                        num_layers=3, qkv_dim=512, mlp_dim=2048)

In [4]:
mesh_devices = mesh_utils.create_device_mesh([1, 1, 1])
mesh = Mesh(mesh_devices, ('data', 'fsdp', 'tensor'))

learning_rate_fn = create_learning_rate_schedule(
    learning_rate=trainer_config.learning_rate, warmup_steps=trainer_config.warmup_steps
)
tx = optax.adamw(learning_rate_fn, b1=0.9, b2=0.98, eps=1e-9,
                  weight_decay=trainer_config.weight_decay)

In [14]:
start_step = 1
rng = jax.random.PRNGKey(42)
rng, init_rng = jax.random.split(rng)
rng, inference_rng = random.split(rng)
dropout_rngs = rng

In [6]:
state, state_sharding = setup_initial_state(
    const, tx, transformer_config, init_rng, mesh
)

data_sharding = NamedSharding(mesh, P(('data',)))

In [7]:
jit_train_step = jax.jit(
    train_step,
    in_shardings=(state_sharding, data_sharding, None),
    out_shardings=(state_sharding, None),
    static_argnums=(2, 3),
    donate_argnums=0
)

jit_eval_step = jax.jit(
    eval_step,
    in_shardings=(
      state_sharding,
      data_sharding,
    ),  # type: ignore
    out_shardings=None,  # type: ignore
    static_argnums=(2,),
  )

In [8]:
tokenizer_hf = PreTrainedTokenizerFast.from_pretrained('ZurabDz/bpe_tokenizer_tmp', token='<>')
# dataset = load_dataset('DKYoon/SlimPajama-6B', num_proc=6, split='train')
dataset = load_dataset('roneneldan/TinyStories', num_proc=6)

In [9]:
train_dataset = process_dataset(dataset['train'], tokenizer_hf, 14)
eval_dataset = process_dataset(dataset['validation'], tokenizer_hf, 14)

In [10]:
train_iter_ds = train_dataset.iter(trainer_config.train_batch_size)
eval_iter_ds = eval_dataset.iter(trainer_config.eval_batch_size)

train_metrics = []
eval_metrics = []

In [16]:
checkpoint_dir = "/home/penguin/Desktop/microlm/output"

In [17]:
checkpoint_manager = ocp.CheckpointManager(
    ocp.test_utils.erase_and_create_empty(checkpoint_dir),
    options=ocp.CheckpointManagerOptions(
        max_to_keep=5,
        keep_checkpoints_without_metrics=False,
        create=True,
    ),
)


In [18]:
for step in tqdm(range(start_step, min(50_000, len(train_dataset) // trainer_config.train_batch_size))):
    batch = next(train_iter_ds)
    jaxed_batch = jnp.array(batch['ids'])
    state, metrics = jit_train_step(
        state, jaxed_batch, learning_rate_fn, 0.0, dropout_rngs
    )
    train_metrics.append(metrics)

    if step % 1000 == 0:
        train_metrics = common_utils.stack_forest(train_metrics)
        lr = train_metrics.pop('learning_rate').mean()
        metrics_sums = jax.tree.map(jnp.sum, train_metrics)
        denominator = metrics_sums.pop('denominator')
        summary = jax.tree.map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
        summary['learning_rate'] = lr
        summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), max=1.0e4)
        summary = {'train_' + k: v for k, v in summary.items()}

        train_metrics = []
        print("summary: ", summary)

        checkpoint_manager.save(
            step, args=ocp.args.Composite(state=ocp.args.PyTreeSave(state))
        )

        break

    # if (step + 1) % 100 == 0:
    #     eval_iter_ds = eval_dataset.select(range(0, 500)).iter(trainer_config.eval_batch_size)
    #     for batch in tqdm(eval_iter_ds, total=len(eval_dataset), leave=False):
    #         metrics = jit_eval_step(state, jnp.array(batch['ids']))
    #         eval_metrics.append(metrics)


    #     eval_metrics = common_utils.stack_forest(eval_metrics)
    #     eval_metrics_sums = jax.tree.map(jnp.sum, eval_metrics)
    #     eval_denominator = eval_metrics_sums.pop('denominator')
    #     eval_summary = jax.tree.map(
    #         lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
    #         eval_metrics_sums,
    #     )
    #     eval_summary['perplexity'] = jnp.clip(
    #         jnp.exp(eval_summary['loss']), max=1.0e4
    #       )
        
    #     print("eval_summary: ", {'eval_' + k: v for k, v in eval_summary.items()})
    #     eval_metrics = []

    # break
    
checkpoint_manager.close()


  0%|          | 0/49999 [00:00<?, ?it/s]

summary:  {'train_accuracy': Array(0.35568678, dtype=float32), 'train_loss': Array(3.3990588, dtype=float32), 'train_learning_rate': np.float32(0.0013243135), 'train_perplexity': Array(29.935911, dtype=float32)}




In [20]:
with ocp.CheckpointManager(
    checkpoint_dir, options=ocp.CheckpointManagerOptions(read_only=True)
) as read_mgr:
  restored = read_mgr.restore(
      1000,
      # pass in the model_state to restore the exact same State type
      args=ocp.args.Composite(state=ocp.args.PyTreeRestore(item=state))
  )



In [21]:
from typing_extensions import Protocol, runtime_checkable
from src import temperature_sampler
import importlib


@runtime_checkable
class HasCache(Protocol):
  def init_cache(self, input_shape, dtype = jnp.float32): ...

In [23]:
new_state = restored['state']

In [25]:
module = nnx.merge(new_state.graphdef, new_state.params, new_state.keys, new_state.rest)
inputs = jnp.array(tokenizer_hf(["Little cat"])['input_ids'])


In [26]:
importlib.reload(temperature_sampler)

<module 'src.temperature_sampler' from '/home/penguin/Desktop/microlm/src/temperature_sampler.py'>

In [27]:
# TODO(cgarciae): check how pytorch does this.
for _path, m in module.iter_modules():
    if isinstance(m, HasCache):
        input_shape = (inputs.shape[0], transformer_config.max_len, transformer_config.emb_dim)
        m.init_cache(input_shape, dtype=transformer_config.dtype)

graphdef, params, cache, keys, rest = nnx.split(module, nnx.Param, nnx.Cache, nnx.RngKey, ...)

def tokens_ids_to_logits(flat_ids, cache: nnx.State):
    """Token slice to logits from decoder model."""
    # --> [batch * beam, 1, vocab]
    module = nnx.merge(graphdef, params, cache, keys, rest)
    module.set_attributes(deterministic=True, decode=True)
    logits = module(flat_ids, nnx.Rngs(0))
    cache = nnx.state(module, nnx.Cache)
    # Remove singleton sequence-length dimension:
    # [batch, 1, vocab] --> [batch, vocab]
    logits = logits.squeeze(axis=1)
    return logits, cache

# Using the above-defined single-step decoder function, run a
# beam search over possible sequences given input encoding.
seqs = temperature_sampler.temperature_sample(
    jnp.pad(inputs, ((0, 0), (0, 64))),
    cache,
    tokens_ids_to_logits,
    rng,
)

In [28]:
tokenizer_hf.decode(seqs[0])

'Little cat ran around and smiled. The dog was sad.\n\n\n\nHe smiled and said to the rabbit was so excited to help the dog was happy to help and said, so excited. He said, "I want to have a big boy. He can play with the dog\'s mommy. He was proud of'