In [1]:
import re
import sys
import os
import jax
import time
from functools import partial

import mlxu
from google.cloud import storage
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as PS
import flax
from flax import linen as nn
from flax.jax_utils import prefetch_to_device
from transformers import GenerationConfig, FlaxLogitsProcessorList

sys.path.append('/home/lishengping/projects/mesh_easy_jax')

from easylm.checkpoint import StreamingCheckpointer
from easylm.llama_model import LLaMAConfig, LLaMAConfig2, FlaxLLaMAForCausalLM, LLaMATokenizer
from easylm.jax_utils import (
    JaxRNG, next_rng, match_partition_rules, tree_apply,
    set_random_seed, get_float_dtype_by_name, make_shard_and_gather_fns,
    with_sharding_constraint, FlaxTemperatureLogitsWarper
)

In [3]:
bucket_name = 'llm_base_models'
directory_path = 'easylm/linli_chinese_llama'
# directory_path = 'llama7b_finetune_mesh_jax_flax'

client = storage.Client()
model_dirs = {}
for blob in client.list_blobs(bucket_name, prefix=directory_path):
    if 'step_' in blob.name:
        step = re.findall('step_(\d+)',blob.name)[0]
        model_dirs[int(step)] = blob.name
model_dirs = sorted(model_dirs.items(), key=lambda x: x[0])
step, model_path = model_dirs[-1]
print(f'model_path: {model_path}')

model_path: easylm/linli_chinese_llama/step_4000/streaming_train_state


In [26]:
FLAGS_DEF = {
    'seed': 42,
    'initialize_jax_distributed': False,
    'mesh_dim': '1,-1,1',
    'dtype': 'bf16',
    'input_length': 1024,
    'seq_length': 2048,
    'top_k': 10,
    'top_p': 0.85,
    'do_sample': True,
    'num_beams': 1,
    'add_bos_token': False,
    'load_llama_config': '',
    'load_checkpoint': '',
}

In [8]:
vocab_file = '/home/lishengping/linli_llama_chinese0606/tokenizer.model'
tokenizer = LLaMATokenizer(
            vocab_file=vocab_file,
            padding_side='left',
            truncation_side='right',
        )

prefix_tokenizer = LLaMATokenizer(
            vocab_file=vocab_file,
            padding_side='right',
            truncation_side='left',
        )

In [9]:
with jax.default_device(jax.devices("cpu")[0]):
    llama_config = LLaMAConfig2.get_default_config()
    llama_config.add_cross_attention = False
    llama_config.is_encoder_decoder = False
    llama_config.cache = True
#     llama_config.num_hidden_layers = 2    
    llama_config.gradient_checkpointing = ''  
    llama_config.output_attentions = False
    llama_config.output_hidden_states = False
    llama_config.return_dict = True
    hf_model = FlaxLLaMAForCausalLM(llama_config, input_shape=(1, 2048), seed=42, _do_init=False)

do_init: False


In [10]:
start = time.time()
with jax.default_device(jax.devices("cpu")[0]):
    checkpoint_config = StreamingCheckpointer.get_default_config({'save_optimizer_state': True})
    checkpointer = StreamingCheckpointer(checkpoint_config, 'output')
    load_checkpoint = f'params::gs://{bucket_name}/{model_path}'
    _, train_state = checkpointer.load_trainstate_checkpoint(load_checkpoint, disallow_trainstate=True)
print(f'load weight time: {time.time() - start}')

if train_state['params'].get('params', None) is not None:
    params = train_state['params']['params']
else:
    params = train_state

start = time.time()  
model_ps = match_partition_rules(LLaMAConfig.get_partition_rules(), params)
shard_fns, _ = make_shard_and_gather_fns(model_ps, get_float_dtype_by_name('bf16'))

print(f'shard weight time: {time.time() - start}')

load weight time: 1534.1503233909607
shard weight time: 0.059808969497680664


In [11]:
@partial(
    pjit,
    in_shardings=(model_ps, PS(), PS(), PS()),
    out_shardings=(PS(), PS())
)
def forward_generate(params, rng, batch, temperature):
    batch = with_sharding_constraint(batch, PS(('dp')))
    rng_generator = JaxRNG(rng)
    output = hf_model.generate(
        batch['input_tokens'],
        attention_mask=batch['attention_mask'],
        params=params['params'],
        prng_key=rng_generator(),
        logits_processor=FlaxLogitsProcessorList(
            [FlaxTemperatureLogitsWarper(temperature)]
        ),
        generation_config=GenerationConfig(
            max_new_tokens=FLAGS_DEF['seq_length'] - FLAGS_DEF['input_length'],
            pad_token_id=tokenizer.eos_token_id,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            do_sample=FLAGS_DEF['do_sample'],
            num_beams=FLAGS_DEF['num_beams'],
            top_k=FLAGS_DEF['top_k'],
            top_p=FLAGS_DEF['top_p'],
        )
    ).sequences[:, batch['input_tokens'].shape[1]:]
    return output, rng_generator()

In [12]:
mesh_dim = '1, 8'
mesh = LLaMAConfig2.get_jax_mesh(mesh_dim)
set_random_seed(42)
with mesh:
    params = tree_apply(shard_fns, params)
    sharded_rng = next_rng()

In [32]:
# @staticmethod
def generate(text, temperature):
    global sharded_rng
    inputs = prefix_tokenizer(
        text,
        padding='max_length',
        truncation=True,
        max_length=FLAGS_DEF['input_length'],
        return_tensors='np',
    )
    input_tokens = inputs.input_ids
    input_mask = inputs.attention_mask
    if FLAGS_DEF['add_bos_token']:
        input_tokens[:, 0] = tokenizer.bos_token_id
        input_mask[:, 0] = 1
    batch = dict(
        input_tokens=input_tokens,
        attention_mask=input_mask,
    )
    with mesh:
        output, sharded_rng = forward_generate(
            params, sharded_rng, batch, temperature
        )
        output = jax.device_get(output)
    output_text = []
    for text in list(tokenizer.batch_decode(output)):
        if tokenizer.eos_token in text:
            text = text.split(tokenizer.eos_token, maxsplit=1)[0]
        output_text.append(text)

    return output_text, output

In [68]:
text = 'human:中国的首都在哪？\nassistant：'
# text = '新北市府南大門今天上午9點10分被人潑白漆，警方調查，一名男子持一桶白色油漆，步行到市府大門口潑漆後，立即騎機車逃逸，'
# text = '周杰伦是谁'
# text = 'Today, I want to'
text = 'human:请用python实现冒泡排序算法\n\nassistant：'
text = 'human:北京有哪些好玩的地方？\n\nassistant：'
text = 'human:23 + 56 = ？\n\nassistant：'
text = 'human:\n写一篇300字的小说\n\nassistant:\n'
text = 'human:\n不积跬步无以至千里，扩写\n\nassistant:\n'
text = 'Human:\n美国总统是谁\n\nAssistant:\n'
text = 'Human:\n中国主席是？\n\nAssistant:\n'
text = 'Human:\n2023年中国主席是？\n\nAssistant:\n'
text = 'Human:\n提取实体：我爱中国天安门\n\nAssistant:\n'
text = 'Human:\n你能做什么\n\nAssistant:\n'
text = 'Human:\n帮我写一份旅游青岛的7天计划\n\nAssistant:\n'
text = 'Human:\n怎么追女生\n\nAssistant:\n'
text = 'Human:\n211 + 1912 = ？\n\nAssistant:\n'
text = 'Human:\nx +y = 4\n x - y =3,求x, y\n\nAssistant:\n'
text = 'Human:\n用英文写一首歌\n\nAssistant:\n'
text = 'Human:\n怎么做红烧肉\n\nAssistant:\n'
text = 'Human:\n四大名著\n\nAssistant:\n'
text = 'Human:\n周杰伦是谁，详细点介绍\n\nAssistant:\n'

r = generate(text, temperature=0.5)
print(f'output:\n{r[0][0]}')

output:
文字：
周杰伦，1979年1月18日出生于台湾台北市，台湾著名歌手、音乐制作人、演员、作家、商人。

音乐方面，周杰伦是华语流行乐坛的代表性人物之一，曾获得过多项金曲奖、MTV音乐录影带大奖等奖项。他的音乐风格多样，包括流行、摇滚、电子、嘻哈等，并在音乐制作方面有着极高的造诣。他的歌曲在华语乐坛有着深远的影响力，被誉为“华语乐坛的顶级创作人”。

除了音乐方面，周杰伦还涉足演艺界，曾出演过多部电影和电视剧，并获得过金马奖等奖项。他还是一位作家，出版了多部小说和散文集。

除此之外，周杰伦还是一位成功的商人，他创立了自己的品牌Jaywalk，并在全球范围内拥有多家餐厅、酒吧和咖啡店。

总之，周杰伦是一位多才多艺的艺人，他在音乐、演艺、商业等领域都有着极高的成就和影响力。
