In [None]:
from src.model import TransformerLM
from src.trainer import (
    create_learning_rate_schedule,
    setup_initial_state
)
from src.steps import train_step, eval_step
import optax
import jax
from flax import nnx
from jax import random
from flax.training import common_utils
from tqdm import tqdm
from src.confings import TransformerConfig, TrainerConfig
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from datasets import load_dataset
import jax.numpy as jnp
from itertools import chain
from flax.training import checkpoints
from transformers import PreTrainedTokenizerFast
import jax.random as random



In [11]:
def const(config, key):
    return TransformerLM(config, rngs=nnx.Rngs(params=key))

In [None]:
trainer_config = TrainerConfig()
transformer_config = TransformerConfig(vocab_size=30_000, emb_dim=256, num_heads=4,
                                        num_layers=3, qkv_dim=256, mlp_dim=1024)

In [13]:
mesh_devices = mesh_utils.create_device_mesh([1, 1, 1])
mesh = Mesh(mesh_devices, ('data', 'fsdp', 'tensor'))

In [14]:
learning_rate_fn = create_learning_rate_schedule(
    learning_rate=trainer_config.learning_rate, warmup_steps=trainer_config.warmup_steps
)
tx = optax.adamw(learning_rate_fn, b1=0.9, b2=0.98, eps=1e-9,
                  weight_decay=trainer_config.weight_decay)

In [15]:
start_step = 0
rng = jax.random.PRNGKey(42)
rng, init_rng = jax.random.split(rng)
rng, inference_rng = random.split(rng)
dropout_rngs = rng

In [16]:
state, state_sharding = setup_initial_state(
    const, tx, transformer_config, init_rng, mesh
)

data_sharding = NamedSharding(mesh, P(('data',)))

In [None]:
jit_train_step = jax.jit(
    train_step,
    in_shardings=(state_sharding, data_sharding, None),
    out_shardings=(state_sharding, None),
    static_argnums=(2, 3),
    donate_argnums=0
)

jit_eval_step = jax.jit(
    eval_step,
    in_shardings=(
      state_sharding.params,
      data_sharding,
    ),  # type: ignore
    out_shardings=None,  # type: ignore
    static_argnums=(2, 3),
  )

In [None]:
tokenizer_hf = PreTrainedTokenizerFast.from_pretrained('ZurabDz/bpe_tokenizer_tmp', token='<>')

In [None]:
dataset = load_dataset('DKYoon/SlimPajama-6B', num_proc=6, split='train')

Resolving data files:   0%|          | 0/48 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/48 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/48 [00:00<?, ?it/s]

In [12]:
size = 500_000
small_dataset = dataset.select(range(0, size))

In [13]:
def tokenize_and_pad(batch):
    return {'ids': tokenizer_hf(batch['text'])['input_ids']}

def group_texts(examples, block_size=128):
    # Concatenate all texts.
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    if total_length >= block_size:
        total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    return result

In [14]:
mapped_dataset = small_dataset.map(tokenize_and_pad, batched=True, num_proc=8, remove_columns=small_dataset.column_names)

In [15]:
grouped = mapped_dataset.map(group_texts, batched=True, num_proc=8)

In [None]:
batch_size = 80
iter_ds = grouped.iter(batch_size)
train_metrics = []

In [None]:
for step in tqdm(range(start_step, len(grouped) // batch_size)):
    batch = next(iter_ds)
    jaxed_batch = jnp.array(batch['ids'])
    state, metrics = jit_train_step(
        state, jaxed_batch, learning_rate_fn, 0.0, dropout_rngs
    )
    train_metrics.append(metrics)

    if step % 500 == 0:
        train_metrics = common_utils.stack_forest(train_metrics)
        lr = train_metrics.pop('learning_rate').mean()
        metrics_sums = jax.tree.map(jnp.sum, train_metrics)
        denominator = metrics_sums.pop('denominator')
        summary = jax.tree.map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
        summary['learning_rate'] = lr
        summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), max=1.0e4)
        summary = {'train_' + k: v for k, v in summary.items()}
        print(summary)
        train_metrics = []

    if step % 2000 == 0:
        rng_params = state.rest['decoder']['encoderdecoderblock_0']['attention']['rngs']['params']
        rng_params['key'].value = random.key_data(rng_params['key'].value)
        checkpoints.save_checkpoint_multiprocess(trainer_config.output_dir, state, step)

  0%|          | 48/51228 [00:16<4:58:00,  2.86it/s] 


KeyboardInterrupt: 