Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions install_everything.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pip install flax
pip install tensorflow-text
pip install tensorflow
pip install huggingface_hub
pip install transformers

pip install ray[default]==2.33.0
# torch cpu
Expand Down
1 change: 1 addition & 0 deletions install_everything_gpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pip show torch_xla2 && pip uninstall -y torch_xla2
pip install flax==0.8.4
pip install tensorflow-text
pip install tensorflow
pip install transformers

pip install ray[default]==2.22.0
# torch cpu
Expand Down
14 changes: 7 additions & 7 deletions jetstream_pt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from jetstream.core import server_lib
from jetstream.core.config_lib import ServerConfig, MetricsServerConfig
import torch
from transformers import AutoTokenizer

from jetstream_pt import fetch_models
from jetstream_pt import environment, engine, quantize_model, torchjax
Expand All @@ -25,13 +26,13 @@

def shard_weights(env, weights, weight_shardings):
"""Shard weights according to weight_shardings"""
for k, v in weight_shardings.items():
print("SHARDING", k, v)
sharded = {}
for key, val in weights.items():
sharding = env.sharding_by_axis(weight_shardings.get(key, -1))
with jax.default_device(jax.devices("cpu")[0]):
arr = torch_xla2.tensor.t2j(val)

print("SHARDING", key, sharding)
arr = jax.device_put(arr, sharding)
sharded[key] = torchjax.to_torch(arr)
return sharded
Expand All @@ -48,17 +49,16 @@ def create_engine(devices):
FLAGS.max_output_length,
quant_config.enable_weight_quantization,
)
tokenizer = AutoTokenizer.from_pretrained(FLAGS.model_id)
env = environment.JetEngineEnvironment(env_data)
env.hf_tokenizer = tokenizer
model = fetch_models.instantiate_model_from_repo_id(FLAGS.model_id, env)
if quant_config.enable_weight_quantization:
quantize_model.quantize_model(model, quant_config)

weight_shardings = model.get_sharding_annotations()
sharded_weights = shard_weights(env, model.state_dict(), weight_shardings)

if quant_config.enable_weight_quantization:
model.load_state_dict(sharded_weights, assign=True, strict=False)
quantize_model.quantize_model(model, quant_config)
sharded_weights = model.state_dict()

return engine.PyTorchEngine(
pt_model=model,
env=env,
Expand Down
3 changes: 3 additions & 0 deletions jetstream_pt/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from jetstream_pt import cache_manager
from jetstream_pt import quantize
from jetstream_pt import torchjax
from jetstream_pt.hf_tokenizer import HFTokenizerAdapter
from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData, QuantizationConfig
from jetstream_pt.third_party.llama import model_exportable as llama_model, model_args
from jetstream_pt.third_party.gemma import config as gemma_config, model as gemma_model
Expand Down Expand Up @@ -705,6 +706,8 @@ def get_tokenizer(self) -> tokenizer_pb2.TokenizerParameters:
def build_tokenizer(
self, metadata: tokenizer_pb2.TokenizerParameters # pylint: disable=all
) -> tokenizer_api.Tokenizer:
if self.env.hf_tokenizer is not None:
return HFTokenizerAdapter(self.env.hf_tokenizer)
if "llama-3" in self.env.model_type:
return token_utils.TikToken(metadata)

Expand Down
4 changes: 4 additions & 0 deletions jetstream_pt/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ def __init__(self, data: JetEngineEnvironmentData):
self.testing_seed = self._data.testing_seed
self.ring_buffer = self._data.ring_buffer

# If not None, then use this tokenizer without
# trying to create new ones.
self.hf_tokenizer = None

if not self.ring_buffer:
self.lazy_cache_update = True
self.ragged_mha = True
Expand Down
19 changes: 12 additions & 7 deletions jetstream_pt/fetch_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
QuantizationConfig,
)
from jetstream_pt.third_party.llama import model_exportable as llama_model
from jetstream_pt.third_party.mixtral import model as mixtral_model

FLAGS = flags.FLAGS

Expand All @@ -38,12 +39,15 @@ class ModelInfo:
num_layers: int
num_heads: int
head_dim: int
n_reps: int # repeatition for GQA


_llama2_7 = ModelInfo(llama_model.Transformer, 32, 32, 128)
_llama2_13 = ModelInfo(llama_model.Transformer, 40, 40, 128)
_llama2_70 = ModelInfo(llama_model.Transformer, 80, 8, 128)
_llama3_8 = ModelInfo(llama_model.Transformer, 32, 8, 128)
_llama2_7 = ModelInfo(llama_model.Transformer, 32, 32, 128, 1)
_llama2_13 = ModelInfo(llama_model.Transformer, 40, 40, 128, 1)
_llama2_70 = ModelInfo(llama_model.Transformer, 80, 8, 128, 4)
_llama3_8 = ModelInfo(llama_model.Transformer, 32, 8, 128, 4)

_mixtral_87 = ModelInfo(mixtral_model.Transformer, 32, 8, 128, 4)


model_id_to_class = {
Expand All @@ -57,8 +61,8 @@ class ModelInfo:
"google/gemma-2b-it": None,
"google/gemma-7b": None,
"google/gemma-7b-it": None,
"mistralai/Mixtral-8x7B-v0.1": None,
"mistralai/Mixtral-8x7B-Instruct-v0.1": None,
"mistralai/Mixtral-8x7B-v0.1": _mixtral_87,
"mistralai/Mixtral-8x7B-Instruct-v0.1": _mixtral_87,
}


Expand Down Expand Up @@ -107,6 +111,7 @@ def construct_env_data_from_model_id(
else input_length + output_length
)

model_info = model_id_to_class.get(repo_id)
env_data = JetEngineEnvironmentData(
tokenizer_path=tokenizer_path,
checkpoint_path=checkpoint_path,
Expand All @@ -119,8 +124,8 @@ def construct_env_data_from_model_id(
bf16_enable=True,
sharding_config_path="",
shard_on_batch=shard_on_batch,
n_reps=model_info.n_reps,
)
model_info = model_id_to_class.get(repo_id)
env_data.cache_shape = (
batch_size,
model_info.num_heads,
Expand Down
50 changes: 50 additions & 0 deletions jetstream_pt/hf_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from jetstream.engine import tokenizer_api


class HFTokenizerAdapter(tokenizer_api.Tokenizer):
"""Implementation of Tokenizer interface backed by HF tokenizer."""

def __init__(self, tokenizer):
self.tokenizer = tokenizer

def encode(self, s: str, **kwargs):
"""Tokenize a string.
Args:
s: String to tokenize.
**kwargs: Additional keyword arguments.
Returns:
tokens: Tokenized into integers.
true_length: Actual length of the non-padded sequence
if padding is used.
"""
return self(s)

def decode(self, token_ids: list[int], **kwargs) -> str:
"""Processess input token ids to generate a string.
Args:
token_ids: List of token ids.
**kwargs: Additional keyword arguments.
Returns:
str: String generated from the token ids.
"""
return self.decode(token_ids)

@property
def pad_id(self) -> int:
"""ID of the pad token."""
return self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else 0

@property
def eos_id(self) -> int:
"""ID of EOS token."""
return self.tokenizer.eos_token_id

@property
def bos_id(self) -> int:
"""ID of BOS token."""
return self.tokenizer.bos_token_id

@property
def stop_tokens(self) -> set[int]:
"""ID of the stop token."""
return {self.eos_id, self.pad_id}
14 changes: 14 additions & 0 deletions jetstream_pt/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,20 @@ def get_quantized_embedding_layer(config: "QuantizationConfig"):
return Int8Embedding


def create_quantized_from_nn_embedding(
float_embedding: nn.Embedding, config: "QuantizationConfig"
):
clazz_ = get_quantized_embedding_layer(config)
obj = clazz_(
float_embedding.num_embeddings,
float_embedding.embedding_dim,
)
weights, scaler, _ = quantize_tensor(float_embedding.weight, 1)
obj.weight = weights
obj.scaler = scaler
return obj


class RMSNorm(torch.nn.Module):
"""RMSNorm module."""

Expand Down
14 changes: 12 additions & 2 deletions jetstream_pt/quantize_model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
import torch
from .layers import create_quantized_from_nn_linear
from .layers import (
create_quantized_from_nn_linear,
create_quantized_from_nn_embedding,
)


def quantize_model(float_model, config):
"""Apply quantization to linear layers."""

def quantize_nn_mod(float_model):
for name, mod in float_model.named_modules():
if isinstance(mod, torch.nn.Linear):
new_mod = None
if hasattr(mod, "get_quantized_version"):
new_mod = mod.get_quantized_version()
elif isinstance(mod, torch.nn.Linear):
new_mod = create_quantized_from_nn_linear(mod, config)
elif isinstance(mod, torch.nn.Embedding):
new_mod = create_quantized_from_nn_embedding(mod, config)

if new_mod:
setattr(float_model, name, new_mod)

float_model.apply(quantize_nn_mod)
Expand Down
4 changes: 4 additions & 0 deletions jetstream_pt/third_party/llama/model_exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ def __init__(
self.annotate_sharding("w1.weight", 0)
self.annotate_sharding("w2.weight", 1)
self.annotate_sharding("w3.weight", 0)
if LinearLayer != torch.nn.Linear:
self.annotate_sharding("w1.weight_scaler", 0)
self.annotate_sharding("w2.weight_scaler", 0)
self.annotate_sharding("w3.weight_scaler", 0)

def forward(self, x):
result = self.w2(F.silu(self.w1(x)) * self.w3(x))
Expand Down
Loading