Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,46 +43,65 @@ NOTE: the above script will export PYTHONPATH, so sourcing will make it to take

# Download and convert weights

## Get official llama weights from meta-llama
## LLaMA
### Get official llama weights from meta-llama

Following instructions here: https://github.com/meta-llama/llama#download

After you have downloaded the weights, it will also download a `tokenizer.model` file that is
the tokenizer that we will use.

## Gemma
### Get Gemma Checkpoint from HuggingFace

Please sign agreement on Huggingface website to access Gemma checkpoints. Download Gemma PyTorch checkpoint using huggingface-cli. Gemma Tokenizer is included in the checkpoint.

```bash
huggingface-cli download google/gemma-7b-pytorch --local-dir $input_ckpt_dir
```

Need to manually modify the `config.json` in the checkpoint folder to make it a valid JSON file. (Replace `'` with `"`, remove the excessive `,` after the last item in the JSON object)

## Run weight safetensor convert

```bash
export input_ckpt_dir=Original llama weights directory
export output_ckpt_dir=The output directory
export quantize=True #whether to quantize
python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize
export model_name="llama-2" # or "gemma"
python -m convert_checkpoints --model_name=$model_name --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize
```


# Local run

Set tokenizer path
```bash
export tokenizer_path=tokenizer model file path from meta-llama
export tokenizer_path=tokenizer model file path
```

## Llama 7b
```bash
python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path
python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml
```

## Llama 13b
```bash
python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path
python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml
```


## Gemma 7b
```bash
python run_interactive.py --model_name=$model_name --size=7b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml
```


# Run the server
NOTE: the `--platform=tpu=8` need to specify number of tpu devices (which is 4 for v4-8 and 8 for v5light-8`)

```bash
python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --platform=tpu=8
python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --platform=tpu=8 --model=$model_name
```
Now you can fire gRPC to it

Expand Down
68 changes: 62 additions & 6 deletions convert_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
_QUANTIZE = flags.DEFINE_bool(
"quantize", False, "When set to true, produces quantized weights"
)
_MODEL_TYPE = flags.DEFINE_string("model_name", "llama", "Type of the model.")

# ParallelEmbedding is col partitioned across the shards.
# ColumnParallelLinear is row partitioned across shards due to transpose.
Expand Down Expand Up @@ -403,16 +404,71 @@ def merge_weights(
print(f"Export outputs takes {end - start} seconds")


def convert_hf_gemma_weights(
input_ckpt_dir: epath.Path, output_ckpt_dir: epath.Path
):
"""Convert gemma weights from Huggingface to be compatible with JetStream
1. Map attention weights to new names.
2. Split qkv fusion.
"""
ckpt_file = list(input_ckpt_dir.glob("*.ckpt"))
assert len(ckpt_file) == 1, "only expect 1 ckpt file for Gemma model."
ckpt_file = ckpt_file[0]
state_dict = torch.load(ckpt_file, map_location=torch.device("cpu"))[
"model_state_dict"
]
model_config = json.loads((input_ckpt_dir / "config.json").read_text())
for key in list(state_dict.keys()):
if state_dict[key].dtype.is_complex and _OUTPUT_SAFETENSORS.value:
assert (
key == "freqs_cis"
), "Only expect key 'freqs_cis' in the state_dict has complex dtype."
# Remove "freqs_cis" since it has complex dtype, and safetensor doesn't support it.
# The "freqs_cis" will be reconstructed when it's loaded by inference engine.
state_dict.pop(key)
continue
prefix_to_remove = "model."
new_key = key
if key.startswith(prefix_to_remove):
new_key = new_key.removeprefix(prefix_to_remove)
if "qkv_proj" in key:
q_dim = model_config["num_attention_heads"] * model_config["head_dim"]
kv_dim = model_config["num_key_value_heads"] * model_config["head_dim"]
qkv = state_dict.pop(key)
q, k, v = qkv.split(
[
q_dim,
kv_dim,
kv_dim,
],
dim=0,
)
state_dict[new_key.replace("qkv_proj", "wq")] = q
state_dict[new_key.replace("qkv_proj", "wk")] = k
state_dict[new_key.replace("qkv_proj", "wv")] = v
continue
if "o_proj" in key:
new_key = new_key.replace("o_proj", "wo")
if new_key != key:
state_dict[new_key] = state_dict.pop(key)
_export_to_local(output_ckpt_dir, model_config, state_dict)


def main(argv: Sequence[str]) -> None:
"""convert checkpoint main function"""
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
merge_weights(
_INPUT_CHECKPOINT_DIR.value,
_OUTPUT_CHECKPOINT_DIR.value,
_MINIMIZE_MEMORY_FOOTPRINT.value,
_ENABLE_FLOAT32.value,
)
if "gemma" in _MODEL_TYPE.value:
convert_hf_gemma_weights(
_INPUT_CHECKPOINT_DIR.value, _OUTPUT_CHECKPOINT_DIR.value
)
else:
merge_weights(
_INPUT_CHECKPOINT_DIR.value,
_OUTPUT_CHECKPOINT_DIR.value,
_MINIMIZE_MEMORY_FOOTPRINT.value,
_ENABLE_FLOAT32.value,
)


if __name__ == "__main__":
Expand Down
31 changes: 26 additions & 5 deletions jetstream_pt/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,13 +552,35 @@ def _load_from_safetensors(self, path):

return weights

def _load_from_state_dict(self, path):
state_dict = torch.load(path, map_location=torch.device("cpu"))
weights = {}
for key, model_weights in self.pt_model.state_dict().items():
assert key in state_dict, f"key: {key} not found"
arr = jax.device_put(
torch_xla2.tensor.t2j(state_dict[key]), self.env.sharding_by_name(key)
)
assert tuple(model_weights.shape) == tuple(
arr.shape
), f"key: {key} error: {model_weights.shape} != {arr.shape}"
weights[key] = arr

for k, v in weights.items():
if k.startswith("layers") and not k.startswith("layers.0"):
continue
print(f"Name: {k}, shape: {v.shape} x {v.dtype}")

return weights

# pylint: disable-next=all
def load_params(self) -> Params:
# We want to fix this: load from files
with jax.default_device(self.colocated_cpus):
if self.env.checkpoint_path:
if self.env.checkpoint_format == "safetensors":
return self._load_from_safetensors(self.env.checkpoint_path)
elif self.env.checkpoint_format == "state_dict":
return self._load_from_state_dict(self.env.checkpoint_path)
else:
jax_weights = self._make_state_dict_jax(self.pt_model.state_dict())
jax_weights = {
Expand Down Expand Up @@ -643,7 +665,7 @@ def create_pytorch_engine(
) -> PyTorchEngine:
"""Returns: The pytorch engine."""

supported_models = ["llama-2", "llama-3"]
supported_models = ["llama-2", "llama-3", "gemma"]
if model_name not in supported_models:
raise NotImplementedError(
f"Model name should be one of{','.join(supported_models)}"
Expand All @@ -664,10 +686,9 @@ def create_pytorch_engine(
elif ".safetensors" in ckpt_path:
checkpoint_format = "safetensors"
checkpoint_path = ckpt_path
elif ".pth" in ckpt_path:
raise NotImplementedError(
"Loading from Pytorch raw checkpoint is not supported!"
)
elif ".pth" in ckpt_path or ".ckpt" in ckpt_path:
checkpoint_format = "state_dict"
checkpoint_path = ckpt_path
else:
path = epath.Path(ckpt_path) if ckpt_path and ckpt_path is not None else ""
if not path.exists():
Expand Down
2 changes: 1 addition & 1 deletion jetstream_pt/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def forward(self, input):

class WeightOnlyInt8Linear(torch.nn.Module):

def __init__(self, in_features, out_features, bias, device):
def __init__(self, in_features, out_features, bias=None, device=None):
super().__init__()
self.in_features = in_features
self.out_features = out_features
Expand Down
12 changes: 9 additions & 3 deletions jetstream_pt/third_party/gemma/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,15 @@ def __init__(
if env.enable_weight_quantization
else torch.nn.Linear
)
self.gate_proj = Linear(hidden_size, intermediate_size, device)
self.up_proj = Linear(hidden_size, intermediate_size, device)
self.down_proj = Linear(intermediate_size, hidden_size, device)
self.gate_proj = Linear(
hidden_size, intermediate_size, bias=False, device=device
)
self.up_proj = Linear(
hidden_size, intermediate_size, bias=False, device=device
)
self.down_proj = Linear(
intermediate_size, hidden_size, bias=False, device=device
)

def forward(self, x):
gate = self.gate_proj(x)
Expand Down
6 changes: 0 additions & 6 deletions run_interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,6 @@
_MAX_CACHE_LENGTH = flags.DEFINE_integer(
"max_cache_length", 1024, "kv_cache_quantize"
)
_MODEL_NAME = flags.DEFINE_string(
"model",
"llama-2",
"name of the model. Supported options are llama-2 and llama-3",
)
_SHARDING_CONFIG = flags.DEFINE_string(
"sharding_config", "", "config file for sharding"
)
Expand All @@ -98,7 +93,6 @@ def create_engine():
param_size=_SIZE.value,
context_length=_CONTEXT_LENGTH.value,
batch_size=_BATCH_SIZE.value,
model_name=_MODEL_NAME.value,
quantize_weights=_QUANTIZE_WEIGHTS.value,
quantize_kv=_QUANTIZE_KV_CACHE.value,
max_cache_length=_MAX_CACHE_LENGTH.value,
Expand Down
4 changes: 0 additions & 4 deletions run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,6 @@
_SHARDING_CONFIG = flags.DEFINE_string(
"sharding_config", "", "config file for sharding"
)
_MODEL_NAME = flags.DEFINE_string(
"model_name", "llama-2", "model name, defaults to llama-2"
)


# pylint: disable-next=all
Expand Down Expand Up @@ -119,7 +116,6 @@ def main(argv: Sequence[str]):
quantize_kv=_QUANTIZE_KV_CACHE.value,
max_cache_length=_MAX_CACHE_LENGTH.value,
sharding_config=sharding_config_path,
model_name=_MODEL_NAME.value,
)
server_config = ServerConfig(
interleaved_slices=(_PLATFORM.value,),
Expand Down