diff --git a/README.md b/README.md index 0514b819..83da9708 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ JetStream Engine implementation in PyTorch ### 1. Get the jetstream-pytorch code ```bash -git clone https://github.com/pytorch-tpu/jetstream-pytorch.git +git clone https://github.com/google/jetstream-pytorch.git ``` 1.1 (optional) Create a virtual env using `venv` or `conda` and activate it. @@ -14,7 +14,8 @@ git clone https://github.com/pytorch-tpu/jetstream-pytorch.git ### 2. Run installation script: ```bash -sh install_everything.sh +cd jetstream-pytorch +source install_everything.sh ``` @@ -38,35 +39,38 @@ python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_ch # Local run -## Llama 7b +Set tokenizer path +```bash +export tokenizer_path=tokenizer model file path from meta-llama ``` -python benchmarks/run_offline.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir/model.safetensors --tokenizer_path=tokenizer.model + +## Llama 7b +```bash +python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir/model.safetensors --tokenizer_path=$tokenizer_path ``` ## Llama 13b +```bash +python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir/model.safetensors --tokenizer_path=$tokenizer_path ``` -python benchmarks/run_offline.py --size=13b --batch_size=96 --max_cache_length=1280 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir/model.safetensors --tokenizer_path=tokenizer.model -``` -NOTE: for 13b model we recommend to use `--max_cache_length=1280`, i.e. this implements sliding window attention. # Run the server NOTE: the `--platform=tpu=8` need to specify number of tpu devices (which is 4 for v4-8 and 8 for v5light-8`) ```bash -python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir/model.safetensors --tokenizer_path=tokenizer.model --platform=tpu=8 +python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir/model.safetensors --tokenizer_path=$tokenizer_path --platform=tpu=8 ``` Now you can fire gRPC to it # Run benchmark go to the deps/JetStream folder (downloaded during `install_everything.sh`) -```bash -cd deps/JetStream -python benchmark_serving.py --tokenizer /home/hanq/jetstream-pytorch/tokenizer.model --num-prompts 2000 --dataset ~/data/ShareGPT_V3_unfiltered_cleaned_split.json --warmup-first=1 --save-request-outputs -``` -The ShareGPT dataset can be downloaded at ```bash +cd deps/JetStream wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json +export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json +pip install -e . +python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000 --dataset-path $dataset_path --dataset sharegpt --save-request-outputs ``` Please look at `deps/JetStream/benchmarks/README.md` for more information.