From 5d7b97092eec663b605494c3a440a6b71b84a84a Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Wed, 8 May 2024 18:07:44 +0000 Subject: [PATCH 1/4] add convert hf gemma weights --- convert_checkpoints.py | 53 ++++++++++++++++++++++++- jetstream_pt/engine.py | 26 +++++++++++- jetstream_pt/layers.py | 2 +- jetstream_pt/third_party/gemma/model.py | 6 +-- run_interactive.py | 6 --- 5 files changed, 81 insertions(+), 12 deletions(-) diff --git a/convert_checkpoints.py b/convert_checkpoints.py index 1a29cacc..031e9e0e 100644 --- a/convert_checkpoints.py +++ b/convert_checkpoints.py @@ -72,6 +72,11 @@ _QUANTIZE = flags.DEFINE_bool( "quantize", False, "When set to true, produces quantized weights" ) +_MODEL_TYPE = flags.DEFINE_string( + "model_type", + "llama", + "Type of the model." +) # ParallelEmbedding is col partitioned across the shards. # ColumnParallelLinear is row partitioned across shards due to transpose. @@ -403,11 +408,57 @@ def merge_weights( print(f"Export outputs takes {end - start} seconds") +def convert_hf_gemma_weights(input_ckpt_dir: epath.Path, + output_ckpt_dir: epath.Path): + """Convert gemma weights from Huggingface to be compatible with JetStream + 1. Map attention weights to new names. + 2. Split qkv fusion. + """ + ckpt_file = list(input_ckpt_dir.glob("*.ckpt")) + assert len(ckpt_file) == 1 + ckpt_file = ckpt_file[0] + state_dict = torch.load(ckpt_file, map_location=torch.device("cpu"))['model_state_dict'] + model_config = json.loads((input_ckpt_dir / "config.json").read_text()) + for key in list(state_dict.keys()): + prefix_to_remove = 'model.' + new_key = key + if key.startswith(prefix_to_remove): + new_key = new_key.removeprefix(prefix_to_remove) + if 'qkv_proj' in key: + q_dim = model_config['num_attention_heads'] * model_config['head_dim'] + kv_dim = model_config['num_key_value_heads'] * model_config['head_dim'] + qkv = state_dict.pop(key) + q, k, v = qkv.split( + [ + q_dim, + kv_dim, + kv_dim, + ], + dim=0, + ) + state_dict[new_key.replace('qkv_proj', 'wq')] = q + state_dict[new_key.replace('qkv_proj', 'wk')] = k + state_dict[new_key.replace('qkv_proj', 'wv')] = v + continue + if 'o_proj' in key: + new_key = new_key.replace('o_proj', 'wo') + if new_key != key: + state_dict[new_key] = state_dict.pop(key) + + ckpt_basename = os.path.basename(ckpt_file) + output_ckpt_dir.mkdir(parents=True, exist_ok=True) + torch.save({'model_state_dict':state_dict}, + os.fspath(output_ckpt_dir / ckpt_basename)) + (output_ckpt_dir / "config.json").write_text(json.dumps(model_config)) + def main(argv: Sequence[str]) -> None: """convert checkpoint main function""" if len(argv) > 1: raise app.UsageError("Too many command-line arguments.") - merge_weights( + if "gemma" in _MODEL_TYPE.value: + convert_hf_gemma_weights(_INPUT_CHECKPOINT_DIR.value, _OUTPUT_CHECKPOINT_DIR.value) + else: + merge_weights( _INPUT_CHECKPOINT_DIR.value, _OUTPUT_CHECKPOINT_DIR.value, _MINIMIZE_MEMORY_FOOTPRINT.value, diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index ea77342e..3e1d7d9b 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -552,6 +552,25 @@ def _load_from_safetensors(self, path): return weights + def _load_from_state_dict(self, path): + + state_dict = torch.load(path, map_location=torch.device("cpu"))['model_state_dict'] + weights = {} + for key, model_weights in self.pt_model.state_dict().items(): + assert key in state_dict , f"key: {key} not found" + arr = jax.device_put(torch_xla2.tensor.t2j(state_dict[key]), self.env.sharding_by_name(key)) + assert tuple(model_weights.shape) == tuple( + arr.shape + ), f"key: {key} error: {model_weights.shape} != {arr.shape}" + weights[key] = arr + + for k, v in weights.items(): + if k.startswith("layers") and not k.startswith("layers.0"): + continue + print(f"Name: {k}, shape: {v.shape} x {v.dtype}") + + return weights + # pylint: disable-next=all def load_params(self) -> Params: # We want to fix this: load from files @@ -559,6 +578,8 @@ def load_params(self) -> Params: if self.env.checkpoint_path: if self.env.checkpoint_format == "safetensors": return self._load_from_safetensors(self.env.checkpoint_path) + elif self.env.checkpoint_format == "state_dict": + return self._load_from_state_dict(self.env.checkpoint_path) else: jax_weights = self._make_state_dict_jax(self.pt_model.state_dict()) jax_weights = { @@ -643,7 +664,7 @@ def create_pytorch_engine( ) -> PyTorchEngine: """Returns: The pytorch engine.""" - supported_models = ["llama-2", "llama-3"] + supported_models = ["llama-2", "llama-3", "gemma"] if model_name not in supported_models: raise NotImplementedError( f"Model name should be one of{','.join(supported_models)}" @@ -668,6 +689,9 @@ def create_pytorch_engine( raise NotImplementedError( "Loading from Pytorch raw checkpoint is not supported!" ) + elif ".ckpt" in ckpt_path: + checkpoint_format = "state_dict" + checkpoint_path = ckpt_path else: path = epath.Path(ckpt_path) if ckpt_path and ckpt_path is not None else "" if not path.exists(): diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index e0b32df9..a58fcf6b 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -45,7 +45,7 @@ def forward(self, input): class WeightOnlyInt8Linear(torch.nn.Module): - def __init__(self, in_features, out_features, bias, device): + def __init__(self, in_features, out_features, bias=None, device=None): super().__init__() self.in_features = in_features self.out_features = out_features diff --git a/jetstream_pt/third_party/gemma/model.py b/jetstream_pt/third_party/gemma/model.py index 0d03227f..a415988a 100644 --- a/jetstream_pt/third_party/gemma/model.py +++ b/jetstream_pt/third_party/gemma/model.py @@ -76,9 +76,9 @@ def __init__( if env.enable_weight_quantization else torch.nn.Linear ) - self.gate_proj = Linear(hidden_size, intermediate_size, device) - self.up_proj = Linear(hidden_size, intermediate_size, device) - self.down_proj = Linear(intermediate_size, hidden_size, device) + self.gate_proj = Linear(hidden_size, intermediate_size, bias=False, device=device) + self.up_proj = Linear(hidden_size, intermediate_size, bias=False, device=device) + self.down_proj = Linear(intermediate_size, hidden_size, bias=False, device=device) def forward(self, x): gate = self.gate_proj(x) diff --git a/run_interactive.py b/run_interactive.py index 0dd16bb9..d0d3b21f 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -72,11 +72,6 @@ _MAX_CACHE_LENGTH = flags.DEFINE_integer( "max_cache_length", 1024, "kv_cache_quantize" ) -_MODEL_NAME = flags.DEFINE_string( - "model", - "llama-2", - "name of the model. Supported options are llama-2 and llama-3", -) _SHARDING_CONFIG = flags.DEFINE_string( "sharding_config", "", "config file for sharding" ) @@ -98,7 +93,6 @@ def create_engine(): param_size=_SIZE.value, context_length=_CONTEXT_LENGTH.value, batch_size=_BATCH_SIZE.value, - model_name=_MODEL_NAME.value, quantize_weights=_QUANTIZE_WEIGHTS.value, quantize_kv=_QUANTIZE_KV_CACHE.value, max_cache_length=_MAX_CACHE_LENGTH.value, From e7037d8e1c03398ac42cd2ce37876882187445c0 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Wed, 8 May 2024 22:54:35 +0000 Subject: [PATCH 2/4] format --- convert_checkpoints.py | 56 +++++++++++++------------ jetstream_pt/engine.py | 12 ++++-- jetstream_pt/third_party/gemma/model.py | 12 ++++-- 3 files changed, 47 insertions(+), 33 deletions(-) diff --git a/convert_checkpoints.py b/convert_checkpoints.py index 031e9e0e..d12941a6 100644 --- a/convert_checkpoints.py +++ b/convert_checkpoints.py @@ -72,11 +72,7 @@ _QUANTIZE = flags.DEFINE_bool( "quantize", False, "When set to true, produces quantized weights" ) -_MODEL_TYPE = flags.DEFINE_string( - "model_type", - "llama", - "Type of the model." -) +_MODEL_TYPE = flags.DEFINE_string("model_type", "llama", "Type of the model.") # ParallelEmbedding is col partitioned across the shards. # ColumnParallelLinear is row partitioned across shards due to transpose. @@ -408,8 +404,9 @@ def merge_weights( print(f"Export outputs takes {end - start} seconds") -def convert_hf_gemma_weights(input_ckpt_dir: epath.Path, - output_ckpt_dir: epath.Path): +def convert_hf_gemma_weights( + input_ckpt_dir: epath.Path, output_ckpt_dir: epath.Path +): """Convert gemma weights from Huggingface to be compatible with JetStream 1. Map attention weights to new names. 2. Split qkv fusion. @@ -417,16 +414,18 @@ def convert_hf_gemma_weights(input_ckpt_dir: epath.Path, ckpt_file = list(input_ckpt_dir.glob("*.ckpt")) assert len(ckpt_file) == 1 ckpt_file = ckpt_file[0] - state_dict = torch.load(ckpt_file, map_location=torch.device("cpu"))['model_state_dict'] + state_dict = torch.load(ckpt_file, map_location=torch.device("cpu"))[ + "model_state_dict" + ] model_config = json.loads((input_ckpt_dir / "config.json").read_text()) for key in list(state_dict.keys()): - prefix_to_remove = 'model.' + prefix_to_remove = "model." new_key = key if key.startswith(prefix_to_remove): new_key = new_key.removeprefix(prefix_to_remove) - if 'qkv_proj' in key: - q_dim = model_config['num_attention_heads'] * model_config['head_dim'] - kv_dim = model_config['num_key_value_heads'] * model_config['head_dim'] + if "qkv_proj" in key: + q_dim = model_config["num_attention_heads"] * model_config["head_dim"] + kv_dim = model_config["num_key_value_heads"] * model_config["head_dim"] qkv = state_dict.pop(key) q, k, v = qkv.split( [ @@ -436,34 +435,39 @@ def convert_hf_gemma_weights(input_ckpt_dir: epath.Path, ], dim=0, ) - state_dict[new_key.replace('qkv_proj', 'wq')] = q - state_dict[new_key.replace('qkv_proj', 'wk')] = k - state_dict[new_key.replace('qkv_proj', 'wv')] = v + state_dict[new_key.replace("qkv_proj", "wq")] = q + state_dict[new_key.replace("qkv_proj", "wk")] = k + state_dict[new_key.replace("qkv_proj", "wv")] = v continue - if 'o_proj' in key: - new_key = new_key.replace('o_proj', 'wo') + if "o_proj" in key: + new_key = new_key.replace("o_proj", "wo") if new_key != key: state_dict[new_key] = state_dict.pop(key) - + ckpt_basename = os.path.basename(ckpt_file) output_ckpt_dir.mkdir(parents=True, exist_ok=True) - torch.save({'model_state_dict':state_dict}, - os.fspath(output_ckpt_dir / ckpt_basename)) + torch.save( + {"model_state_dict": state_dict}, + os.fspath(output_ckpt_dir / ckpt_basename), + ) (output_ckpt_dir / "config.json").write_text(json.dumps(model_config)) + def main(argv: Sequence[str]) -> None: """convert checkpoint main function""" if len(argv) > 1: raise app.UsageError("Too many command-line arguments.") if "gemma" in _MODEL_TYPE.value: - convert_hf_gemma_weights(_INPUT_CHECKPOINT_DIR.value, _OUTPUT_CHECKPOINT_DIR.value) + convert_hf_gemma_weights( + _INPUT_CHECKPOINT_DIR.value, _OUTPUT_CHECKPOINT_DIR.value + ) else: merge_weights( - _INPUT_CHECKPOINT_DIR.value, - _OUTPUT_CHECKPOINT_DIR.value, - _MINIMIZE_MEMORY_FOOTPRINT.value, - _ENABLE_FLOAT32.value, - ) + _INPUT_CHECKPOINT_DIR.value, + _OUTPUT_CHECKPOINT_DIR.value, + _MINIMIZE_MEMORY_FOOTPRINT.value, + _ENABLE_FLOAT32.value, + ) if __name__ == "__main__": diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 3e1d7d9b..d627a19c 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -553,12 +553,16 @@ def _load_from_safetensors(self, path): return weights def _load_from_state_dict(self, path): - - state_dict = torch.load(path, map_location=torch.device("cpu"))['model_state_dict'] + + state_dict = torch.load(path, map_location=torch.device("cpu"))[ + "model_state_dict" + ] weights = {} for key, model_weights in self.pt_model.state_dict().items(): - assert key in state_dict , f"key: {key} not found" - arr = jax.device_put(torch_xla2.tensor.t2j(state_dict[key]), self.env.sharding_by_name(key)) + assert key in state_dict, f"key: {key} not found" + arr = jax.device_put( + torch_xla2.tensor.t2j(state_dict[key]), self.env.sharding_by_name(key) + ) assert tuple(model_weights.shape) == tuple( arr.shape ), f"key: {key} error: {model_weights.shape} != {arr.shape}" diff --git a/jetstream_pt/third_party/gemma/model.py b/jetstream_pt/third_party/gemma/model.py index a415988a..0cfc9b15 100644 --- a/jetstream_pt/third_party/gemma/model.py +++ b/jetstream_pt/third_party/gemma/model.py @@ -76,9 +76,15 @@ def __init__( if env.enable_weight_quantization else torch.nn.Linear ) - self.gate_proj = Linear(hidden_size, intermediate_size, bias=False, device=device) - self.up_proj = Linear(hidden_size, intermediate_size, bias=False, device=device) - self.down_proj = Linear(intermediate_size, hidden_size, bias=False, device=device) + self.gate_proj = Linear( + hidden_size, intermediate_size, bias=False, device=device + ) + self.up_proj = Linear( + hidden_size, intermediate_size, bias=False, device=device + ) + self.down_proj = Linear( + intermediate_size, hidden_size, bias=False, device=device + ) def forward(self, x): gate = self.gate_proj(x) From 7d4900d67db283e4400b6d60893f9a5a17fdfc58 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Thu, 9 May 2024 02:29:05 +0000 Subject: [PATCH 3/4] update convert script --- README.md | 31 +++++++++++++++++++++++++------ convert_checkpoints.py | 21 +++++++++++---------- jetstream_pt/engine.py | 11 ++--------- run_server.py | 4 ---- 4 files changed, 38 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 9734901d..42952a80 100644 --- a/README.md +++ b/README.md @@ -43,20 +43,33 @@ NOTE: the above script will export PYTHONPATH, so sourcing will make it to take # Download and convert weights -## Get official llama weights from meta-llama +## LLaMA +### Get official llama weights from meta-llama Following instructions here: https://github.com/meta-llama/llama#download After you have downloaded the weights, it will also download a `tokenizer.model` file that is the tokenizer that we will use. +## Gemma +### Get Gemma Checkpoint from HuggingFace + +Download Gemma PyTorch checkpoint using huggingface-cli. Gemma Tokenizer is included in the checkpoint. + +```bash +huggingface-cli download google/gemma-7b-pytorch --local-dir $input_ckpt_dir +``` + +Need to manually modify the `config.json` in the checkpoint folder to make it a valid JSON file. (Replace `'` with `"`, remove the excessive `,` after the last item in the JSON object) + ## Run weight safetensor convert ```bash export input_ckpt_dir=Original llama weights directory export output_ckpt_dir=The output directory export quantize=True #whether to quantize -python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize +export model_name="llama-2" # or "gemma" +python -m convert_checkpoints --model_name=$model_name --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize ``` @@ -64,17 +77,23 @@ python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_ch Set tokenizer path ```bash -export tokenizer_path=tokenizer model file path from meta-llama +export tokenizer_path=tokenizer model file path ``` ## Llama 7b ```bash -python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path +python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml ``` ## Llama 13b ```bash -python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path +python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml +``` + + +## Gemma 7b +```bash +python run_interactive.py --model_name=$model_name --size=7b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml ``` @@ -82,7 +101,7 @@ python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --q NOTE: the `--platform=tpu=8` need to specify number of tpu devices (which is 4 for v4-8 and 8 for v5light-8`) ```bash -python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --platform=tpu=8 +python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --platform=tpu=8 --model=$model_name ``` Now you can fire gRPC to it diff --git a/convert_checkpoints.py b/convert_checkpoints.py index d12941a6..bda847b1 100644 --- a/convert_checkpoints.py +++ b/convert_checkpoints.py @@ -72,7 +72,7 @@ _QUANTIZE = flags.DEFINE_bool( "quantize", False, "When set to true, produces quantized weights" ) -_MODEL_TYPE = flags.DEFINE_string("model_type", "llama", "Type of the model.") +_MODEL_TYPE = flags.DEFINE_string("model_name", "llama", "Type of the model.") # ParallelEmbedding is col partitioned across the shards. # ColumnParallelLinear is row partitioned across shards due to transpose. @@ -412,13 +412,21 @@ def convert_hf_gemma_weights( 2. Split qkv fusion. """ ckpt_file = list(input_ckpt_dir.glob("*.ckpt")) - assert len(ckpt_file) == 1 + assert len(ckpt_file) == 1, "only expect 1 ckpt file for Gemma model." ckpt_file = ckpt_file[0] state_dict = torch.load(ckpt_file, map_location=torch.device("cpu"))[ "model_state_dict" ] model_config = json.loads((input_ckpt_dir / "config.json").read_text()) for key in list(state_dict.keys()): + if state_dict[key].dtype.is_complex and _OUTPUT_SAFETENSORS.value: + assert ( + key == "freqs_cis" + ), "Only expect key 'freqs_cis' in the state_dict has complex dtype." + # Remove "freqs_cis" since it has complex dtype, and safetensor doesn't support it. + # The "freqs_cis" will be reconstructed when it's loaded by inference engine. + state_dict.pop(key) + continue prefix_to_remove = "model." new_key = key if key.startswith(prefix_to_remove): @@ -443,14 +451,7 @@ def convert_hf_gemma_weights( new_key = new_key.replace("o_proj", "wo") if new_key != key: state_dict[new_key] = state_dict.pop(key) - - ckpt_basename = os.path.basename(ckpt_file) - output_ckpt_dir.mkdir(parents=True, exist_ok=True) - torch.save( - {"model_state_dict": state_dict}, - os.fspath(output_ckpt_dir / ckpt_basename), - ) - (output_ckpt_dir / "config.json").write_text(json.dumps(model_config)) + _export_to_local(output_ckpt_dir, model_config, state_dict) def main(argv: Sequence[str]) -> None: diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index d627a19c..f95d8cbd 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -553,10 +553,7 @@ def _load_from_safetensors(self, path): return weights def _load_from_state_dict(self, path): - - state_dict = torch.load(path, map_location=torch.device("cpu"))[ - "model_state_dict" - ] + state_dict = torch.load(path, map_location=torch.device("cpu")) weights = {} for key, model_weights in self.pt_model.state_dict().items(): assert key in state_dict, f"key: {key} not found" @@ -689,11 +686,7 @@ def create_pytorch_engine( elif ".safetensors" in ckpt_path: checkpoint_format = "safetensors" checkpoint_path = ckpt_path - elif ".pth" in ckpt_path: - raise NotImplementedError( - "Loading from Pytorch raw checkpoint is not supported!" - ) - elif ".ckpt" in ckpt_path: + elif ".pth" in ckpt_path or ".ckpt" in ckpt_path: checkpoint_format = "state_dict" checkpoint_path = ckpt_path else: diff --git a/run_server.py b/run_server.py index c3603fba..161af9bd 100644 --- a/run_server.py +++ b/run_server.py @@ -89,9 +89,6 @@ _SHARDING_CONFIG = flags.DEFINE_string( "sharding_config", "", "config file for sharding" ) -_MODEL_NAME = flags.DEFINE_string( - "model_name", "llama-2", "model name, defaults to llama-2" -) # pylint: disable-next=all @@ -119,7 +116,6 @@ def main(argv: Sequence[str]): quantize_kv=_QUANTIZE_KV_CACHE.value, max_cache_length=_MAX_CACHE_LENGTH.value, sharding_config=sharding_config_path, - model_name=_MODEL_NAME.value, ) server_config = ServerConfig( interleaved_slices=(_PLATFORM.value,), From 0ba908d468be3fc4c9099a5cd579e75677354ae2 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Thu, 9 May 2024 03:35:38 +0000 Subject: [PATCH 4/4] add sign agreement --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 42952a80..3b14f6c0 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ the tokenizer that we will use. ## Gemma ### Get Gemma Checkpoint from HuggingFace -Download Gemma PyTorch checkpoint using huggingface-cli. Gemma Tokenizer is included in the checkpoint. +Please sign agreement on Huggingface website to access Gemma checkpoints. Download Gemma PyTorch checkpoint using huggingface-cli. Gemma Tokenizer is included in the checkpoint. ```bash huggingface-cli download google/gemma-7b-pytorch --local-dir $input_ckpt_dir