diff --git a/README.md b/README.md index 9734901d..3b14f6c0 100644 --- a/README.md +++ b/README.md @@ -43,20 +43,33 @@ NOTE: the above script will export PYTHONPATH, so sourcing will make it to take # Download and convert weights -## Get official llama weights from meta-llama +## LLaMA +### Get official llama weights from meta-llama Following instructions here: https://github.com/meta-llama/llama#download After you have downloaded the weights, it will also download a `tokenizer.model` file that is the tokenizer that we will use. +## Gemma +### Get Gemma Checkpoint from HuggingFace + +Please sign agreement on Huggingface website to access Gemma checkpoints. Download Gemma PyTorch checkpoint using huggingface-cli. Gemma Tokenizer is included in the checkpoint. + +```bash +huggingface-cli download google/gemma-7b-pytorch --local-dir $input_ckpt_dir +``` + +Need to manually modify the `config.json` in the checkpoint folder to make it a valid JSON file. (Replace `'` with `"`, remove the excessive `,` after the last item in the JSON object) + ## Run weight safetensor convert ```bash export input_ckpt_dir=Original llama weights directory export output_ckpt_dir=The output directory export quantize=True #whether to quantize -python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize +export model_name="llama-2" # or "gemma" +python -m convert_checkpoints --model_name=$model_name --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize ``` @@ -64,17 +77,23 @@ python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_ch Set tokenizer path ```bash -export tokenizer_path=tokenizer model file path from meta-llama +export tokenizer_path=tokenizer model file path ``` ## Llama 7b ```bash -python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path +python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml ``` ## Llama 13b ```bash -python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path +python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml +``` + + +## Gemma 7b +```bash +python run_interactive.py --model_name=$model_name --size=7b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml ``` @@ -82,7 +101,7 @@ python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --q NOTE: the `--platform=tpu=8` need to specify number of tpu devices (which is 4 for v4-8 and 8 for v5light-8`) ```bash -python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --platform=tpu=8 +python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --platform=tpu=8 --model=$model_name ``` Now you can fire gRPC to it diff --git a/convert_checkpoints.py b/convert_checkpoints.py index 1a29cacc..bda847b1 100644 --- a/convert_checkpoints.py +++ b/convert_checkpoints.py @@ -72,6 +72,7 @@ _QUANTIZE = flags.DEFINE_bool( "quantize", False, "When set to true, produces quantized weights" ) +_MODEL_TYPE = flags.DEFINE_string("model_name", "llama", "Type of the model.") # ParallelEmbedding is col partitioned across the shards. # ColumnParallelLinear is row partitioned across shards due to transpose. @@ -403,16 +404,71 @@ def merge_weights( print(f"Export outputs takes {end - start} seconds") +def convert_hf_gemma_weights( + input_ckpt_dir: epath.Path, output_ckpt_dir: epath.Path +): + """Convert gemma weights from Huggingface to be compatible with JetStream + 1. Map attention weights to new names. + 2. Split qkv fusion. + """ + ckpt_file = list(input_ckpt_dir.glob("*.ckpt")) + assert len(ckpt_file) == 1, "only expect 1 ckpt file for Gemma model." + ckpt_file = ckpt_file[0] + state_dict = torch.load(ckpt_file, map_location=torch.device("cpu"))[ + "model_state_dict" + ] + model_config = json.loads((input_ckpt_dir / "config.json").read_text()) + for key in list(state_dict.keys()): + if state_dict[key].dtype.is_complex and _OUTPUT_SAFETENSORS.value: + assert ( + key == "freqs_cis" + ), "Only expect key 'freqs_cis' in the state_dict has complex dtype." + # Remove "freqs_cis" since it has complex dtype, and safetensor doesn't support it. + # The "freqs_cis" will be reconstructed when it's loaded by inference engine. + state_dict.pop(key) + continue + prefix_to_remove = "model." + new_key = key + if key.startswith(prefix_to_remove): + new_key = new_key.removeprefix(prefix_to_remove) + if "qkv_proj" in key: + q_dim = model_config["num_attention_heads"] * model_config["head_dim"] + kv_dim = model_config["num_key_value_heads"] * model_config["head_dim"] + qkv = state_dict.pop(key) + q, k, v = qkv.split( + [ + q_dim, + kv_dim, + kv_dim, + ], + dim=0, + ) + state_dict[new_key.replace("qkv_proj", "wq")] = q + state_dict[new_key.replace("qkv_proj", "wk")] = k + state_dict[new_key.replace("qkv_proj", "wv")] = v + continue + if "o_proj" in key: + new_key = new_key.replace("o_proj", "wo") + if new_key != key: + state_dict[new_key] = state_dict.pop(key) + _export_to_local(output_ckpt_dir, model_config, state_dict) + + def main(argv: Sequence[str]) -> None: """convert checkpoint main function""" if len(argv) > 1: raise app.UsageError("Too many command-line arguments.") - merge_weights( - _INPUT_CHECKPOINT_DIR.value, - _OUTPUT_CHECKPOINT_DIR.value, - _MINIMIZE_MEMORY_FOOTPRINT.value, - _ENABLE_FLOAT32.value, - ) + if "gemma" in _MODEL_TYPE.value: + convert_hf_gemma_weights( + _INPUT_CHECKPOINT_DIR.value, _OUTPUT_CHECKPOINT_DIR.value + ) + else: + merge_weights( + _INPUT_CHECKPOINT_DIR.value, + _OUTPUT_CHECKPOINT_DIR.value, + _MINIMIZE_MEMORY_FOOTPRINT.value, + _ENABLE_FLOAT32.value, + ) if __name__ == "__main__": diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index ea77342e..f95d8cbd 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -552,6 +552,26 @@ def _load_from_safetensors(self, path): return weights + def _load_from_state_dict(self, path): + state_dict = torch.load(path, map_location=torch.device("cpu")) + weights = {} + for key, model_weights in self.pt_model.state_dict().items(): + assert key in state_dict, f"key: {key} not found" + arr = jax.device_put( + torch_xla2.tensor.t2j(state_dict[key]), self.env.sharding_by_name(key) + ) + assert tuple(model_weights.shape) == tuple( + arr.shape + ), f"key: {key} error: {model_weights.shape} != {arr.shape}" + weights[key] = arr + + for k, v in weights.items(): + if k.startswith("layers") and not k.startswith("layers.0"): + continue + print(f"Name: {k}, shape: {v.shape} x {v.dtype}") + + return weights + # pylint: disable-next=all def load_params(self) -> Params: # We want to fix this: load from files @@ -559,6 +579,8 @@ def load_params(self) -> Params: if self.env.checkpoint_path: if self.env.checkpoint_format == "safetensors": return self._load_from_safetensors(self.env.checkpoint_path) + elif self.env.checkpoint_format == "state_dict": + return self._load_from_state_dict(self.env.checkpoint_path) else: jax_weights = self._make_state_dict_jax(self.pt_model.state_dict()) jax_weights = { @@ -643,7 +665,7 @@ def create_pytorch_engine( ) -> PyTorchEngine: """Returns: The pytorch engine.""" - supported_models = ["llama-2", "llama-3"] + supported_models = ["llama-2", "llama-3", "gemma"] if model_name not in supported_models: raise NotImplementedError( f"Model name should be one of{','.join(supported_models)}" @@ -664,10 +686,9 @@ def create_pytorch_engine( elif ".safetensors" in ckpt_path: checkpoint_format = "safetensors" checkpoint_path = ckpt_path - elif ".pth" in ckpt_path: - raise NotImplementedError( - "Loading from Pytorch raw checkpoint is not supported!" - ) + elif ".pth" in ckpt_path or ".ckpt" in ckpt_path: + checkpoint_format = "state_dict" + checkpoint_path = ckpt_path else: path = epath.Path(ckpt_path) if ckpt_path and ckpt_path is not None else "" if not path.exists(): diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index e0b32df9..a58fcf6b 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -45,7 +45,7 @@ def forward(self, input): class WeightOnlyInt8Linear(torch.nn.Module): - def __init__(self, in_features, out_features, bias, device): + def __init__(self, in_features, out_features, bias=None, device=None): super().__init__() self.in_features = in_features self.out_features = out_features diff --git a/jetstream_pt/third_party/gemma/model.py b/jetstream_pt/third_party/gemma/model.py index 0d03227f..0cfc9b15 100644 --- a/jetstream_pt/third_party/gemma/model.py +++ b/jetstream_pt/third_party/gemma/model.py @@ -76,9 +76,15 @@ def __init__( if env.enable_weight_quantization else torch.nn.Linear ) - self.gate_proj = Linear(hidden_size, intermediate_size, device) - self.up_proj = Linear(hidden_size, intermediate_size, device) - self.down_proj = Linear(intermediate_size, hidden_size, device) + self.gate_proj = Linear( + hidden_size, intermediate_size, bias=False, device=device + ) + self.up_proj = Linear( + hidden_size, intermediate_size, bias=False, device=device + ) + self.down_proj = Linear( + intermediate_size, hidden_size, bias=False, device=device + ) def forward(self, x): gate = self.gate_proj(x) diff --git a/run_interactive.py b/run_interactive.py index 0dd16bb9..d0d3b21f 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -72,11 +72,6 @@ _MAX_CACHE_LENGTH = flags.DEFINE_integer( "max_cache_length", 1024, "kv_cache_quantize" ) -_MODEL_NAME = flags.DEFINE_string( - "model", - "llama-2", - "name of the model. Supported options are llama-2 and llama-3", -) _SHARDING_CONFIG = flags.DEFINE_string( "sharding_config", "", "config file for sharding" ) @@ -98,7 +93,6 @@ def create_engine(): param_size=_SIZE.value, context_length=_CONTEXT_LENGTH.value, batch_size=_BATCH_SIZE.value, - model_name=_MODEL_NAME.value, quantize_weights=_QUANTIZE_WEIGHTS.value, quantize_kv=_QUANTIZE_KV_CACHE.value, max_cache_length=_MAX_CACHE_LENGTH.value, diff --git a/run_server.py b/run_server.py index c3603fba..161af9bd 100644 --- a/run_server.py +++ b/run_server.py @@ -89,9 +89,6 @@ _SHARDING_CONFIG = flags.DEFINE_string( "sharding_config", "", "config file for sharding" ) -_MODEL_NAME = flags.DEFINE_string( - "model_name", "llama-2", "model name, defaults to llama-2" -) # pylint: disable-next=all @@ -119,7 +116,6 @@ def main(argv: Sequence[str]): quantize_kv=_QUANTIZE_KV_CACHE.value, max_cache_length=_MAX_CACHE_LENGTH.value, sharding_config=sharding_config_path, - model_name=_MODEL_NAME.value, ) server_config = ServerConfig( interleaved_slices=(_PLATFORM.value,),