Skip to content

Commit d72a6da

Browse files
committed
mixtral working,
gemma and llama also works
1 parent b59c542 commit d72a6da

File tree

3 files changed

+81
-37
lines changed

3 files changed

+81
-37
lines changed

jetstream_pt/fetch_models.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -166,18 +166,10 @@ def instantiate_model_from_repo_id(
166166
env.device = "meta"
167167
model = model_info.model_class.from_hf_model_id(repo_id, env)
168168
weights = _load_weights(model_dir)
169-
updated_keys = model.get_hf_names_to_real_name()
170-
for name, updated in updated_keys.items():
171-
if name in weights:
172-
val = weights.pop(name)
173-
weights[updated] = val
169+
weights = model.convert_hf_weights(weights)
174170

175171

176-
for name in list(weights.keys()):
177-
if 'inv_freq' in name:
178-
weights.pop(name)
179-
weights['freqs_cis'] = model.freqs_cis
180-
model.load_state_dict(weights, assign=True, strict=True)
172+
model.load_state_dict(weights, assign=True, strict=False)
181173

182174
return model
183175
## QQ do i need to set the weights onto the model?
@@ -198,11 +190,11 @@ def _hf_download(
198190
local_dir=dest_directory,
199191
local_dir_use_symlinks=False,
200192
token=hf_token,
201-
allow_patterns=[
202-
"model-?????-of-?????.safetensors",
203-
"*.json",
204-
"*.model",
205-
],
193+
# allow_patterns=[
194+
# "model-?????-of-?????.safetensors",
195+
# "*.json",
196+
# "*.model",
197+
# ],
206198
)
207199
except HTTPError as e:
208200
if e.response.status_code == 401:

jetstream_pt/model_base.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,18 @@ class AttrProperty:
4646

4747

4848
class ModuleBase(torch.nn.Module, metaclass=abc.ABCMeta):
49-
"""nn Module that allows attaching properties"""
49+
"""nn Module that allows attaching properties.
50+
51+
This class currently serves 2 goals:
52+
1. Allow model to specify alternative names for submodules / weights
53+
this is needed so that it can *also* load HuggingFace checkpoints
54+
without need to do massive rewrites.
55+
56+
2. Allow model to attach information to weights, such as sharding config.
57+
58+
Quantization config could be another thing to attach, but right now it's not used
59+
this way.
60+
"""
5061

5162
attr_to_property: Dict[str, Any]
5263

@@ -74,6 +85,18 @@ def annotate_sharding(self, name, axis):
7485
"""Set sharding name for a attribute or submodule."""
7586
self.attr_to_property[name].sharding_axis = axis
7687

77-
def drop_weight(self, key):
78-
"""list out names to discard."""
79-
return False
88+
def convert_hf_weights(self, hf_weights: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
89+
"""Load state_dict with hg weights."""
90+
weights = {}
91+
updated_keys = self.get_hf_names_to_real_name()
92+
for name, updated in updated_keys.items():
93+
if name in hf_weights:
94+
weights[updated] = hf_weights[name]
95+
96+
for name in list(weights.keys()):
97+
if 'inv_freq' in name:
98+
weights.pop(name)
99+
if hasattr(self, 'freqs_cis'):
100+
weights['freqs_cis'] = self.freqs_cis
101+
return weights
102+

jetstream_pt/third_party/mixtral/model.py

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
15+
import collections
16+
import copy
1617
from dataclasses import dataclass
1718
from typing import Optional, List, Any
1819

@@ -163,6 +164,32 @@ def from_hf_model_id(cls, model_id, env):
163164
model = cls(args, env)
164165
return model
165166

167+
def convert_hf_weights(self, hf_weights):
168+
updated_weights = super().convert_hf_weights(hf_weights)
169+
# key is layer id, weight name
170+
groupped_by_experts = collections.defaultdict(lambda: [None] * 8)
171+
172+
173+
updated = copy.copy(hf_weights)
174+
for key, value in hf_weights.items():
175+
if 'block_sparse_moe.experts' in key:
176+
# 0 1 2 3 4 5 6 7
177+
#"model.layers.0.block_sparse_moe.experts.0.w1.weight"
178+
updated.pop(key)
179+
name_pieces = key.split('.')
180+
assert len(name_pieces) == 8
181+
layer_id = int(name_pieces[2])
182+
expert_id = int(name_pieces[5])
183+
weight_name = name_pieces[6]
184+
groupped_by_experts[(layer_id, weight_name)][expert_id] = value
185+
186+
187+
for (layer_id, weight_name), ws in groupped_by_experts.items():
188+
name = f"model.layers.{layer_id}.block_sparse_moe.cond_ffn.{weight_name}"
189+
updated[name] = torch.stack(ws)
190+
res = super().convert_hf_weights(updated)
191+
return res
192+
166193

167194
class TransformerBlock(ModuleBase):
168195

@@ -177,6 +204,7 @@ def __init__(self, config: ModelArgs, env, layer_id) -> None:
177204
device=config.device,
178205
layer_id=layer_id,
179206
)
207+
self.config = config
180208
self.hf_name("attention", "self_attn")
181209
self.attention.hf_name("wq", "q_proj")
182210
self.attention.hf_name("wk", "k_proj")
@@ -194,19 +222,20 @@ def __init__(self, config: ModelArgs, env, layer_id) -> None:
194222

195223
self.hf_name("attention_norm", "input_layernorm")
196224
self.hf_name("ffn_norm", "post_attention_layernorm")
197-
self._register_load_state_dict_pre_hook(self.load_hook)
198-
199-
def load_hook(self, state_dict, prefix, *args):
200-
if prefix + "block_sparse_moe.experts" in state_dict:
201-
w1s, w2s, w3s = [], [], []
202-
for i in range(8):
203-
exp_prefix = f"{prefix}block_sparse_moe.experts.{i}."
204-
w1s.append(state_dict.pop(exp_prefix + ".w1"))
205-
w2s.append(state_dict.pop(exp_prefix + ".w2"))
206-
w3s.append(state_dict.pop(exp_prefix + ".w3"))
207-
state_dict[prefix + "block_sparse_moe.cond_ffn.w1"] = torch.cat(w1s)
208-
state_dict[prefix + "block_sparse_moe.cond_ffn.w2"] = torch.cat(w2s)
209-
state_dict[prefix + "block_sparse_moe.cond_ffn.w3"] = torch.cat(w3s)
225+
226+
self.attention._register_load_state_dict_pre_hook(
227+
self._load_attention_hf_weights)
228+
229+
def _load_attention_hf_weights(self, state_dict, prefix, *args):
230+
def transform(val, n_heads):
231+
dim1, dim2 = val.shape
232+
return val.reshape(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2)
233+
qname = prefix + "wq.weight"
234+
kname = prefix + "wk.weight"
235+
if qname in state_dict:
236+
state_dict[prefix + 'wq.weight'] = transform(state_dict[qname], self.config.n_head)
237+
if kname in state_dict:
238+
state_dict[prefix + 'wk.weight'] = transform(state_dict[kname], self.config.n_local_heads or self.config.n_head)
210239

211240
def forward(
212241
self,
@@ -383,14 +412,14 @@ def get_quantized_version(self):
383412
"""Return quantized version of this class."""
384413
quant_version = Int8ConditionalFeedForward(self.config)
385414
w1, w1_scaler, _ = quantize.quantize_tensor(self.w1, 2)
386-
w2, w2_scaler, _ = quantize.quantize_tensor(self.w2, 1)
415+
w2, w2_scaler, _ = quantize.quantize_tensor(self.w2, 2)
387416
w3, w3_scaler, _ = quantize.quantize_tensor(self.w3, 2)
388417
quant_version.w1 = w1
389418
quant_version.w2 = w2
390419
quant_version.w3 = w3
391-
quant_version.w1_scaler = w1_scaler
392-
quant_version.w2_scaler = w2_scaler
393-
quant_version.w3_scaler = w3_scaler
420+
quant_version.w1_scaler = w1_scaler.squeeze(2)
421+
quant_version.w2_scaler = w2_scaler.squeeze(2)
422+
quant_version.w3_scaler = w3_scaler.squeeze(2)
394423
return quant_version
395424

396425

0 commit comments

Comments
 (0)