Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 48 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,34 +1,57 @@
# Jetstream-PyTorch
JetStream Engine implementation in PyTorch

# Outline

# Install
1. Ssh to Cloud TPU VM (using v5e-8 TPU VM)
a. Create a Cloud TPU VM if you haven’t
2. Download jetstream-pytorch github repo
3. Clone repo and install dependencies
4. Download and convert weights
5. Run checkpoint converter (quantizer)
6. Local run
7. Run the server
8. Run benchmarks
9. Typical Errors

### 1. Get the jetstream-pytorch code
# Ssh to Cloud TPU VM (using v5e-8 TPU VM)

```bash
gcloud compute config-ssh
gcloud compute tpus tpu-vm ssh "your-tpu-vm" --project "your-project" --zone "your-project-zone"
```
## Create a Cloud TPU VM in a GCP project if you haven’t
Follow step 1-9 in the following guide
* https://cloud.google.com/tpu/docs/v5e-inference#prepare-a-project

# Clone repo and install dependencies

## Get the jetstream-pytorch code
```bash
git clone https://github.com/google/jetstream-pytorch.git
```

1.1 (optional) Create a virtual env using `venv` or `conda` and activate it.
(optional) Create a virtual env using `venv` or `conda` and activate it.

### 2. Run installation script:
## 2. Run installation script:

```bash
cd jetstream-pytorch
source install_everything.sh
```
NOTE: the above script will export PYTHONPATH, so sourcing will make it to take effect in the current shell

# Download and convert weights

# Get weights

### First get official llama weights from meta-llama
## Get official llama weights from meta-llama

Following instructions here: https://github.com/meta-llama/llama#download

After you have downloaded the weights, it will also download a `tokenizer.model` file that is
the tokenizer that we will use.

### Run weight merger to convert (and )
## Run weight safetensor convert

```bash
export input_ckpt_dir=Original llama weights directory
export output_ckpt_dir=The output directory
Expand Down Expand Up @@ -73,3 +96,20 @@ export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000 --dataset-path $dataset_path --dataset sharegpt --save-request-outputs
```
Please look at `deps/JetStream/benchmarks/README.md` for more information.


# Typical Errors

## Unexpected keyword argument 'device'

Fix:
* Uninstall jax and jaxlib dependencies
* Reinstall using `source install_everything.sh

## Out of memory

Fix:
* Use smaller batch size
* Use quantization