Skip to content

Commit 9c555a8

Browse files
authored
Add mixtral support to new CLI (#174)
* Add mixtral support to new CLI * lint * add new deps
1 parent 7307541 commit 9c555a8

File tree

11 files changed

+190
-22
lines changed

11 files changed

+190
-22
lines changed

install_everything.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ pip install flax
2828
pip install tensorflow-text
2929
pip install tensorflow
3030
pip install huggingface_hub
31+
pip install transformers
3132

3233
pip install ray[default]==2.33.0
3334
# torch cpu

install_everything_gpu.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ pip show torch_xla2 && pip uninstall -y torch_xla2
2727
pip install flax==0.8.4
2828
pip install tensorflow-text
2929
pip install tensorflow
30+
pip install transformers
3031

3132
pip install ray[default]==2.22.0
3233
# torch cpu

jetstream_pt/cli.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from jetstream.core import server_lib
88
from jetstream.core.config_lib import ServerConfig, MetricsServerConfig
99
import torch
10+
from transformers import AutoTokenizer
1011

1112
from jetstream_pt import fetch_models
1213
from jetstream_pt import environment, engine, quantize_model, torchjax
@@ -25,13 +26,13 @@
2526

2627
def shard_weights(env, weights, weight_shardings):
2728
"""Shard weights according to weight_shardings"""
28-
for k, v in weight_shardings.items():
29-
print("SHARDING", k, v)
3029
sharded = {}
3130
for key, val in weights.items():
3231
sharding = env.sharding_by_axis(weight_shardings.get(key, -1))
3332
with jax.default_device(jax.devices("cpu")[0]):
3433
arr = torch_xla2.tensor.t2j(val)
34+
35+
print("SHARDING", key, sharding)
3536
arr = jax.device_put(arr, sharding)
3637
sharded[key] = torchjax.to_torch(arr)
3738
return sharded
@@ -48,17 +49,16 @@ def create_engine(devices):
4849
FLAGS.max_output_length,
4950
quant_config.enable_weight_quantization,
5051
)
52+
tokenizer = AutoTokenizer.from_pretrained(FLAGS.model_id)
5153
env = environment.JetEngineEnvironment(env_data)
54+
env.hf_tokenizer = tokenizer
5255
model = fetch_models.instantiate_model_from_repo_id(FLAGS.model_id, env)
56+
if quant_config.enable_weight_quantization:
57+
quantize_model.quantize_model(model, quant_config)
5358

5459
weight_shardings = model.get_sharding_annotations()
5560
sharded_weights = shard_weights(env, model.state_dict(), weight_shardings)
5661

57-
if quant_config.enable_weight_quantization:
58-
model.load_state_dict(sharded_weights, assign=True, strict=False)
59-
quantize_model.quantize_model(model, quant_config)
60-
sharded_weights = model.state_dict()
61-
6262
return engine.PyTorchEngine(
6363
pt_model=model,
6464
env=env,

jetstream_pt/engine.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from jetstream_pt import cache_manager
3737
from jetstream_pt import quantize
3838
from jetstream_pt import torchjax
39+
from jetstream_pt.hf_tokenizer import HFTokenizerAdapter
3940
from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData, QuantizationConfig
4041
from jetstream_pt.third_party.llama import model_exportable as llama_model, model_args
4142
from jetstream_pt.third_party.gemma import config as gemma_config, model as gemma_model
@@ -705,6 +706,8 @@ def get_tokenizer(self) -> tokenizer_pb2.TokenizerParameters:
705706
def build_tokenizer(
706707
self, metadata: tokenizer_pb2.TokenizerParameters # pylint: disable=all
707708
) -> tokenizer_api.Tokenizer:
709+
if self.env.hf_tokenizer is not None:
710+
return HFTokenizerAdapter(self.env.hf_tokenizer)
708711
if "llama-3" in self.env.model_type:
709712
return token_utils.TikToken(metadata)
710713

jetstream_pt/environment.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,10 @@ def __init__(self, data: JetEngineEnvironmentData):
141141
self.testing_seed = self._data.testing_seed
142142
self.ring_buffer = self._data.ring_buffer
143143

144+
# If not None, then use this tokenizer without
145+
# trying to create new ones.
146+
self.hf_tokenizer = None
147+
144148
if not self.ring_buffer:
145149
self.lazy_cache_update = True
146150
self.ragged_mha = True

jetstream_pt/fetch_models.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
QuantizationConfig,
1313
)
1414
from jetstream_pt.third_party.llama import model_exportable as llama_model
15+
from jetstream_pt.third_party.mixtral import model as mixtral_model
1516

1617
FLAGS = flags.FLAGS
1718

@@ -38,12 +39,15 @@ class ModelInfo:
3839
num_layers: int
3940
num_heads: int
4041
head_dim: int
42+
n_reps: int # repeatition for GQA
4143

4244

43-
_llama2_7 = ModelInfo(llama_model.Transformer, 32, 32, 128)
44-
_llama2_13 = ModelInfo(llama_model.Transformer, 40, 40, 128)
45-
_llama2_70 = ModelInfo(llama_model.Transformer, 80, 8, 128)
46-
_llama3_8 = ModelInfo(llama_model.Transformer, 32, 8, 128)
45+
_llama2_7 = ModelInfo(llama_model.Transformer, 32, 32, 128, 1)
46+
_llama2_13 = ModelInfo(llama_model.Transformer, 40, 40, 128, 1)
47+
_llama2_70 = ModelInfo(llama_model.Transformer, 80, 8, 128, 4)
48+
_llama3_8 = ModelInfo(llama_model.Transformer, 32, 8, 128, 4)
49+
50+
_mixtral_87 = ModelInfo(mixtral_model.Transformer, 32, 8, 128, 4)
4751

4852

4953
model_id_to_class = {
@@ -57,8 +61,8 @@ class ModelInfo:
5761
"google/gemma-2b-it": None,
5862
"google/gemma-7b": None,
5963
"google/gemma-7b-it": None,
60-
"mistralai/Mixtral-8x7B-v0.1": None,
61-
"mistralai/Mixtral-8x7B-Instruct-v0.1": None,
64+
"mistralai/Mixtral-8x7B-v0.1": _mixtral_87,
65+
"mistralai/Mixtral-8x7B-Instruct-v0.1": _mixtral_87,
6266
}
6367

6468

@@ -107,6 +111,7 @@ def construct_env_data_from_model_id(
107111
else input_length + output_length
108112
)
109113

114+
model_info = model_id_to_class.get(repo_id)
110115
env_data = JetEngineEnvironmentData(
111116
tokenizer_path=tokenizer_path,
112117
checkpoint_path=checkpoint_path,
@@ -119,8 +124,8 @@ def construct_env_data_from_model_id(
119124
bf16_enable=True,
120125
sharding_config_path="",
121126
shard_on_batch=shard_on_batch,
127+
n_reps=model_info.n_reps,
122128
)
123-
model_info = model_id_to_class.get(repo_id)
124129
env_data.cache_shape = (
125130
batch_size,
126131
model_info.num_heads,

jetstream_pt/hf_tokenizer.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from jetstream.engine import tokenizer_api
2+
3+
4+
class HFTokenizerAdapter(tokenizer_api.Tokenizer):
5+
"""Implementation of Tokenizer interface backed by HF tokenizer."""
6+
7+
def __init__(self, tokenizer):
8+
self.tokenizer = tokenizer
9+
10+
def encode(self, s: str, **kwargs):
11+
"""Tokenize a string.
12+
Args:
13+
s: String to tokenize.
14+
**kwargs: Additional keyword arguments.
15+
Returns:
16+
tokens: Tokenized into integers.
17+
true_length: Actual length of the non-padded sequence
18+
if padding is used.
19+
"""
20+
return self(s)
21+
22+
def decode(self, token_ids: list[int], **kwargs) -> str:
23+
"""Processess input token ids to generate a string.
24+
Args:
25+
token_ids: List of token ids.
26+
**kwargs: Additional keyword arguments.
27+
Returns:
28+
str: String generated from the token ids.
29+
"""
30+
return self.decode(token_ids)
31+
32+
@property
33+
def pad_id(self) -> int:
34+
"""ID of the pad token."""
35+
return self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else 0
36+
37+
@property
38+
def eos_id(self) -> int:
39+
"""ID of EOS token."""
40+
return self.tokenizer.eos_token_id
41+
42+
@property
43+
def bos_id(self) -> int:
44+
"""ID of BOS token."""
45+
return self.tokenizer.bos_token_id
46+
47+
@property
48+
def stop_tokens(self) -> set[int]:
49+
"""ID of the stop token."""
50+
return {self.eos_id, self.pad_id}

jetstream_pt/layers.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,20 @@ def get_quantized_embedding_layer(config: "QuantizationConfig"):
320320
return Int8Embedding
321321

322322

323+
def create_quantized_from_nn_embedding(
324+
float_embedding: nn.Embedding, config: "QuantizationConfig"
325+
):
326+
clazz_ = get_quantized_embedding_layer(config)
327+
obj = clazz_(
328+
float_embedding.num_embeddings,
329+
float_embedding.embedding_dim,
330+
)
331+
weights, scaler, _ = quantize_tensor(float_embedding.weight, 1)
332+
obj.weight = weights
333+
obj.scaler = scaler
334+
return obj
335+
336+
323337
class RMSNorm(torch.nn.Module):
324338
"""RMSNorm module."""
325339

jetstream_pt/quantize_model.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,24 @@
11
import torch
2-
from .layers import create_quantized_from_nn_linear
2+
from .layers import (
3+
create_quantized_from_nn_linear,
4+
create_quantized_from_nn_embedding,
5+
)
36

47

58
def quantize_model(float_model, config):
69
"""Apply quantization to linear layers."""
710

811
def quantize_nn_mod(float_model):
912
for name, mod in float_model.named_modules():
10-
if isinstance(mod, torch.nn.Linear):
13+
new_mod = None
14+
if hasattr(mod, "get_quantized_version"):
15+
new_mod = mod.get_quantized_version()
16+
elif isinstance(mod, torch.nn.Linear):
1117
new_mod = create_quantized_from_nn_linear(mod, config)
18+
elif isinstance(mod, torch.nn.Embedding):
19+
new_mod = create_quantized_from_nn_embedding(mod, config)
20+
21+
if new_mod:
1222
setattr(float_model, name, new_mod)
1323

1424
float_model.apply(quantize_nn_mod)

jetstream_pt/third_party/llama/model_exportable.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ def __init__(
7575
self.annotate_sharding("w1.weight", 0)
7676
self.annotate_sharding("w2.weight", 1)
7777
self.annotate_sharding("w3.weight", 0)
78+
if LinearLayer != torch.nn.Linear:
79+
self.annotate_sharding("w1.weight_scaler", 0)
80+
self.annotate_sharding("w2.weight_scaler", 0)
81+
self.annotate_sharding("w3.weight_scaler", 0)
7882

7983
def forward(self, x):
8084
result = self.w2(F.silu(self.w1(x)) * self.w3(x))

0 commit comments

Comments
 (0)