diff --git a/jetstream_pt/cli.py b/jetstream_pt/cli.py index 193a1d4c..76dcace4 100644 --- a/jetstream_pt/cli.py +++ b/jetstream_pt/cli.py @@ -1,12 +1,16 @@ +from typing import List +import random import sys - +import time # import torch_xla2 first! import torch_xla2 # pylint: disable import jax from absl import app, flags +from jetstream.engine import token_utils from jetstream.core import server_lib from jetstream.core.config_lib import ServerConfig, MetricsServerConfig import torch +import numpy as np from transformers import AutoTokenizer from jetstream_pt import fetch_models @@ -55,9 +59,12 @@ def create_engine(devices): model = fetch_models.instantiate_model_from_repo_id(FLAGS.model_id, env) if quant_config.enable_weight_quantization: quantize_model.quantize_model(model, quant_config) + print("====== model =======") + print(model) weight_shardings = model.get_sharding_annotations() sharded_weights = shard_weights(env, model.state_dict(), weight_shardings) + env_data.quant_config = quant_config return engine.PyTorchEngine( pt_model=model, @@ -74,11 +81,7 @@ def list_model(): def serve(): """Run gRPC server.""" - if FLAGS.model_id == "": - print("Please specify model_id with --model_id") - print("valid model ids are:") - list_model() - sys.exit(1) + _check_model_id() devices = server_lib.get_devices() print(f"devices: {devices}") @@ -103,9 +106,105 @@ def serve(): jetstream_server.wait_for_termination() +def _check_model_id(): + if FLAGS.model_id == "": + print("Please specify model_id with --model_id") + print("valid model ids are:") + list_model() + sys.exit(1) + + def interactive(): """Run interactive""" - raise RuntimeError("Not implemented") + _check_model_id() + devices = server_lib.get_devices() + print(f"devices: {devices}") + pt_engine = create_engine(devices) + + start = time.perf_counter() + params = pt_engine.load_params() + print("Load params ", time.perf_counter() - start) + + metadata = pt_engine.get_tokenizer() + tokenizer = pt_engine.build_tokenizer(metadata) + max_output_length = 1024 + + profiling_output = FLAGS.profiling_output + profiling_prefill = ( + FLAGS.profiling_prefill + and profiling_output is not None + and profiling_output != "" + ) + + if profiling_prefill: + jax.profiler.start_trace(profiling_output) + + decode_state = pt_engine.init_decode_state() + + if profiling_prefill: + jax.profiler.stop_trace() + + prompts: List[str] = [ + # pylint: disable-next=all + "I believe the meaning of life is", + # pylint: disable-next=all + "To add an element to an ArrayList of a specific class type in Java, you can follow the following steps:\n\n1. Create an instance of the class to be added.\n2. Get a reference to the ArrayList.\n3. Call the `add()` method on the ArrayList, passing the instance of the class as the argument.\n\nHere's an example of how to add an object of type `Person` to an ArrayList of type `ArrayList`:\n```csharp\n// Create a new instance of the Person class\nPerson person = new Person(\"John\", 25);\n\n// Get a reference to the ArrayList\nArrayList peopleList = new ArrayList<>();\n\n// Add the person object to the ArrayList\npeopleList.add(person);\n```\nIn this example, the `Person` class is assumed to have a constructor that takes two arguments: a String for the person's name, and an int for their age. You can substitute your own class and constructor as necessary.", + # pylint: disable-next=all + "[INST] <>\nYou are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n<>\n\nQuestion 1: What is commercial real estate finance?\nQuestion 2: What are Commercial Real Estate services?\nOptions are:\n[a]. no.\n[b]. yes.\nWould the answer to these two questions be the same? [/INST]", + # pylint: disable-next=all + "[INST] <>\nYou are an AI assistant that helps people find information. Provide a detailed answer so user don\u2019t need to search outside to understand the answer.\n<>\n\nUse reasoning to lead to the answer of the following question:\nWhere are you likely to find water underneath?\nOptions:\n- toilet\n- sink\n- jar\n- bridge\n- house\n Reasoning process: [/INST", + # pylint: disable-next=all + "[INST] <>\nYou are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n<>\n\nContinue the following story.\n\nKay didn't have shoes that fit her feet properly. She only wore sneakers, because the \nChoose from: [I] shoes fitted badly. [II] sneakers fitted badly. [/INST]", + ] + for prompt in prompts: + slot = random.randint(0, FLAGS.batch_size - 1) + tokens, true_length = tokenizer.encode(prompt) + + print(f"---- Input prompts are: {prompt}") + print(f"---- Encoded tokens are: {tokens}") + + # pylint: disable-next=all + if profiling_prefill: + jax.profiler.start_trace(profiling_output) + + prefill_result, _ = pt_engine.prefill( + params=params, padded_tokens=tokens, true_length=true_length + ) + # pylint: disable-next=all + decode_state = pt_engine.insert(prefill_result, decode_state, slot=slot) + + if profiling_prefill: + jax.profiler.stop_trace() + + sampled_tokens_list = [] + print(f"---- Streaming decode started on #slot{slot}.") + complete = np.zeros((1,), dtype=np.bool_) + while True: + if profiling_output: + jax.profiler.start_trace(profiling_output) + + decode_state, result_tokens = pt_engine.generate(params, decode_state) + result_tokens = result_tokens.convert_to_numpy() + + if profiling_output: + jax.profiler.stop_trace() + + output, complete = token_utils.process_result_tokens( + tokenizer=tokenizer, + slot=slot, + slot_max_length=max_output_length, + result_tokens=result_tokens, + complete=complete, + ) + if complete[0]: + break + token_ids = output[0].token_ids + sampled_tokens_list.extend(token_ids) + + print("---- All output tokens.") + print(sampled_tokens_list) + print("---- All output text.") + print(tokenizer.decode(sampled_tokens_list)) def main(argv): @@ -115,15 +214,14 @@ def main(argv): if argv[1] == "list": list_model() - return - - if argv[1] == "serve": + elif argv[1] == "serve": serve() - return - - if argv[1] == "interative": + elif argv[1] == "interactive": interactive() - return + else: + print( + "Invalid arguments. please specify 'list', 'serve', or 'interactive'." + ) if __name__ == "__main__": diff --git a/jetstream_pt/fetch_models.py b/jetstream_pt/fetch_models.py index b7e60922..6786b512 100644 --- a/jetstream_pt/fetch_models.py +++ b/jetstream_pt/fetch_models.py @@ -13,6 +13,7 @@ ) from jetstream_pt.third_party.llama import model_exportable as llama_model from jetstream_pt.third_party.mixtral import model as mixtral_model +from jetstream_pt.third_party.gemma import model as gemma_model FLAGS = flags.FLAGS @@ -49,6 +50,9 @@ class ModelInfo: _mixtral_87 = ModelInfo(mixtral_model.Transformer, 32, 8, 128, 4) +_gemma_2b = ModelInfo(gemma_model.GemmaModel, 18, 1, 256, 8) +_gemma_7b = ModelInfo(gemma_model.GemmaModel, 28, 16, 256, 1) + model_id_to_class = { "meta-llama/Llama-2-7b-chat-hf": _llama2_7, @@ -57,10 +61,10 @@ class ModelInfo: "meta-llama/Llama-2-13b-hf": _llama2_13, "meta-llama/Meta-Llama-3-8B": _llama3_8, "meta-llama/Meta-Llama-3-8B-Instruct": _llama3_8, - "google/gemma-2b": None, - "google/gemma-2b-it": None, - "google/gemma-7b": None, - "google/gemma-7b-it": None, + "google/gemma-2b": _gemma_2b, + "google/gemma-2b-it": _gemma_2b, + "google/gemma-7b": _gemma_7b, + "google/gemma-7b-it": _gemma_7b, "mistralai/Mixtral-8x7B-v0.1": _mixtral_87, "mistralai/Mixtral-8x7B-Instruct-v0.1": _mixtral_87, } @@ -162,11 +166,7 @@ def instantiate_model_from_repo_id( env.device = "meta" model = model_info.model_class.from_hf_model_id(repo_id, env) weights = _load_weights(model_dir) - updated_keys = model.get_hf_names_to_real_name() - for name, updated in updated_keys.items(): - if name in weights: - val = weights.pop(name) - weights[updated] = val + weights = model.convert_hf_weights(weights) model.load_state_dict(weights, assign=True, strict=False) diff --git a/jetstream_pt/hf_tokenizer.py b/jetstream_pt/hf_tokenizer.py index a02148d4..358cefa7 100644 --- a/jetstream_pt/hf_tokenizer.py +++ b/jetstream_pt/hf_tokenizer.py @@ -1,4 +1,4 @@ -from jetstream.engine import tokenizer_api +from jetstream.engine import tokenizer_api, token_utils class HFTokenizerAdapter(tokenizer_api.Tokenizer): @@ -17,7 +17,10 @@ def encode(self, s: str, **kwargs): true_length: Actual length of the non-padded sequence if padding is used. """ - return self(s) + res = self.tokenizer.encode(s, add_special_tokens=False) + return token_utils.pad_tokens( + res, self.bos_id, self.pad_id, jax_padding=True + ) def decode(self, token_ids: list[int], **kwargs) -> str: """Processess input token ids to generate a string. @@ -27,7 +30,7 @@ def decode(self, token_ids: list[int], **kwargs) -> str: Returns: str: String generated from the token ids. """ - return self.decode(token_ids) + return self.tokenizer.decode(token_ids) @property def pad_id(self) -> int: @@ -47,4 +50,4 @@ def bos_id(self) -> int: @property def stop_tokens(self) -> set[int]: """ID of the stop token.""" - return {self.eos_id, self.pad_id} + return {self.eos_id} diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index d06b8d87..c41fe76a 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -328,9 +328,9 @@ def create_quantized_from_nn_embedding( float_embedding.num_embeddings, float_embedding.embedding_dim, ) - weights, scaler, _ = quantize_tensor(float_embedding.weight, 1) + weights, scaler, _ = quantize_tensor(float_embedding.weight, 0) obj.weight = weights - obj.scaler = scaler + obj.weight_scaler = scaler return obj diff --git a/jetstream_pt/model_base.py b/jetstream_pt/model_base.py index e7609891..660a6fec 100644 --- a/jetstream_pt/model_base.py +++ b/jetstream_pt/model_base.py @@ -46,7 +46,18 @@ class AttrProperty: class ModuleBase(torch.nn.Module, metaclass=abc.ABCMeta): - """nn Module that allows attaching properties""" + """nn Module that allows attaching properties. + + This class currently serves 2 goals: + 1. Allow model to specify alternative names for submodules / weights + this is needed so that it can *also* load HuggingFace checkpoints + without need to do massive rewrites. + + 2. Allow model to attach information to weights, such as sharding config. + + Quantization config could be another thing to attach, but right now it's not used + this way. + """ attr_to_property: Dict[str, Any] @@ -74,6 +85,19 @@ def annotate_sharding(self, name, axis): """Set sharding name for a attribute or submodule.""" self.attr_to_property[name].sharding_axis = axis - def drop_weight(self, key): - """list out names to discard.""" - return False + def convert_hf_weights( + self, hf_weights: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + """Load state_dict with hg weights.""" + weights = {} + updated_keys = self.get_hf_names_to_real_name() + for name, updated in updated_keys.items(): + if name in hf_weights: + weights[updated] = hf_weights[name] + + for name in list(weights.keys()): + if "inv_freq" in name: + weights.pop(name) + if hasattr(self, "freqs_cis"): + weights["freqs_cis"] = self.freqs_cis + return weights diff --git a/jetstream_pt/third_party/gemma/model.py b/jetstream_pt/third_party/gemma/model.py index 5773b8bd..9d597551 100644 --- a/jetstream_pt/third_party/gemma/model.py +++ b/jetstream_pt/third_party/gemma/model.py @@ -21,6 +21,7 @@ from . import config as gemma_config from jetstream_pt import layers +from jetstream_pt.model_base import ModuleBase import jax @@ -63,7 +64,7 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: return x_out -class GemmaAttention(nn.Module): +class GemmaAttention(ModuleBase): def __init__( self, @@ -131,6 +132,19 @@ def __init__( **linear_kwargs, ) + self.hf_name("wk", "k_proj") + self.hf_name("wv", "v_proj") + self.hf_name("wq", "q_proj") + self.annotate_sharding("wk.weight", 0) + self.annotate_sharding("wv.weight", 0) + self.annotate_sharding("wq.weight", 0) + self.annotate_sharding("o_proj.weight", 1) + if Linear != torch.nn.Linear: + self.annotate_sharding("wk.weight_scaler", 0) + self.annotate_sharding("wv.weight_scaler", 0) + self.annotate_sharding("wq.weight_scaler", 0) + self.annotate_sharding("o_proj.weight_scaler", -1) + Kernel = ( layers.Int8KVAttentionKernel if env.quant_config.enable_kv_quantization @@ -195,7 +209,7 @@ def forward( return output -class RMSNorm(torch.nn.Module): +class RMSNorm(ModuleBase): def __init__( self, @@ -221,7 +235,7 @@ def forward(self, x): return output -class GemmaMLP(nn.Module): +class GemmaMLP(ModuleBase): def __init__( self, @@ -262,6 +276,17 @@ def __init__( **linear_kwargs, ) + self.annotate_sharding("gate_proj.weight", 0) + self.annotate_sharding("up_proj.weight", 0) + self.annotate_sharding("down_proj.weight", 1) + self.annotate_sharding("gate_proj.bias", 0) + self.annotate_sharding("up_proj.bias", 0) + self.annotate_sharding("down_proj.bias", -1) + if Linear != torch.nn.Linear: + self.annotate_sharding("gate_proj.weight_scaler", 0) + self.annotate_sharding("up_proj.weight_scaler", 0) + self.annotate_sharding("down_proj.weight_scaler", -1) + def forward(self, x): gate = self.gate_proj(x) gate = F.gelu(gate, approximate="tanh") @@ -271,7 +296,7 @@ def forward(self, x): return outputs -class GemmaDecoderLayer(nn.Module): +class GemmaDecoderLayer(ModuleBase): def __init__(self, config: gemma_config.GemmaConfig, env, layer_id): super().__init__() @@ -333,7 +358,7 @@ def forward( return hidden_states -class GemmaModel(nn.Module): +class GemmaModel(ModuleBase): def __init__(self, config: gemma_config.GemmaConfig, env): super().__init__() @@ -356,6 +381,10 @@ def __init__(self, config: gemma_config.GemmaConfig, env): self.embedder = Embedding( config.vocab_size, config.hidden_size, device=config.device ) + self.hf_name("embedder", "model.embed_tokens") + self.hf_name("layers", "model.layers") + self.hf_name("norm", "model.norm") + rope_theta = getattr(config, "rope_theta", 10000) freqs_cis = precompute_freqs_cis( config.head_dim, config.max_position_embeddings * 2, theta=rope_theta @@ -430,3 +459,17 @@ def get_quantized_embedding_weight_to_scaler_map(): return { "embedder.weight": "embedder.weight_scaler", } + + @classmethod + def from_hf_model_id(cls, model_id, env): + name = { + "google/gemma-2b": "2b", + "google/gemma-2b-it": "2b", + "google/gemma-7b": "7b", + "google/gemma-7b-it": "7b", + }.get(model_id) + assert name + args = gemma_config.get_model_config(name) + args.device = "meta" + model = cls(args, env) + return model diff --git a/jetstream_pt/third_party/llama/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py index cd8a3e82..791ff7a5 100644 --- a/jetstream_pt/third_party/llama/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -2,7 +2,7 @@ """This version contains modification to make it easier to trace and support batch.""" from typing import Any, List, Optional - +import copy import jax import torch import torch.nn.functional as F @@ -99,6 +99,7 @@ def __init__( self.n_heads = args.n_heads self.dim = args.dim self.head_dim = args.dim // args.n_heads + self.args = args self.attention = Attention( args.n_heads, @@ -360,8 +361,23 @@ def from_hf_model_id(cls, model_id, env): def drop_weight(self, key): return key.startswith("model") - def shard_weights(self, weights_dict): - """Shards the weights + def convert_hf_weights(self, hf_weights): - Assumes the weights_dict is a list of XLATensor2 - """ + def transform(val, n_heads): + dim1, dim2 = val.shape + return ( + val.reshape(n_heads, 2, dim1 // n_heads // 2, dim2) + .transpose(1, 2) + .reshape(dim1, dim2) + ) + + updated = copy.copy(hf_weights) + + for key, value in hf_weights.items(): + if "q_proj" in key: + updated[key] = transform(value, self.params.n_heads) + if "k_proj" in key: + updated[key] = transform( + value, self.params.n_kv_heads or self.params.n_heads + ) + return super().convert_hf_weights(updated) diff --git a/jetstream_pt/third_party/mixtral/model.py b/jetstream_pt/third_party/mixtral/model.py index 7cef7dff..276d7f80 100644 --- a/jetstream_pt/third_party/mixtral/model.py +++ b/jetstream_pt/third_party/mixtral/model.py @@ -12,7 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import collections +import copy from dataclasses import dataclass from typing import Optional, List, Any @@ -163,6 +164,43 @@ def from_hf_model_id(cls, model_id, env): model = cls(args, env) return model + def convert_hf_weights(self, hf_weights): + + def transform(val, n_heads): + dim1, dim2 = val.shape + return ( + val.reshape(n_heads, 2, dim1 // n_heads // 2, dim2) + .transpose(1, 2) + .reshape(dim1, dim2) + ) + + groupped_by_experts = collections.defaultdict(lambda: [None] * 8) + updated = copy.copy(hf_weights) + for key, value in hf_weights.items(): + if "block_sparse_moe.experts" in key: + # 0 1 2 3 4 5 6 7 + # "model.layers.0.block_sparse_moe.experts.0.w1.weight" + updated.pop(key) + name_pieces = key.split(".") + assert len(name_pieces) == 8 + layer_id = int(name_pieces[2]) + expert_id = int(name_pieces[5]) + weight_name = name_pieces[6] + groupped_by_experts[(layer_id, weight_name)][expert_id] = value + + if "q_proj" in key: + updated[key] = transform(value, self.config.n_head) + if "k_proj" in key: + updated[key] = transform( + value, self.config.n_local_heads or self.config.n_head + ) + + for (layer_id, weight_name), ws in groupped_by_experts.items(): + name = f"model.layers.{layer_id}.block_sparse_moe.cond_ffn.{weight_name}" + updated[name] = torch.stack(ws) + res = super().convert_hf_weights(updated) + return res + class TransformerBlock(ModuleBase): @@ -177,6 +215,7 @@ def __init__(self, config: ModelArgs, env, layer_id) -> None: device=config.device, layer_id=layer_id, ) + self.config = config self.hf_name("attention", "self_attn") self.attention.hf_name("wq", "q_proj") self.attention.hf_name("wk", "k_proj") @@ -194,19 +233,6 @@ def __init__(self, config: ModelArgs, env, layer_id) -> None: self.hf_name("attention_norm", "input_layernorm") self.hf_name("ffn_norm", "post_attention_layernorm") - self._register_load_state_dict_pre_hook(self.load_hook) - - def load_hook(self, state_dict, prefix, *args): - if prefix + "block_sparse_moe.experts" in state_dict: - w1s, w2s, w3s = [], [], [] - for i in range(8): - exp_prefix = f"{prefix}block_sparse_moe.experts.{i}." - w1s.append(state_dict.pop(exp_prefix + ".w1")) - w2s.append(state_dict.pop(exp_prefix + ".w2")) - w3s.append(state_dict.pop(exp_prefix + ".w3")) - state_dict[prefix + "block_sparse_moe.cond_ffn.w1"] = torch.cat(w1s) - state_dict[prefix + "block_sparse_moe.cond_ffn.w2"] = torch.cat(w2s) - state_dict[prefix + "block_sparse_moe.cond_ffn.w3"] = torch.cat(w3s) def forward( self, @@ -383,14 +409,14 @@ def get_quantized_version(self): """Return quantized version of this class.""" quant_version = Int8ConditionalFeedForward(self.config) w1, w1_scaler, _ = quantize.quantize_tensor(self.w1, 2) - w2, w2_scaler, _ = quantize.quantize_tensor(self.w2, 1) + w2, w2_scaler, _ = quantize.quantize_tensor(self.w2, 2) w3, w3_scaler, _ = quantize.quantize_tensor(self.w3, 2) quant_version.w1 = w1 quant_version.w2 = w2 quant_version.w3 = w3 - quant_version.w1_scaler = w1_scaler - quant_version.w2_scaler = w2_scaler - quant_version.w3_scaler = w3_scaler + quant_version.w1_scaler = w1_scaler.squeeze(2) + quant_version.w2_scaler = w2_scaler.squeeze(2) + quant_version.w3_scaler = w3_scaler.squeeze(2) return quant_version diff --git a/tests/test_quantization.py b/tests/test_quantization.py index d150c67b..087c340a 100644 --- a/tests/test_quantization.py +++ b/tests/test_quantization.py @@ -516,6 +516,18 @@ def forward(self, x): res = helpers.call_xla_model(qm, qm.state_dict(), arg) self.assertGreater(self._calc_cosine_dist(res, torch_res), 0.9999) + def test_embedding(self): + m = torch.nn.Embedding(1000, 100) + arg = torch.randint(0, 1000, [2]).to(torch.int32) + torch_res = m(arg) + quant_config = QuantizationConfig( + enable_weight_quantization=True, + enable_activation_quantization=False, + ) + qm = quantize_model(m, quant_config) + res = helpers.call_xla_model(qm, qm.state_dict(), arg) + self.assertGreater(self._calc_cosine_dist(res, torch_res), 0.997) + if __name__ == "__main__": unittest.main()