diff --git a/README.md b/README.md index aa985c3a..d442d58c 100644 --- a/README.md +++ b/README.md @@ -1,34 +1,57 @@ # Jetstream-PyTorch JetStream Engine implementation in PyTorch +# Outline -# Install +1. Ssh to Cloud TPU VM (using v5e-8 TPU VM) + a. Create a Cloud TPU VM if you haven’t +2. Download jetstream-pytorch github repo +3. Clone repo and install dependencies +4. Download and convert weights +5. Run checkpoint converter (quantizer) +6. Local run +7. Run the server +8. Run benchmarks +9. Typical Errors -### 1. Get the jetstream-pytorch code +# Ssh to Cloud TPU VM (using v5e-8 TPU VM) + +```bash +gcloud compute config-ssh +gcloud compute tpus tpu-vm ssh "your-tpu-vm" --project "your-project" --zone "your-project-zone" +``` +## Create a Cloud TPU VM in a GCP project if you haven’t +Follow step 1-9 in the following guide +* https://cloud.google.com/tpu/docs/v5e-inference#prepare-a-project + +# Clone repo and install dependencies + +## Get the jetstream-pytorch code ```bash git clone https://github.com/google/jetstream-pytorch.git ``` -1.1 (optional) Create a virtual env using `venv` or `conda` and activate it. +(optional) Create a virtual env using `venv` or `conda` and activate it. -### 2. Run installation script: +## 2. Run installation script: ```bash cd jetstream-pytorch source install_everything.sh ``` +NOTE: the above script will export PYTHONPATH, so sourcing will make it to take effect in the current shell +# Download and convert weights -# Get weights - -### First get official llama weights from meta-llama +## Get official llama weights from meta-llama Following instructions here: https://github.com/meta-llama/llama#download After you have downloaded the weights, it will also download a `tokenizer.model` file that is the tokenizer that we will use. -### Run weight merger to convert (and ) +## Run weight safetensor convert + ```bash export input_ckpt_dir=Original llama weights directory export output_ckpt_dir=The output directory @@ -73,3 +96,20 @@ export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000 --dataset-path $dataset_path --dataset sharegpt --save-request-outputs ``` Please look at `deps/JetStream/benchmarks/README.md` for more information. + + +# Typical Errors + +## Unexpected keyword argument 'device' + +Fix: +* Uninstall jax and jaxlib dependencies +* Reinstall using `source install_everything.sh + +## Out of memory + +Fix: +* Use smaller batch size +* Use quantization + +