In [19]:
import sys
import time

import jax
import numpy as np
import jax.numpy as jnp
import orbax
from optax import MaskedNode
from etils import epath

from praxis import base_hyperparams
from praxis import pax_fiddle
from praxis import py_utils
from paxml import checkpoints  # mapped to internal
from paxml import checkpoint_managers
from paxml import train_states
from paxml import trainer_lib
from flax.traverse_util import flatten_dict, unflatten_dict

sys.path.append('/home/lishengping/projects/paxml/paxml')

from paxml.main import get_experiment


try:
    jax.distributed.initialize()
except Exception as error:
    print(f'Error: {error}')
    assert jax.local_device_count() == 8
    

TrainState = train_states.TrainState
instantiate = base_hyperparams.instantiate
CheckpointType = checkpoints.CheckpointType
Checkpointer = checkpoints.Checkpointer
PaxCheckpointHandler = checkpoints.PaxCheckpointHandler
NestedMap = py_utils.NestedMap


experiment_config = get_experiment('tasks.lm.params.c4.C4SpmdGpt37BRoPE')()
task_p = experiment_config.task()
jax_task = instantiate(task_p)

SAVE_INTERVAL_STEPS = 1
options = checkpoint_managers.CheckpointManagerOptions(
      max_to_keep=10,
      save_interval_steps=SAVE_INTERVAL_STEPS,
      cleanup_tmp_directories=True,
  )

checkpointer = Checkpointer(
          PaxCheckpointHandler(
              enforce_restore_shape_check=False,
              use_ocdbt=False,
          )
      )

job_log_dir = epath.Path('gs://llm_projects/log/lspdebug0804/checkpoints')
# job_log_dir = epath.Path('gs://llm_base_models/baichuan-7B-easylm')

checkpoint_type = CheckpointType.GDA

checkpoint_manager = checkpoint_managers.OrbaxCheckpointManager(
      job_log_dir,
      checkpointer,
      train_input_checkpointer=False,
      options=options,
      checkpoint_type=checkpoint_type,
      tensorstore_use_ocdbt=False,
  )

Error: distributed.initialize should only be called once.


In [4]:
start = time.time()
print(f'Start load pretrained model params....')
gold_mngr_dir = epath.Path('gs://llm_base_models/baichuan-7B-easylm')
gold_mngr_dir = epath.Path('gs://llm_base_models/orbax_async_test')
gold_item = {
            # 'opt_state': orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler()),
            'params': orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler()),
            # 'step': orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.ArrayCheckpointHandler()),
                }
gold_mngr = orbax.checkpoint.CheckpointManager(gold_mngr_dir, gold_item)

with jax.default_device(jax.devices("cpu")[0]):
    gold_w = gold_mngr.restore(gold_mngr.latest_step())
    
print(f'Load pretrained model params finished, take time: {time.time() - start}s.')

Start load pretrained model params....


I0000 00:00:1691307112.818550   65699 gcs_resource.cc:97] Using default AdmissionQueue with limit 32
I0000 00:00:1691307112.822667   75620 google_auth_provider.cc:179] Running on GCE, using service account 97048824446-compute@developer.gserviceaccount.com


Load pretrained model params finished, take time: 6.7192864418029785s.


In [5]:
paxml_to_mesh_format = {
        ('params', 'lm', 'embedding_lookup', 'emb_var'): 'wte',
        ('params', 'lm', 'transformer', 'repeat', 'sub', 'x_layers_0', 'ff_layer', 'ffn_layer1', 'linear', 'w'): 'w3',
        ('params', 'lm', 'transformer', 'repeat', 'sub', 'x_layers_0', 'ff_layer', 'ffn_layer1_gate', 'linear', 'w'): 'w1',
        ('params', 'lm', 'transformer', 'repeat', 'sub', 'x_layers_0', 'ff_layer', 'ffn_layer2', 'linear', 'w'): 'w2',
        ('params', 'lm', 'transformer', 'repeat', 'sub', 'x_layers_0', 'self_attention', 'query', 'w'): 'wq',
        ('params', 'lm', 'transformer', 'repeat', 'sub', 'x_layers_0', 'self_attention', 'key', 'w'): 'wk',
        ('params', 'lm', 'transformer', 'repeat', 'sub', 'x_layers_0', 'self_attention', 'value', 'w'): 'wv',
        ('params', 'lm', 'transformer', 'repeat', 'sub', 'x_layers_0', 'self_attention', 'post', 'w'): 'wo',
        ('params', 'lm', 'transformer', 'repeat', 'sub', 'x_layers_0', 'layer_norm', 'scale'): 'attention_norm',
        ('params', 'lm', 'transformer', 'repeat', 'sub', 'x_layers_0', 'ff_layer', 'layer_norm', 'scale'): 'ffn_norm',
        ('params', 'lm', 'final_ln', 'scale'): 'ln_f',
        ('params', 'lm', 'softmax', 'logits_ffn', 'linear', 'w'): 'lm_head',
    }

num_heads = experiment_config.NUM_HEADS
model_dims = experiment_config.MODEL_DIMS 
head_dim = model_dims // num_heads

trans_result = {}
with jax.default_device(jax.devices("cpu")[0]):
    for k, v in paxml_to_mesh_format.items():
        values = []
        for gold_key, glod_values in flatten_dict(gold_w['params']).items():
            if v in gold_key:
                if v in 'wqwkwvwo':
                    glod_values = glod_values.reshape(model_dims, num_heads, head_dim)
                values.append([gold_key, glod_values])
        values = sorted(values, key=lambda x: x[0])
        if len(values) > 1:
            stack_values = np.stack(list(zip(*values))[1])
        else:
            stack_values = values[0][1]
        trans_result[k] = stack_values
    opt_state_mv = jax.tree_map(lambda x: jnp.zeros_like(x), trans_result)

print(f'Please simple check model shape and dtype...')
for k, v in trans_result.items():
    print(k, v.shape, v.dtype)


Please simple check model shape and dtype...
('params', 'lm', 'embedding_lookup', 'emb_var') (32000, 4096) float32
('params', 'lm', 'transformer', 'repeat', 'sub', 'x_layers_0', 'ff_layer', 'ffn_layer1', 'linear', 'w') (2, 4096, 11008) float32
('params', 'lm', 'transformer', 'repeat', 'sub', 'x_layers_0', 'ff_layer', 'ffn_layer1_gate', 'linear', 'w') (2, 4096, 11008) float32
('params', 'lm', 'transformer', 'repeat', 'sub', 'x_layers_0', 'ff_layer', 'ffn_layer2', 'linear', 'w') (2, 11008, 4096) float32
('params', 'lm', 'transformer', 'repeat', 'sub', 'x_layers_0', 'self_attention', 'query', 'w') (2, 4096, 32, 128) float32
('params', 'lm', 'transformer', 'repeat', 'sub', 'x_layers_0', 'self_attention', 'key', 'w') (2, 4096, 32, 128) float32
('params', 'lm', 'transformer', 'repeat', 'sub', 'x_layers_0', 'self_attention', 'value', 'w') (2, 4096, 32, 128) float32
('params', 'lm', 'transformer', 'repeat', 'sub', 'x_layers_0', 'self_attention', 'post', 'w') (2, 4096, 32, 128) float32
('params

In [13]:
latest_step =  checkpoint_manager.latest_step()
step = latest_step + SAVE_INTERVAL_STEPS if latest_step is not None else SAVE_INTERVAL_STEPS
print(f'Model save step is {step}')
n_layers = experiment_config.NUM_LAYERS # 模型的层数
# n_layers = 32 # 模型的层数
check_saved_model_fail_or_success = True
start = time.time()
temp_no_prefix, temp_other = {}, {}
for key_tuple, param in opt_state_mv.items():
    if 'repeat' in key_tuple:
        temp_no_prefix[key_tuple] = MaskedNode()
        temp_other[key_tuple] = param
    else:
        temp_no_prefix[key_tuple] = param
        temp_other[key_tuple] = MaskedNode()

temp_no_prefix = unflatten_dict(temp_no_prefix)
temp_other = unflatten_dict(temp_other)
    
no_prefix = {'count': jnp.array(step), 'm': temp_no_prefix, 'v': temp_no_prefix}
other = {'count': jnp.array([step] * n_layers), 'm': temp_other, 'v': temp_other}
trans_opt_states = {
    'no_prefix': [{'count': jnp.array(step)}] * 2 + [no_prefix, {'count': jnp.array(step)}], 
    f'p#{n_layers}#i-1': [{'count': jnp.array([step] * n_layers)}] * 2 + [other, {'count': jnp.array([step] * n_layers)}], 
}
trans_opt_states = [trans_opt_states]


new_trainstate = TrainState(
                            step=jnp.array(step), 
                            mdl_vars=unflatten_dict(trans_result),
                            opt_states=trans_opt_states
)
padded_global_shapes = jax.tree_map(lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype) 
                                    if hasattr(x, 'shape') else x , new_trainstate)
checkpoint_manager.save(step, new_trainstate, padded_global_shapes, train_input_pipeline=None, force=False)
print(f'Saved model finished. take time: {time.time() - start}s !!!')

if check_saved_model_fail_or_success:
    start = time.time()
    print(f'Args check_saved_model_fail_or_success is {check_saved_model_fail_or_success}, start to check model whether saved successful...')
    # fake输入只是为了拿到dtype和shape
    seed = 0
    jax.random.PRNGKey(seed)
    low, high = 0, experiment_config.VOCAB_SIZE
    seq_length = 10
    # batch_size = experiment_config.PERCORE_BATCH_SIZE * 8
    batch_size = 1
    my_sample_input = {}
    my_sample_input['ids'] = np.random.randint(low, high, (batch_size, seq_length)).astype(np.int32)
    my_sample_input['labels'] = my_sample_input['ids'].astype(np.int32)
    my_sample_input['weights'] = np.ones((batch_size, seq_length)).astype(np.float32)
    my_sample_input['paddings'] = my_sample_input['weights']
    my_sample_input['segment_ids'] = my_sample_input['weights'].astype(np.int32)
    my_sample_input['segment_pos'] = np.arange(seq_length).reshape(1, -1).repeat(batch_size, axis=0).astype(np.int32)
    my_sample_input['_seqio_provenance/shard_index'] = np.array([-1]).repeat(batch_size).astype(np.int32)
    my_sample_input['_seqio_provenance/num_shards'] = my_sample_input['_seqio_provenance/shard_index']
    my_sample_input['_seqio_provenance/index_within_shard'] = my_sample_input['_seqio_provenance/shard_index'].astype(np.int64)
    my_sample_input['eval_sample_weights'] = my_sample_input['_seqio_provenance/shard_index'].astype(np.float32)
    my_sample_input = NestedMap(my_sample_input)
    
    inputs_shape_dtype = jax.tree_map(
            lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype),
            my_sample_input,
        )
    train_state_metadata = trainer_lib.create_train_state_metadata(
        jax_task,
        inputs_shape_dtype,
        discard_opt_states=False,
        do_eval=True,
    )
    print(f'Start load model to check whether saved model is True or False')
    device_mesh = py_utils.create_device_mesh(
          jax_task.model.ici_mesh_shape,
          jax_task.model.dcn_mesh_shape,
          contiguous_submeshes=jax_task.model.contiguous_submeshes,
      )
    global_mesh = jax.sharding.Mesh(device_mesh, jax_task.model.mesh_axis_names)
    restore_kwargs = {
              'version': 1.1,
              'specs': train_state_metadata.partition_specs, # shard
              'mesh': global_mesh, # mesh
              'transforms': None, # None
          }
    restore_kwargs = {'state': restore_kwargs}
    items = {'state': train_state_metadata.padded_global_shapes}
    restored_model = checkpoint_manager._manager.restore(step, items=items, restore_kwargs=restore_kwargs)
    print(f'Check model finished. model is  saved successfully. take time: {time.time() - start}')

Model save step is 1320
Saved model finished. take time: 26.856893062591553s !!!
Args check_saved_model_fail_or_success is True, start to check model whether saved successful...
Start load model to check whether saved model is True or False
Check model finished. model is  saved successfully. take time: 9.500797033309937
