22"""This version contains modification to make it easier to trace and support batch."""
33
44from typing import Any , List , Optional
5-
5+ import copy
66import jax
77import torch
88import torch .nn .functional as F
@@ -125,8 +125,6 @@ def __init__(
125125 self .ffn_norm = RMSNorm (args .dim , eps = args .norm_eps , device = args .device )
126126
127127 self .hf_name ("attention" , "self_attn" )
128- # We dont want to rename q_proj and k_proj; this is done in
129- # _load_attention_hf_weights
130128 self .attention .hf_name ("wq" , "q_proj" )
131129 self .attention .hf_name ("wk" , "k_proj" )
132130 self .attention .hf_name ("wv" , "v_proj" )
@@ -140,20 +138,6 @@ def __init__(
140138 self .hf_name ("feed_forward" , "mlp" )
141139 self .hf_name ("attention_norm" , "input_layernorm" )
142140 self .hf_name ("ffn_norm" , "post_attention_layernorm" )
143- self .attention ._register_load_state_dict_pre_hook (
144- self ._load_attention_hf_weights )
145-
146- def _load_attention_hf_weights (self , state_dict , prefix , * args ):
147- def transform (val , n_heads ):
148- dim1 , dim2 = val .shape
149- return val .reshape (n_heads , 2 , dim1 // n_heads // 2 , dim2 ).transpose (1 , 2 ).reshape (dim1 , dim2 )
150- qname = prefix + "wq.weight"
151- kname = prefix + "wk.weight"
152- if qname in state_dict :
153- state_dict [prefix + 'wq.weight' ] = transform (state_dict [qname ], self .n_heads )
154- if kname in state_dict :
155- state_dict [prefix + 'wk.weight' ] = transform (state_dict [kname ], self .args .n_kv_heads or self .n_heads )
156-
157141
158142 def forward (
159143 self ,
@@ -377,8 +361,23 @@ def from_hf_model_id(cls, model_id, env):
377361 def drop_weight (self , key ):
378362 return key .startswith ("model" )
379363
380- def shard_weights (self , weights_dict ):
381- """Shards the weights
364+ def convert_hf_weights (self , hf_weights ):
382365
383- Assumes the weights_dict is a list of XLATensor2
384- """
366+ def transform (val , n_heads ):
367+ dim1 , dim2 = val .shape
368+ return (
369+ val .reshape (n_heads , 2 , dim1 // n_heads // 2 , dim2 )
370+ .transpose (1 , 2 )
371+ .reshape (dim1 , dim2 )
372+ )
373+
374+ updated = copy .copy (hf_weights )
375+
376+ for key , value in hf_weights .items ():
377+ if "q_proj" in key :
378+ updated [key ] = transform (value , self .params .n_heads )
379+ if "k_proj" in key :
380+ updated [key ] = transform (
381+ value , self .params .n_kv_heads or self .params .n_heads
382+ )
383+ return super ().convert_hf_weights (updated )
0 commit comments