From ec5d84261796065249150de1a377e08ce6e7f3b0 Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Fri, 24 May 2024 01:38:11 +0000 Subject: [PATCH 1/6] Add support for Llama3-70b --- README.md | 11 ++++++-- convert_checkpoints.py | 8 ++++-- .../{llama.yaml => llama-2.yaml} | 0 default_shardings/llama-3.yaml | 28 +++++++++++++++++++ jetstream_pt/third_party/llama/model_args.py | 13 +++++++++ .../third_party/llama/model_exportable.py | 10 +++++-- 6 files changed, 61 insertions(+), 9 deletions(-) rename default_shardings/{llama.yaml => llama-2.yaml} (100%) create mode 100644 default_shardings/llama-3.yaml diff --git a/README.md b/README.md index 1c72eca3..21a13377 100644 --- a/README.md +++ b/README.md @@ -84,17 +84,22 @@ export tokenizer_path=tokenizer model file path ## Llama-2 7b ```bash -python run_interactive.py --size=7b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml +python run_interactive.py --size=7b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml ``` ## Llama-2 13b ```bash -python run_interactive.py --size=13b --model_name=$model_name --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml +python run_interactive.py --size=13b --model_name=$model_name --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml ``` ## Llama-3 8b ```bash -python run_interactive.py --size=8b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml +python run_interactive.py --size=8b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml +``` + +## Llama-3 70b +```bash +python run_interactive.py --size=70b --model_name=$model_name --batch_size=8 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml ``` ## Gemma 7b diff --git a/convert_checkpoints.py b/convert_checkpoints.py index 96501ec4..99691302 100644 --- a/convert_checkpoints.py +++ b/convert_checkpoints.py @@ -179,17 +179,19 @@ def _merge_llama_weights( f"{len(tensors)} shards (shape = {tensors[0].shape}) for {key})" ) state_dict_for_key = {} - for pattern, kind in llama_model.get_weight_sharding_type.items(): + for pattern, kind in llama_model.Transformer.get_weight_sharding_type( + model_name=FLAGS.model_name + ).items(): if not key.endswith(pattern): continue with torch.no_grad(): if kind in ("ParallelEmbedding", "RowParallelLinear"): state_dict_for_key[key] = torch.cat(tensors, 1) - elif kind == "ColumnParallelLinear": + elif kind in ("ColumnParallelLinear", "VocabParallelEmbedding"): state_dict_for_key[key] = torch.cat(tensors, 0) else: if not all( - torch.allclose(tensors[0], tensor, atol=1e-6) + torch.allclose(tensors[0], tensor, atol=1e-2) for tensor in tensors[1:] ): raise ValueError( diff --git a/default_shardings/llama.yaml b/default_shardings/llama-2.yaml similarity index 100% rename from default_shardings/llama.yaml rename to default_shardings/llama-2.yaml diff --git a/default_shardings/llama-3.yaml b/default_shardings/llama-3.yaml new file mode 100644 index 00000000..38ede225 --- /dev/null +++ b/default_shardings/llama-3.yaml @@ -0,0 +1,28 @@ + +# Sharding config for llama-3 +# Sharding should either be an int between 0 and rank - 1 +# signifying the axis to shard or -1 / null signifying replicated + + +freqs_cis : -1 # torch.complex64 (2048, 64) +tok_embeddings.weight : 0 # torch.float32 (vocab_size, 4096) +tok_embeddings.weight_scaler : 0 # torch.bfloat16 (4096,) +layers.*.attention.wo.weight : 1 # torch.int8 (4096, 4096) +layers.*.attention.wo.weight_scaler : 0 # torch.bfloat16 (4096,) +layers.*.attention.wq.weight : 0 # torch.int8 (4096, 4096) +layers.*.attention.wq.weight_scaler : 0 # torch.bfloat16 (4096,) +layers.*.attention.wk.weight : 0 # torch.int8 (4096, 4096) +layers.*.attention.wk.weight_scaler : 0 # torch.bfloat16 (4096,) +layers.*.attention.wv.weight : 0 # torch.int8 (4096, 4096) +layers.*.attention.wv.weight_scaler : 0 # torch.bfloat16 (4096,) +layers.*.feed_forward.w1.weight : 0 # torch.float32 (11008, 4096) +layers.*.feed_forward.w1.weight_scaler : 0 # torch.bfloat16 (4096,) +layers.*.feed_forward.w2.weight : 1 # torch.float32 (4096, 11008) +layers.*.feed_forward.w2.weight_scaler : 0 # torch.bfloat16 (11008,) +layers.*.feed_forward.w3.weight : 0 # torch.float32 (11008, 4096) +layers.*.feed_forward.w3.weight_scaler : 0 # torch.bfloat16 (4096,) +layers.*.attention_norm.weight : -1 # torch.float32 (4096,) +layers.*.ffn_norm.weight : -1 # torch.float32 (4096,) +norm.weight : -1 # torch.float32 (4096,) +output.weight : 0 # torch.float32 (vocab_size, 4096) +output.weight_scaler : 0 # torch.float32 (4096,) diff --git a/jetstream_pt/third_party/llama/model_args.py b/jetstream_pt/third_party/llama/model_args.py index bcebfe69..7956667d 100755 --- a/jetstream_pt/third_party/llama/model_args.py +++ b/jetstream_pt/third_party/llama/model_args.py @@ -90,6 +90,19 @@ def get_arg( "norm_eps": 1e-05, "rope_theta": 500000.0, } + elif model_name == "llama-3-70b": + data = { + "dim": 8192, + "ffn_dim_multiplier": 1.3, + "multiple_of": 4096, + "n_heads": 64, + "n_kv_heads": 8, + "n_layers": 80, + "norm_eps": 1e-05, + "vocab_size": 128256, + "rope_theta": 500000.0, + } + return ModelArgs( max_seq_len=seqlen, max_batch_size=batch_size, diff --git a/jetstream_pt/third_party/llama/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py index 7c692b22..4a15846d 100644 --- a/jetstream_pt/third_party/llama/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -225,13 +225,12 @@ def get_quantized_embedding_weight_to_scaler_map(): } @staticmethod - def get_weight_sharding_type(): + def get_weight_sharding_type(model_name: str = "llama-3"): # ParallelEmbedding is col partitioned across the shards. # ColumnParallelLinear is row partitioned across shards due to transpose. # RowParallelLinear is col partitioned across shards due to transpose. # None is no partitioning and tensor should be identical across shards - return { - "tok_embeddings.weight": "ParallelEmbedding", + sharding_dict = { "rope.freqs": None, "attention.wq.weight": "ColumnParallelLinear", "attention.wk.weight": "ColumnParallelLinear", @@ -245,3 +244,8 @@ def get_weight_sharding_type(): "norm.weight": None, "output.weight": "ColumnParallelLinear", } + if model_name == "llama-2": + sharding_dict["tok_embeddings.weight"] = "ParallelEmbedding" + elif model_name == "llama-3": + sharding_dict["tok_embeddings.weight"] = "VocabParallelEmbedding" + return sharding_dict From c64a822143d78e258ae34f10c2d59a16825f3423 Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Fri, 24 May 2024 01:49:59 +0000 Subject: [PATCH 2/6] Fix unit tests --- .github/workflows/unit_tests.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/unit_tests.yaml b/.github/workflows/unit_tests.yaml index 50b805b1..f2f3903f 100644 --- a/.github/workflows/unit_tests.yaml +++ b/.github/workflows/unit_tests.yaml @@ -100,7 +100,7 @@ jobs: source install_everything.sh - name: Run interactive (bf16) run: | - JAX_PLATFORMS=cpu python run_interactive.py --size=tiny --batch_size=1 --max_cache_length=2048 --tokenizer_path=jetstream_pt/third_party/llama/tokenizer.model --model_name=llama-2 --sharding_config=default_shardings/llama.yaml --quantize_weights=0 --quantize_kv_cache=0 + JAX_PLATFORMS=cpu python run_interactive.py --size=tiny --batch_size=1 --max_cache_length=2048 --tokenizer_path=jetstream_pt/third_party/llama/tokenizer.model --model_name=llama-2 --sharding_config=default_shardings/llama-2.yaml --quantize_weights=0 --quantize_kv_cache=0 - name: Run interactive (int8) run: | - JAX_PLATFORMS=cpu python run_interactive.py --size=tiny --batch_size=1 --max_cache_length=2048 --tokenizer_path=jetstream_pt/third_party/llama/tokenizer.model --model_name=llama-2 --sharding_config=default_shardings/llama.yaml --quantize_weights=1 --quantize_type="int8_per_channel" --quantize_kv_cache=1 + JAX_PLATFORMS=cpu python run_interactive.py --size=tiny --batch_size=1 --max_cache_length=2048 --tokenizer_path=jetstream_pt/third_party/llama/tokenizer.model --model_name=llama-2 --sharding_config=default_shardings/llama-2.yaml --quantize_weights=1 --quantize_type="int8_per_channel" --quantize_kv_cache=1 From 0eeaa5ebc938250a0164361d9437ab6da1bf25d4 Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Sat, 25 May 2024 00:11:06 +0000 Subject: [PATCH 3/6] assert model_name is one of llama-2 or llama-3 for weight sharding --- jetstream_pt/third_party/llama/model_exportable.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/jetstream_pt/third_party/llama/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py index 4a15846d..3526aedf 100644 --- a/jetstream_pt/third_party/llama/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -225,11 +225,16 @@ def get_quantized_embedding_weight_to_scaler_map(): } @staticmethod - def get_weight_sharding_type(model_name: str = "llama-3"): + def get_weight_sharding_type(model_name: str = ""): # ParallelEmbedding is col partitioned across the shards. + # VocalParallelEmbedding is row partitioned across the shards. # ColumnParallelLinear is row partitioned across shards due to transpose. # RowParallelLinear is col partitioned across shards due to transpose. # None is no partitioning and tensor should be identical across shards + expected_model_names = ('llama-2', 'llama-3') + assert ( + model_name in expected_model_names + ), f"Expected model_name to one of {expected_model_names}" sharding_dict = { "rope.freqs": None, "attention.wq.weight": "ColumnParallelLinear", From c3368da490d940eda837e62c389fc8cdac17fac1 Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Sat, 25 May 2024 00:17:06 +0000 Subject: [PATCH 4/6] Fix lint --- jetstream_pt/third_party/llama/model_exportable.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jetstream_pt/third_party/llama/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py index 3526aedf..cf1198dd 100644 --- a/jetstream_pt/third_party/llama/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -231,9 +231,9 @@ def get_weight_sharding_type(model_name: str = ""): # ColumnParallelLinear is row partitioned across shards due to transpose. # RowParallelLinear is col partitioned across shards due to transpose. # None is no partitioning and tensor should be identical across shards - expected_model_names = ('llama-2', 'llama-3') + expected_model_names = ("llama-2", "llama-3") assert ( - model_name in expected_model_names + model_name in expected_model_names ), f"Expected model_name to one of {expected_model_names}" sharding_dict = { "rope.freqs": None, From a40f0b0e779871fc694d104b5f6b53c64f52e8c3 Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Mon, 10 Jun 2024 18:24:37 +0000 Subject: [PATCH 5/6] Revert separate shardings for llama-2 and llama-3 --- README.md | 8 +++--- default_shardings/llama-3.yaml | 28 ------------------- .../{llama-2.yaml => llama.yaml} | 0 3 files changed, 4 insertions(+), 32 deletions(-) delete mode 100644 default_shardings/llama-3.yaml rename default_shardings/{llama-2.yaml => llama.yaml} (100%) diff --git a/README.md b/README.md index 21a13377..49047075 100644 --- a/README.md +++ b/README.md @@ -84,22 +84,22 @@ export tokenizer_path=tokenizer model file path ## Llama-2 7b ```bash -python run_interactive.py --size=7b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml +python run_interactive.py --size=7b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml ``` ## Llama-2 13b ```bash -python run_interactive.py --size=13b --model_name=$model_name --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml +python run_interactive.py --size=13b --model_name=$model_name --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml ``` ## Llama-3 8b ```bash -python run_interactive.py --size=8b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml +python run_interactive.py --size=8b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml ``` ## Llama-3 70b ```bash -python run_interactive.py --size=70b --model_name=$model_name --batch_size=8 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml +python run_interactive.py --size=70b --model_name=$model_name --batch_size=8 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml ``` ## Gemma 7b diff --git a/default_shardings/llama-3.yaml b/default_shardings/llama-3.yaml deleted file mode 100644 index 38ede225..00000000 --- a/default_shardings/llama-3.yaml +++ /dev/null @@ -1,28 +0,0 @@ - -# Sharding config for llama-3 -# Sharding should either be an int between 0 and rank - 1 -# signifying the axis to shard or -1 / null signifying replicated - - -freqs_cis : -1 # torch.complex64 (2048, 64) -tok_embeddings.weight : 0 # torch.float32 (vocab_size, 4096) -tok_embeddings.weight_scaler : 0 # torch.bfloat16 (4096,) -layers.*.attention.wo.weight : 1 # torch.int8 (4096, 4096) -layers.*.attention.wo.weight_scaler : 0 # torch.bfloat16 (4096,) -layers.*.attention.wq.weight : 0 # torch.int8 (4096, 4096) -layers.*.attention.wq.weight_scaler : 0 # torch.bfloat16 (4096,) -layers.*.attention.wk.weight : 0 # torch.int8 (4096, 4096) -layers.*.attention.wk.weight_scaler : 0 # torch.bfloat16 (4096,) -layers.*.attention.wv.weight : 0 # torch.int8 (4096, 4096) -layers.*.attention.wv.weight_scaler : 0 # torch.bfloat16 (4096,) -layers.*.feed_forward.w1.weight : 0 # torch.float32 (11008, 4096) -layers.*.feed_forward.w1.weight_scaler : 0 # torch.bfloat16 (4096,) -layers.*.feed_forward.w2.weight : 1 # torch.float32 (4096, 11008) -layers.*.feed_forward.w2.weight_scaler : 0 # torch.bfloat16 (11008,) -layers.*.feed_forward.w3.weight : 0 # torch.float32 (11008, 4096) -layers.*.feed_forward.w3.weight_scaler : 0 # torch.bfloat16 (4096,) -layers.*.attention_norm.weight : -1 # torch.float32 (4096,) -layers.*.ffn_norm.weight : -1 # torch.float32 (4096,) -norm.weight : -1 # torch.float32 (4096,) -output.weight : 0 # torch.float32 (vocab_size, 4096) -output.weight_scaler : 0 # torch.float32 (4096,) diff --git a/default_shardings/llama-2.yaml b/default_shardings/llama.yaml similarity index 100% rename from default_shardings/llama-2.yaml rename to default_shardings/llama.yaml From 2bfbb310272ef05d21202d3bb7d1b17991523f27 Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Mon, 10 Jun 2024 18:47:55 +0000 Subject: [PATCH 6/6] Fix lint --- .github/workflows/unit_tests.yaml | 4 ++-- convert_checkpoints.py | 8 +++----- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/.github/workflows/unit_tests.yaml b/.github/workflows/unit_tests.yaml index f2f3903f..50b805b1 100644 --- a/.github/workflows/unit_tests.yaml +++ b/.github/workflows/unit_tests.yaml @@ -100,7 +100,7 @@ jobs: source install_everything.sh - name: Run interactive (bf16) run: | - JAX_PLATFORMS=cpu python run_interactive.py --size=tiny --batch_size=1 --max_cache_length=2048 --tokenizer_path=jetstream_pt/third_party/llama/tokenizer.model --model_name=llama-2 --sharding_config=default_shardings/llama-2.yaml --quantize_weights=0 --quantize_kv_cache=0 + JAX_PLATFORMS=cpu python run_interactive.py --size=tiny --batch_size=1 --max_cache_length=2048 --tokenizer_path=jetstream_pt/third_party/llama/tokenizer.model --model_name=llama-2 --sharding_config=default_shardings/llama.yaml --quantize_weights=0 --quantize_kv_cache=0 - name: Run interactive (int8) run: | - JAX_PLATFORMS=cpu python run_interactive.py --size=tiny --batch_size=1 --max_cache_length=2048 --tokenizer_path=jetstream_pt/third_party/llama/tokenizer.model --model_name=llama-2 --sharding_config=default_shardings/llama-2.yaml --quantize_weights=1 --quantize_type="int8_per_channel" --quantize_kv_cache=1 + JAX_PLATFORMS=cpu python run_interactive.py --size=tiny --batch_size=1 --max_cache_length=2048 --tokenizer_path=jetstream_pt/third_party/llama/tokenizer.model --model_name=llama-2 --sharding_config=default_shardings/llama.yaml --quantize_weights=1 --quantize_type="int8_per_channel" --quantize_kv_cache=1 diff --git a/convert_checkpoints.py b/convert_checkpoints.py index d84c8ea3..1b3af726 100644 --- a/convert_checkpoints.py +++ b/convert_checkpoints.py @@ -187,11 +187,9 @@ def _merge_llama_weights( ) state_dict_for_key = {} - weight_sharding_type = ( - llama_model.Transformer.get_weight_sharding_type( - model_name=FLAGS.model_name - ).items() - ) + weight_sharding_type = llama_model.Transformer.get_weight_sharding_type( + model_name=FLAGS.model_name + ).items() for pattern, kind in weight_sharding_type: if not key.endswith(pattern): continue