diff --git a/README.md b/README.md index bb7cd63c..58b57d7f 100644 --- a/README.md +++ b/README.md @@ -98,6 +98,11 @@ python run_interactive.py --size=13b --model_name=$model_name --batch_size=64 -- python run_interactive.py --size=8b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml ``` +## Llama-3 70b +```bash +python run_interactive.py --size=70b --model_name=$model_name --batch_size=8 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml +``` + ## Gemma 7b ```bash python run_interactive.py --model_name=$model_name --size=7b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml diff --git a/convert_checkpoints.py b/convert_checkpoints.py index 80c4705c..1b3af726 100644 --- a/convert_checkpoints.py +++ b/convert_checkpoints.py @@ -186,20 +186,21 @@ def _merge_llama_weights( f"{len(tensors)} shards (shape = {tensors[0].shape}) for {key})" ) state_dict_for_key = {} - weight_sharding_type = ( - llama_model.Transformer.get_weight_sharding_type().items() - ) + + weight_sharding_type = llama_model.Transformer.get_weight_sharding_type( + model_name=FLAGS.model_name + ).items() for pattern, kind in weight_sharding_type: if not key.endswith(pattern): continue with torch.no_grad(): if kind in ("ParallelEmbedding", "RowParallelLinear"): state_dict_for_key[key] = torch.cat(tensors, 1) - elif kind == "ColumnParallelLinear": + elif kind in ("ColumnParallelLinear", "VocabParallelEmbedding"): state_dict_for_key[key] = torch.cat(tensors, 0) else: if not all( - torch.allclose(tensors[0], tensor, atol=1e-6) + torch.allclose(tensors[0], tensor, atol=1e-2) for tensor in tensors[1:] ): raise ValueError( diff --git a/jetstream_pt/third_party/llama/model_args.py b/jetstream_pt/third_party/llama/model_args.py index bcebfe69..7956667d 100755 --- a/jetstream_pt/third_party/llama/model_args.py +++ b/jetstream_pt/third_party/llama/model_args.py @@ -90,6 +90,19 @@ def get_arg( "norm_eps": 1e-05, "rope_theta": 500000.0, } + elif model_name == "llama-3-70b": + data = { + "dim": 8192, + "ffn_dim_multiplier": 1.3, + "multiple_of": 4096, + "n_heads": 64, + "n_kv_heads": 8, + "n_layers": 80, + "norm_eps": 1e-05, + "vocab_size": 128256, + "rope_theta": 500000.0, + } + return ModelArgs( max_seq_len=seqlen, max_batch_size=batch_size, diff --git a/jetstream_pt/third_party/llama/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py index 2385839e..47d4a697 100644 --- a/jetstream_pt/third_party/llama/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -259,13 +259,17 @@ def get_quantized_embedding_weight_to_scaler_map(): } @staticmethod - def get_weight_sharding_type(): + def get_weight_sharding_type(model_name: str = ""): # ParallelEmbedding is col partitioned across the shards. + # VocalParallelEmbedding is row partitioned across the shards. # ColumnParallelLinear is row partitioned across shards due to transpose. # RowParallelLinear is col partitioned across shards due to transpose. # None is no partitioning and tensor should be identical across shards - return { - "tok_embeddings.weight": "ParallelEmbedding", + expected_model_names = ("llama-2", "llama-3") + assert ( + model_name in expected_model_names + ), f"Expected model_name to one of {expected_model_names}" + sharding_dict = { "rope.freqs": None, "attention.wq.weight": "ColumnParallelLinear", "attention.wk.weight": "ColumnParallelLinear", @@ -279,3 +283,8 @@ def get_weight_sharding_type(): "norm.weight": None, "output.weight": "ColumnParallelLinear", } + if model_name == "llama-2": + sharding_dict["tok_embeddings.weight"] = "ParallelEmbedding" + elif model_name == "llama-3": + sharding_dict["tok_embeddings.weight"] = "VocabParallelEmbedding" + return sharding_dict