1212 QuantizationConfig ,
1313)
1414from jetstream_pt .third_party .llama import model_exportable as llama_model
15+ from jetstream_pt .third_party .mixtral import model as mixtral_model
1516
1617FLAGS = flags .FLAGS
1718
@@ -38,12 +39,15 @@ class ModelInfo:
3839 num_layers : int
3940 num_heads : int
4041 head_dim : int
42+ n_reps : int # repeatition for GQA
4143
4244
43- _llama2_7 = ModelInfo (llama_model .Transformer , 32 , 32 , 128 )
44- _llama2_13 = ModelInfo (llama_model .Transformer , 40 , 40 , 128 )
45- _llama2_70 = ModelInfo (llama_model .Transformer , 80 , 8 , 128 )
46- _llama3_8 = ModelInfo (llama_model .Transformer , 32 , 8 , 128 )
45+ _llama2_7 = ModelInfo (llama_model .Transformer , 32 , 32 , 128 , 1 )
46+ _llama2_13 = ModelInfo (llama_model .Transformer , 40 , 40 , 128 , 1 )
47+ _llama2_70 = ModelInfo (llama_model .Transformer , 80 , 8 , 128 , 4 )
48+ _llama3_8 = ModelInfo (llama_model .Transformer , 32 , 8 , 128 , 4 )
49+
50+ _mixtral_87 = ModelInfo (mixtral_model .Transformer , 32 , 8 , 128 , 4 )
4751
4852
4953model_id_to_class = {
@@ -57,8 +61,8 @@ class ModelInfo:
5761 "google/gemma-2b-it" : None ,
5862 "google/gemma-7b" : None ,
5963 "google/gemma-7b-it" : None ,
60- "mistralai/Mixtral-8x7B-v0.1" : None ,
61- "mistralai/Mixtral-8x7B-Instruct-v0.1" : None ,
64+ "mistralai/Mixtral-8x7B-v0.1" : _mixtral_87 ,
65+ "mistralai/Mixtral-8x7B-Instruct-v0.1" : _mixtral_87 ,
6266}
6367
6468
@@ -107,6 +111,7 @@ def construct_env_data_from_model_id(
107111 else input_length + output_length
108112 )
109113
114+ model_info = model_id_to_class .get (repo_id )
110115 env_data = JetEngineEnvironmentData (
111116 tokenizer_path = tokenizer_path ,
112117 checkpoint_path = checkpoint_path ,
@@ -119,8 +124,8 @@ def construct_env_data_from_model_id(
119124 bf16_enable = True ,
120125 sharding_config_path = "" ,
121126 shard_on_batch = shard_on_batch ,
127+ n_reps = model_info .n_reps ,
122128 )
123- model_info = model_id_to_class .get (repo_id )
124129 env_data .cache_shape = (
125130 batch_size ,
126131 model_info .num_heads ,
0 commit comments