Skip to content

Commit 4535bdf

Browse files
authored
Add support for Llama3-70b (#101)
* Add support for Llama3-70b * Fix unit tests * assert model_name is one of llama-2 or llama-3 for weight sharding * Fix lint * Revert separate shardings for llama-2 and llama-3 * Fix lint
1 parent e07aee6 commit 4535bdf

File tree

4 files changed

+36
-8
lines changed

4 files changed

+36
-8
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,11 @@ python run_interactive.py --size=13b --model_name=$model_name --batch_size=64 --
9898
python run_interactive.py --size=8b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
9999
```
100100

101+
## Llama-3 70b
102+
```bash
103+
python run_interactive.py --size=70b --model_name=$model_name --batch_size=8 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
104+
```
105+
101106
## Gemma 7b
102107
```bash
103108
python run_interactive.py --model_name=$model_name --size=7b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml

convert_checkpoints.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -186,20 +186,21 @@ def _merge_llama_weights(
186186
f"{len(tensors)} shards (shape = {tensors[0].shape}) for {key})"
187187
)
188188
state_dict_for_key = {}
189-
weight_sharding_type = (
190-
llama_model.Transformer.get_weight_sharding_type().items()
191-
)
189+
190+
weight_sharding_type = llama_model.Transformer.get_weight_sharding_type(
191+
model_name=FLAGS.model_name
192+
).items()
192193
for pattern, kind in weight_sharding_type:
193194
if not key.endswith(pattern):
194195
continue
195196
with torch.no_grad():
196197
if kind in ("ParallelEmbedding", "RowParallelLinear"):
197198
state_dict_for_key[key] = torch.cat(tensors, 1)
198-
elif kind == "ColumnParallelLinear":
199+
elif kind in ("ColumnParallelLinear", "VocabParallelEmbedding"):
199200
state_dict_for_key[key] = torch.cat(tensors, 0)
200201
else:
201202
if not all(
202-
torch.allclose(tensors[0], tensor, atol=1e-6)
203+
torch.allclose(tensors[0], tensor, atol=1e-2)
203204
for tensor in tensors[1:]
204205
):
205206
raise ValueError(

jetstream_pt/third_party/llama/model_args.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,19 @@ def get_arg(
9090
"norm_eps": 1e-05,
9191
"rope_theta": 500000.0,
9292
}
93+
elif model_name == "llama-3-70b":
94+
data = {
95+
"dim": 8192,
96+
"ffn_dim_multiplier": 1.3,
97+
"multiple_of": 4096,
98+
"n_heads": 64,
99+
"n_kv_heads": 8,
100+
"n_layers": 80,
101+
"norm_eps": 1e-05,
102+
"vocab_size": 128256,
103+
"rope_theta": 500000.0,
104+
}
105+
93106
return ModelArgs(
94107
max_seq_len=seqlen,
95108
max_batch_size=batch_size,

jetstream_pt/third_party/llama/model_exportable.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,13 +259,17 @@ def get_quantized_embedding_weight_to_scaler_map():
259259
}
260260

261261
@staticmethod
262-
def get_weight_sharding_type():
262+
def get_weight_sharding_type(model_name: str = ""):
263263
# ParallelEmbedding is col partitioned across the shards.
264+
# VocalParallelEmbedding is row partitioned across the shards.
264265
# ColumnParallelLinear is row partitioned across shards due to transpose.
265266
# RowParallelLinear is col partitioned across shards due to transpose.
266267
# None is no partitioning and tensor should be identical across shards
267-
return {
268-
"tok_embeddings.weight": "ParallelEmbedding",
268+
expected_model_names = ("llama-2", "llama-3")
269+
assert (
270+
model_name in expected_model_names
271+
), f"Expected model_name to one of {expected_model_names}"
272+
sharding_dict = {
269273
"rope.freqs": None,
270274
"attention.wq.weight": "ColumnParallelLinear",
271275
"attention.wk.weight": "ColumnParallelLinear",
@@ -279,3 +283,8 @@ def get_weight_sharding_type():
279283
"norm.weight": None,
280284
"output.weight": "ColumnParallelLinear",
281285
}
286+
if model_name == "llama-2":
287+
sharding_dict["tok_embeddings.weight"] = "ParallelEmbedding"
288+
elif model_name == "llama-3":
289+
sharding_dict["tok_embeddings.weight"] = "VocabParallelEmbedding"
290+
return sharding_dict

0 commit comments

Comments
 (0)