Skip to content

Commit fab667f

Browse files
committed
cosmetic fixes
1 parent 1a8b746 commit fab667f

File tree

9 files changed

+61
-67
lines changed

9 files changed

+61
-67
lines changed

jetstream_pt/cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def create_engine(devices):
5959
model = fetch_models.instantiate_model_from_repo_id(FLAGS.model_id, env)
6060
if quant_config.enable_weight_quantization:
6161
quantize_model.quantize_model(model, quant_config)
62-
print('====== model =======')
62+
print("====== model =======")
6363
print(model)
6464

6565
weight_shardings = model.get_sharding_annotations()
@@ -225,7 +225,7 @@ def main(argv):
225225
return
226226
else:
227227
print(
228-
"Invalid arguments. please specify 'list', 'serve', or 'interactive'."
228+
"Invalid arguments. please specify 'list', 'serve', or 'interactive'."
229229
)
230230

231231

jetstream_pt/engine.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,6 @@ def _call_model_prefill(self, weights, tokens, input_indexes):
230230
with self._lock:
231231
with torch_xla2.default_env():
232232
res = torch.func.functional_call(self.pt_model, paramst, argst)[0]
233-
jax.debug.print('Prefill result {}', res._elem)
234233
caches_res = [c.state() for c in caches]
235234
return torchjax.from_torch((res, caches_res))
236235

@@ -283,7 +282,6 @@ def prefill(
283282
self.env.temperature,
284283
)
285284
token_out = jnp.reshape(token, (1, 1))
286-
jax.debug.print('TOKEN is {}', token_out)
287285
data = jnp.concatenate(
288286
[
289287
token_out, # First token

jetstream_pt/fetch_models.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
)
1414
from jetstream_pt.third_party.llama import model_exportable as llama_model
1515
from jetstream_pt.third_party.mixtral import model as mixtral_model
16-
from jetstream_pt.third_party.gemma import model as gemma_model
16+
from jetstream_pt.third_party.gemma import model as gemma_model
1717

1818
FLAGS = flags.FLAGS
1919

@@ -168,7 +168,6 @@ def instantiate_model_from_repo_id(
168168
weights = _load_weights(model_dir)
169169
weights = model.convert_hf_weights(weights)
170170

171-
172171
model.load_state_dict(weights, assign=True, strict=False)
173172

174173
return model
@@ -190,11 +189,11 @@ def _hf_download(
190189
local_dir=dest_directory,
191190
local_dir_use_symlinks=False,
192191
token=hf_token,
193-
# allow_patterns=[
194-
# "model-?????-of-?????.safetensors",
195-
# "*.json",
196-
# "*.model",
197-
# ],
192+
allow_patterns=[
193+
"model-?????-of-?????.safetensors",
194+
"*.json",
195+
"*.model",
196+
],
198197
)
199198
except HTTPError as e:
200199
if e.response.status_code == 401:

jetstream_pt/hf_tokenizer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ def encode(self, s: str, **kwargs):
1818
if padding is used.
1919
"""
2020
res = self.tokenizer.encode(s, add_special_tokens=False)
21-
return token_utils.pad_tokens(res, self.bos_id, self.pad_id, jax_padding=True)
21+
return token_utils.pad_tokens(
22+
res, self.bos_id, self.pad_id, jax_padding=True
23+
)
2224

2325
def decode(self, token_ids: list[int], **kwargs) -> str:
2426
"""Processess input token ids to generate a string.

jetstream_pt/layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def create_quantized_from_nn_embedding(
330330
)
331331
weights, scaler, _ = quantize_tensor(float_embedding.weight, 0)
332332
obj.weight = weights
333-
obj.weight_scaler = scaler
333+
obj.weight_scaler = scaler
334334
return obj
335335

336336

jetstream_pt/model_base.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class AttrProperty:
4747

4848
class ModuleBase(torch.nn.Module, metaclass=abc.ABCMeta):
4949
"""nn Module that allows attaching properties.
50-
50+
5151
This class currently serves 2 goals:
5252
1. Allow model to specify alternative names for submodules / weights
5353
this is needed so that it can *also* load HuggingFace checkpoints
@@ -85,7 +85,9 @@ def annotate_sharding(self, name, axis):
8585
"""Set sharding name for a attribute or submodule."""
8686
self.attr_to_property[name].sharding_axis = axis
8787

88-
def convert_hf_weights(self, hf_weights: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
88+
def convert_hf_weights(
89+
self, hf_weights: Dict[str, torch.Tensor]
90+
) -> Dict[str, torch.Tensor]:
8991
"""Load state_dict with hg weights."""
9092
weights = {}
9193
updated_keys = self.get_hf_names_to_real_name()
@@ -94,9 +96,8 @@ def convert_hf_weights(self, hf_weights: Dict[str, torch.Tensor]) -> Dict[str, t
9496
weights[updated] = hf_weights[name]
9597

9698
for name in list(weights.keys()):
97-
if 'inv_freq' in name:
99+
if "inv_freq" in name:
98100
weights.pop(name)
99-
if hasattr(self, 'freqs_cis'):
100-
weights['freqs_cis'] = self.freqs_cis
101+
if hasattr(self, "freqs_cis"):
102+
weights["freqs_cis"] = self.freqs_cis
101103
return weights
102-

jetstream_pt/third_party/gemma/model.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -277,11 +277,11 @@ def __init__(
277277
)
278278

279279
self.annotate_sharding("gate_proj.weight", 0)
280-
self.annotate_sharding('up_proj.weight', 0)
281-
self.annotate_sharding('down_proj.weight', 1)
280+
self.annotate_sharding("up_proj.weight", 0)
281+
self.annotate_sharding("down_proj.weight", 1)
282282
self.annotate_sharding("gate_proj.bias", 0)
283-
self.annotate_sharding('up_proj.bias', 0)
284-
self.annotate_sharding('down_proj.bias', -1)
283+
self.annotate_sharding("up_proj.bias", 0)
284+
self.annotate_sharding("down_proj.bias", -1)
285285
if Linear != torch.nn.Linear:
286286
self.annotate_sharding("gate_proj.weight_scaler", 0)
287287
self.annotate_sharding("up_proj.weight_scaler", 0)
@@ -418,7 +418,6 @@ def forward(
418418
freqs_cis = freqs_cis.reshape(bsz, seqlen, -1)
419419

420420
hidden_states = self.embedder(tokens)
421-
#jax.debug.print('after embedding {}', hidden_states[-1]._elem)
422421
hidden_states = hidden_states * (self.config.hidden_size**0.5)
423422

424423
end = None if start is None else (start + input_pos) % self.env.cache_len
@@ -435,7 +434,6 @@ def forward(
435434
ragged_batch_index=ragged_batch_index,
436435
ragged_block_index=ragged_block_index,
437436
)
438-
#jax.debug.print('hidden after layer {}: {}', i, hidden_states[-1]._elem)
439437
hidden_states = self.norm(hidden_states)
440438

441439
embedder_weight = self.embedder.weight

jetstream_pt/third_party/llama/model_exportable.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"""This version contains modification to make it easier to trace and support batch."""
33

44
from typing import Any, List, Optional
5-
5+
import copy
66
import jax
77
import torch
88
import torch.nn.functional as F
@@ -125,8 +125,6 @@ def __init__(
125125
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps, device=args.device)
126126

127127
self.hf_name("attention", "self_attn")
128-
# We dont want to rename q_proj and k_proj; this is done in
129-
# _load_attention_hf_weights
130128
self.attention.hf_name("wq", "q_proj")
131129
self.attention.hf_name("wk", "k_proj")
132130
self.attention.hf_name("wv", "v_proj")
@@ -140,20 +138,6 @@ def __init__(
140138
self.hf_name("feed_forward", "mlp")
141139
self.hf_name("attention_norm", "input_layernorm")
142140
self.hf_name("ffn_norm", "post_attention_layernorm")
143-
self.attention._register_load_state_dict_pre_hook(
144-
self._load_attention_hf_weights)
145-
146-
def _load_attention_hf_weights(self, state_dict, prefix, *args):
147-
def transform(val, n_heads):
148-
dim1, dim2 = val.shape
149-
return val.reshape(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2)
150-
qname = prefix + "wq.weight"
151-
kname = prefix + "wk.weight"
152-
if qname in state_dict:
153-
state_dict[prefix + 'wq.weight'] = transform(state_dict[qname], self.n_heads)
154-
if kname in state_dict:
155-
state_dict[prefix + 'wk.weight'] = transform(state_dict[kname], self.args.n_kv_heads or self.n_heads)
156-
157141

158142
def forward(
159143
self,
@@ -377,8 +361,23 @@ def from_hf_model_id(cls, model_id, env):
377361
def drop_weight(self, key):
378362
return key.startswith("model")
379363

380-
def shard_weights(self, weights_dict):
381-
"""Shards the weights
364+
def convert_hf_weights(self, hf_weights):
382365

383-
Assumes the weights_dict is a list of XLATensor2
384-
"""
366+
def transform(val, n_heads):
367+
dim1, dim2 = val.shape
368+
return (
369+
val.reshape(n_heads, 2, dim1 // n_heads // 2, dim2)
370+
.transpose(1, 2)
371+
.reshape(dim1, dim2)
372+
)
373+
374+
updated = copy.copy(hf_weights)
375+
376+
for key, value in hf_weights.items():
377+
if "q_proj" in key:
378+
updated[key] = transform(value, self.params.n_heads)
379+
if "k_proj" in key:
380+
updated[key] = transform(
381+
value, self.params.n_kv_heads or self.params.n_heads
382+
)
383+
return super().convert_hf_weights(updated)

jetstream_pt/third_party/mixtral/model.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -165,24 +165,35 @@ def from_hf_model_id(cls, model_id, env):
165165
return model
166166

167167
def convert_hf_weights(self, hf_weights):
168-
updated_weights = super().convert_hf_weights(hf_weights)
169-
# key is layer id, weight name
170-
groupped_by_experts = collections.defaultdict(lambda: [None] * 8)
171-
172168

169+
def transform(val, n_heads):
170+
dim1, dim2 = val.shape
171+
return (
172+
val.reshape(n_heads, 2, dim1 // n_heads // 2, dim2)
173+
.transpose(1, 2)
174+
.reshape(dim1, dim2)
175+
)
176+
177+
groupped_by_experts = collections.defaultdict(lambda: [None] * 8)
173178
updated = copy.copy(hf_weights)
174179
for key, value in hf_weights.items():
175-
if 'block_sparse_moe.experts' in key:
180+
if "block_sparse_moe.experts" in key:
176181
# 0 1 2 3 4 5 6 7
177-
#"model.layers.0.block_sparse_moe.experts.0.w1.weight"
182+
# "model.layers.0.block_sparse_moe.experts.0.w1.weight"
178183
updated.pop(key)
179-
name_pieces = key.split('.')
184+
name_pieces = key.split(".")
180185
assert len(name_pieces) == 8
181186
layer_id = int(name_pieces[2])
182187
expert_id = int(name_pieces[5])
183188
weight_name = name_pieces[6]
184189
groupped_by_experts[(layer_id, weight_name)][expert_id] = value
185190

191+
if "q_proj" in key:
192+
updated[key] = transform(value, self.config.n_head)
193+
if "k_proj" in key:
194+
updated[key] = transform(
195+
value, self.config.n_local_heads or self.config.n_head
196+
)
186197

187198
for (layer_id, weight_name), ws in groupped_by_experts.items():
188199
name = f"model.layers.{layer_id}.block_sparse_moe.cond_ffn.{weight_name}"
@@ -222,20 +233,6 @@ def __init__(self, config: ModelArgs, env, layer_id) -> None:
222233

223234
self.hf_name("attention_norm", "input_layernorm")
224235
self.hf_name("ffn_norm", "post_attention_layernorm")
225-
226-
self.attention._register_load_state_dict_pre_hook(
227-
self._load_attention_hf_weights)
228-
229-
def _load_attention_hf_weights(self, state_dict, prefix, *args):
230-
def transform(val, n_heads):
231-
dim1, dim2 = val.shape
232-
return val.reshape(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2)
233-
qname = prefix + "wq.weight"
234-
kname = prefix + "wk.weight"
235-
if qname in state_dict:
236-
state_dict[prefix + 'wq.weight'] = transform(state_dict[qname], self.config.n_head)
237-
if kname in state_dict:
238-
state_dict[prefix + 'wk.weight'] = transform(state_dict[kname], self.config.n_local_heads or self.config.n_head)
239236

240237
def forward(
241238
self,

0 commit comments

Comments
 (0)