2222from torch .nn import functional as F
2323from .config import ModelArgs , find_multiple
2424from jetstream_pt .layers import Attention , get_quantized_linear_layer , get_quantized_embedding_layer
25+ from jetstream_pt .model_base import ModuleBase
2526
2627import jax
2728
2829
29- class Transformer (nn . Module ):
30+ class Transformer (ModuleBase ):
3031
3132 def __init__ (self , config : ModelArgs , env ) -> None :
3233 super ().__init__ ()
@@ -37,6 +38,7 @@ def __init__(self, config: ModelArgs, env) -> None:
3738 self .tok_embeddings = Embedding (
3839 config .vocab_size , config .dim , device = config .device
3940 )
41+
4042 self .layers = nn .ModuleList (
4143 TransformerBlock (config , env , layer_id )
4244 for layer_id in range (config .n_layer )
@@ -47,6 +49,14 @@ def __init__(self, config: ModelArgs, env) -> None:
4749 config .dim , config .vocab_size , bias = False , device = config .device
4850 )
4951
52+ self .hf_name ("norm" , "model.norm" )
53+ self .hf_name ("layers" , "model.layers" )
54+ self .hf_name ('output' , 'lm_head' )
55+ self .hf_name ('tok_embeddings' , 'model.embed_tokens' )
56+
57+ self .annotate_sharding ("tok_embeddings.weight" , 1 )
58+ self .annotate_sharding ("output.weight" , 0 )
59+
5060 self .max_batch_size = - 1
5161 self .max_seq_length = - 1
5262
@@ -140,8 +150,20 @@ def get_weight_sharding_type():
140150 "output.weight" : "ColumnParallelLinear" ,
141151 }
142152
153+ @classmethod
154+ def from_hf_model_id (cls , model_id , env ):
155+ name = {
156+ "mistralai/Mixtral-8x7B-v0.1" : "Mixtral-8x7B-v0.1" ,
157+ "mistralai/Mixtral-8x7B-Instruct-v0.1" : "Mixtral-8x7B-v0.1" ,
158+ }.get (model_id )
159+ assert name
160+ args = ModelArgs .from_name (name )
161+ args .device = 'meta'
162+ model = cls (args , env )
163+ return model
143164
144- class TransformerBlock (nn .Module ):
165+
166+ class TransformerBlock (ModuleBase ):
145167
146168 def __init__ (self , config : ModelArgs , env , layer_id ) -> None :
147169 super ().__init__ ()
@@ -154,10 +176,37 @@ def __init__(self, config: ModelArgs, env, layer_id) -> None:
154176 device = config .device ,
155177 layer_id = layer_id ,
156178 )
179+ self .hf_name ("attention" , "self_attn" )
180+ self .attention .hf_name ("wq" , "q_proj" )
181+ self .attention .hf_name ("wk" , "k_proj" )
182+ self .attention .hf_name ("wv" , "v_proj" )
183+ self .attention .hf_name ("wo" , "o_proj" )
184+
185+ self .attention .annotate_sharding ("wq" , 0 )
186+ self .attention .annotate_sharding ("wk" , 0 )
187+ self .attention .annotate_sharding ("wv" , 0 )
188+ self .attention .annotate_sharding ("wo" , 1 )
189+
157190 self .block_sparse_moe = MOEFeedForward (config , config .device , env )
158191 self .ffn_norm = RMSNorm (config .dim , config .norm_eps )
159192 self .attention_norm = RMSNorm (config .dim , config .norm_eps )
160193
194+ self .hf_name ("attention_norm" , "input_layernorm" )
195+ self .hf_name ("ffn_norm" , "post_attention_layernorm" )
196+ self ._register_load_state_dict_pre_hook (self .load_hook )
197+
198+ def load_hook (self , state_dict , prefix , * args ):
199+ if prefix + "block_sparse_moe.experts" in state_dict :
200+ w1s , w2s , w3s = [], [], []
201+ for i in range (8 ):
202+ exp_prefix = f"{ prefix } block_sparse_moe.experts.{ i } ."
203+ w1s .append (state_dict .pop (exp_prefix + ".w1" ))
204+ w2s .append (state_dict .pop (exp_prefix + ".w2" ))
205+ w3s .append (state_dict .pop (exp_prefix + ".w3" ))
206+ state_dict [prefix + "block_sparse_moe.cond_ffn.w1" ] = torch .cat (w1s )
207+ state_dict [prefix + "block_sparse_moe.cond_ffn.w2" ] = torch .cat (w2s )
208+ state_dict [prefix + "block_sparse_moe.cond_ffn.w3" ] = torch .cat (w3s )
209+
161210 def forward (
162211 self ,
163212 x : Tensor ,
@@ -189,7 +238,7 @@ def forward(
189238 return out
190239
191240
192- class Int8ConditionalFeedForward (nn . Module ):
241+ class Int8ConditionalFeedForward (ModuleBase ):
193242
194243 def __init__ (self , config ):
195244 super ().__init__ ()
@@ -215,12 +264,20 @@ def __init__(self, config):
215264 self .register_buffer ("w2" , w2 )
216265 self .register_buffer ("w3" , w3 )
217266
267+ self .annotate_sharding ("w1" , 1 )
268+ self .annotate_sharding ("w2" , 2 )
269+ self .annotate_sharding ("w3" , 1 )
270+
218271 w1_scaler = torch .empty (config .num_experts , config .intermediate_size )
219272 w2_scaler = torch .empty (config .num_experts , config .dim )
220273 w3_scaler = torch .empty (config .num_experts , config .intermediate_size )
274+
221275 self .register_buffer ("w1_scaler" , w1_scaler )
222276 self .register_buffer ("w2_scaler" , w2_scaler )
223277 self .register_buffer ("w3_scaler" , w3_scaler )
278+ self .annotate_sharding ("w1_scaler" , 1 )
279+ self .annotate_sharding ("w2_scaler" , - 1 )
280+ self .annotate_sharding ("w3_scaler" , 1 )
224281
225282 def forward (self , x : Tensor , expert_indices : Tensor ) -> Tensor :
226283 seq_len = x .shape [0 ]
@@ -266,7 +323,7 @@ def forward_for_long_seq_len(self, x, expert_indices):
266323 return expert_outs [seq_indexes , expert_indices ]
267324
268325
269- class ConditionalFeedForward (nn . Module ):
326+ class ConditionalFeedForward (ModuleBase ):
270327
271328 def __init__ (self , config ):
272329 super ().__init__ ()
@@ -280,6 +337,9 @@ def __init__(self, config):
280337 self .w3 = nn .Parameter (
281338 torch .empty (config .num_experts , config .intermediate_size , config .dim )
282339 )
340+ self .annotate_sharding ("w1" , 1 )
341+ self .annotate_sharding ("w2" , 2 )
342+ self .annotate_sharding ("w3" , 1 )
283343
284344 def forward (self , x : Tensor , expert_indices : Tensor ) -> Tensor :
285345 seq_len = x .shape [0 ]
@@ -318,7 +378,7 @@ def forward_for_long_seq_len(self, x, expert_indices):
318378 return expert_outs [seq_indexes , expert_indices ]
319379
320380
321- class MOEFeedForward (nn . Module ):
381+ class MOEFeedForward (ModuleBase ):
322382
323383 def __init__ (self , config , device , env ) -> None :
324384 super ().__init__ ()
@@ -352,7 +412,7 @@ def forward(self, x: Tensor) -> Tensor:
352412 return expert_outs
353413
354414
355- class RMSNorm (nn . Module ):
415+ class RMSNorm (ModuleBase ):
356416
357417 def __init__ (self , dim : int , eps : float = 1e-5 ):
358418 super ().__init__ ()
0 commit comments