1212# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313# See the License for the specific language governing permissions and
1414# limitations under the License.
15-
15+ import collections
16+ import copy
1617from dataclasses import dataclass
1718from typing import Optional , List , Any
1819
@@ -163,6 +164,32 @@ def from_hf_model_id(cls, model_id, env):
163164 model = cls (args , env )
164165 return model
165166
167+ def convert_hf_weights (self , hf_weights ):
168+ updated_weights = super ().convert_hf_weights (hf_weights )
169+ # key is layer id, weight name
170+ groupped_by_experts = collections .defaultdict (lambda : [None ] * 8 )
171+
172+
173+ updated = copy .copy (hf_weights )
174+ for key , value in hf_weights .items ():
175+ if 'block_sparse_moe.experts' in key :
176+ # 0 1 2 3 4 5 6 7
177+ #"model.layers.0.block_sparse_moe.experts.0.w1.weight"
178+ updated .pop (key )
179+ name_pieces = key .split ('.' )
180+ assert len (name_pieces ) == 8
181+ layer_id = int (name_pieces [2 ])
182+ expert_id = int (name_pieces [5 ])
183+ weight_name = name_pieces [6 ]
184+ groupped_by_experts [(layer_id , weight_name )][expert_id ] = value
185+
186+
187+ for (layer_id , weight_name ), ws in groupped_by_experts .items ():
188+ name = f"model.layers.{ layer_id } .block_sparse_moe.cond_ffn.{ weight_name } "
189+ updated [name ] = torch .stack (ws )
190+ res = super ().convert_hf_weights (updated )
191+ return res
192+
166193
167194class TransformerBlock (ModuleBase ):
168195
@@ -177,6 +204,7 @@ def __init__(self, config: ModelArgs, env, layer_id) -> None:
177204 device = config .device ,
178205 layer_id = layer_id ,
179206 )
207+ self .config = config
180208 self .hf_name ("attention" , "self_attn" )
181209 self .attention .hf_name ("wq" , "q_proj" )
182210 self .attention .hf_name ("wk" , "k_proj" )
@@ -194,19 +222,20 @@ def __init__(self, config: ModelArgs, env, layer_id) -> None:
194222
195223 self .hf_name ("attention_norm" , "input_layernorm" )
196224 self .hf_name ("ffn_norm" , "post_attention_layernorm" )
197- self ._register_load_state_dict_pre_hook (self .load_hook )
198-
199- def load_hook (self , state_dict , prefix , * args ):
200- if prefix + "block_sparse_moe.experts" in state_dict :
201- w1s , w2s , w3s = [], [], []
202- for i in range (8 ):
203- exp_prefix = f"{ prefix } block_sparse_moe.experts.{ i } ."
204- w1s .append (state_dict .pop (exp_prefix + ".w1" ))
205- w2s .append (state_dict .pop (exp_prefix + ".w2" ))
206- w3s .append (state_dict .pop (exp_prefix + ".w3" ))
207- state_dict [prefix + "block_sparse_moe.cond_ffn.w1" ] = torch .cat (w1s )
208- state_dict [prefix + "block_sparse_moe.cond_ffn.w2" ] = torch .cat (w2s )
209- state_dict [prefix + "block_sparse_moe.cond_ffn.w3" ] = torch .cat (w3s )
225+
226+ self .attention ._register_load_state_dict_pre_hook (
227+ self ._load_attention_hf_weights )
228+
229+ def _load_attention_hf_weights (self , state_dict , prefix , * args ):
230+ def transform (val , n_heads ):
231+ dim1 , dim2 = val .shape
232+ return val .reshape (n_heads , 2 , dim1 // n_heads // 2 , dim2 ).transpose (1 , 2 ).reshape (dim1 , dim2 )
233+ qname = prefix + "wq.weight"
234+ kname = prefix + "wk.weight"
235+ if qname in state_dict :
236+ state_dict [prefix + 'wq.weight' ] = transform (state_dict [qname ], self .config .n_head )
237+ if kname in state_dict :
238+ state_dict [prefix + 'wk.weight' ] = transform (state_dict [kname ], self .config .n_local_heads or self .config .n_head )
210239
211240 def forward (
212241 self ,
@@ -383,14 +412,14 @@ def get_quantized_version(self):
383412 """Return quantized version of this class."""
384413 quant_version = Int8ConditionalFeedForward (self .config )
385414 w1 , w1_scaler , _ = quantize .quantize_tensor (self .w1 , 2 )
386- w2 , w2_scaler , _ = quantize .quantize_tensor (self .w2 , 1 )
415+ w2 , w2_scaler , _ = quantize .quantize_tensor (self .w2 , 2 )
387416 w3 , w3_scaler , _ = quantize .quantize_tensor (self .w3 , 2 )
388417 quant_version .w1 = w1
389418 quant_version .w2 = w2
390419 quant_version .w3 = w3
391- quant_version .w1_scaler = w1_scaler
392- quant_version .w2_scaler = w2_scaler
393- quant_version .w3_scaler = w3_scaler
420+ quant_version .w1_scaler = w1_scaler . squeeze ( 2 )
421+ quant_version .w2_scaler = w2_scaler . squeeze ( 2 )
422+ quant_version .w3_scaler = w3_scaler . squeeze ( 2 )
394423 return quant_version
395424
396425
0 commit comments