Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@ JetStream Engine implementation in PyTorch

### 1. Get the jetstream-pytorch code
```bash
git clone https://github.com/pytorch-tpu/jetstream-pytorch.git
git clone https://github.com/google/jetstream-pytorch.git
```

1.1 (optional) Create a virtual env using `venv` or `conda` and activate it.

### 2. Run installation script:

```bash
sh install_everything.sh
cd jetstream-pytorch
source install_everything.sh
```


Expand All @@ -38,35 +39,38 @@ python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_ch

# Local run

## Llama 7b
Set tokenizer path
```bash
export tokenizer_path=tokenizer model file path from meta-llama
```
python benchmarks/run_offline.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir/model.safetensors --tokenizer_path=tokenizer.model

## Llama 7b
```bash
python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir/model.safetensors --tokenizer_path=$tokenizer_path
```

## Llama 13b
```bash
python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir/model.safetensors --tokenizer_path=$tokenizer_path
```
python benchmarks/run_offline.py --size=13b --batch_size=96 --max_cache_length=1280 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir/model.safetensors --tokenizer_path=tokenizer.model
```
NOTE: for 13b model we recommend to use `--max_cache_length=1280`, i.e. this implements sliding window attention.


# Run the server
NOTE: the `--platform=tpu=8` need to specify number of tpu devices (which is 4 for v4-8 and 8 for v5light-8`)

```bash
python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir/model.safetensors --tokenizer_path=tokenizer.model --platform=tpu=8
python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir/model.safetensors --tokenizer_path=$tokenizer_path --platform=tpu=8
```
Now you can fire gRPC to it

# Run benchmark
go to the deps/JetStream folder (downloaded during `install_everything.sh`)
```bash
cd deps/JetStream
python benchmark_serving.py --tokenizer /home/hanq/jetstream-pytorch/tokenizer.model --num-prompts 2000 --dataset ~/data/ShareGPT_V3_unfiltered_cleaned_split.json --warmup-first=1 --save-request-outputs
```
The ShareGPT dataset can be downloaded at

```bash
cd deps/JetStream
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
pip install -e .
python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000 --dataset-path $dataset_path --dataset sharegpt --save-request-outputs
```
Please look at `deps/JetStream/benchmarks/README.md` for more information.