File tree Expand file tree Collapse file tree 4 files changed +36
-8
lines changed
jetstream_pt/third_party/llama Expand file tree Collapse file tree 4 files changed +36
-8
lines changed Original file line number Diff line number Diff line change @@ -98,6 +98,11 @@ python run_interactive.py --size=13b --model_name=$model_name --batch_size=64 --
9898python run_interactive.py --size=8b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
9999```
100100
101+ ## Llama-3 70b
102+ ``` bash
103+ python run_interactive.py --size=70b --model_name=$model_name --batch_size=8 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
104+ ```
105+
101106## Gemma 7b
102107``` bash
103108python run_interactive.py --model_name=$model_name --size=7b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name .yaml
Original file line number Diff line number Diff line change @@ -186,20 +186,21 @@ def _merge_llama_weights(
186186 f"{ len (tensors )} shards (shape = { tensors [0 ].shape } ) for { key } )"
187187 )
188188 state_dict_for_key = {}
189- weight_sharding_type = (
190- llama_model .Transformer .get_weight_sharding_type ().items ()
191- )
189+
190+ weight_sharding_type = llama_model .Transformer .get_weight_sharding_type (
191+ model_name = FLAGS .model_name
192+ ).items ()
192193 for pattern , kind in weight_sharding_type :
193194 if not key .endswith (pattern ):
194195 continue
195196 with torch .no_grad ():
196197 if kind in ("ParallelEmbedding" , "RowParallelLinear" ):
197198 state_dict_for_key [key ] = torch .cat (tensors , 1 )
198- elif kind == "ColumnParallelLinear" :
199+ elif kind in ( "ColumnParallelLinear" , "VocabParallelEmbedding" ) :
199200 state_dict_for_key [key ] = torch .cat (tensors , 0 )
200201 else :
201202 if not all (
202- torch .allclose (tensors [0 ], tensor , atol = 1e-6 )
203+ torch .allclose (tensors [0 ], tensor , atol = 1e-2 )
203204 for tensor in tensors [1 :]
204205 ):
205206 raise ValueError (
Original file line number Diff line number Diff line change @@ -90,6 +90,19 @@ def get_arg(
9090 "norm_eps" : 1e-05 ,
9191 "rope_theta" : 500000.0 ,
9292 }
93+ elif model_name == "llama-3-70b" :
94+ data = {
95+ "dim" : 8192 ,
96+ "ffn_dim_multiplier" : 1.3 ,
97+ "multiple_of" : 4096 ,
98+ "n_heads" : 64 ,
99+ "n_kv_heads" : 8 ,
100+ "n_layers" : 80 ,
101+ "norm_eps" : 1e-05 ,
102+ "vocab_size" : 128256 ,
103+ "rope_theta" : 500000.0 ,
104+ }
105+
93106 return ModelArgs (
94107 max_seq_len = seqlen ,
95108 max_batch_size = batch_size ,
Original file line number Diff line number Diff line change @@ -259,13 +259,17 @@ def get_quantized_embedding_weight_to_scaler_map():
259259 }
260260
261261 @staticmethod
262- def get_weight_sharding_type ():
262+ def get_weight_sharding_type (model_name : str = "" ):
263263 # ParallelEmbedding is col partitioned across the shards.
264+ # VocalParallelEmbedding is row partitioned across the shards.
264265 # ColumnParallelLinear is row partitioned across shards due to transpose.
265266 # RowParallelLinear is col partitioned across shards due to transpose.
266267 # None is no partitioning and tensor should be identical across shards
267- return {
268- "tok_embeddings.weight" : "ParallelEmbedding" ,
268+ expected_model_names = ("llama-2" , "llama-3" )
269+ assert (
270+ model_name in expected_model_names
271+ ), f"Expected model_name to one of { expected_model_names } "
272+ sharding_dict = {
269273 "rope.freqs" : None ,
270274 "attention.wq.weight" : "ColumnParallelLinear" ,
271275 "attention.wk.weight" : "ColumnParallelLinear" ,
@@ -279,3 +283,8 @@ def get_weight_sharding_type():
279283 "norm.weight" : None ,
280284 "output.weight" : "ColumnParallelLinear" ,
281285 }
286+ if model_name == "llama-2" :
287+ sharding_dict ["tok_embeddings.weight" ] = "ParallelEmbedding"
288+ elif model_name == "llama-3" :
289+ sharding_dict ["tok_embeddings.weight" ] = "VocabParallelEmbedding"
290+ return sharding_dict
You can’t perform that action at this time.
0 commit comments