From 9b8990453fc2b552a4b4c2dc8c80b3ad5ce78de1 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Thu, 9 May 2024 12:33:09 +0000 Subject: [PATCH 1/4] Checkpoint on gemma --- convert_checkpoints.py | 24 ++--- default_shardings/gemma.yaml | 4 +- jetstream_pt/engine.py | 4 + jetstream_pt/third_party/gemma/model.py | 136 +++++++++++++++++++++++- run_interactive.py | 16 +-- run_server.py | 1 - tests/test_model_impl.py | 5 +- 7 files changed, 158 insertions(+), 32 deletions(-) diff --git a/convert_checkpoints.py b/convert_checkpoints.py index bda847b1..6a1a84f0 100644 --- a/convert_checkpoints.py +++ b/convert_checkpoints.py @@ -431,26 +431,14 @@ def convert_hf_gemma_weights( new_key = key if key.startswith(prefix_to_remove): new_key = new_key.removeprefix(prefix_to_remove) - if "qkv_proj" in key: - q_dim = model_config["num_attention_heads"] * model_config["head_dim"] - kv_dim = model_config["num_key_value_heads"] * model_config["head_dim"] - qkv = state_dict.pop(key) - q, k, v = qkv.split( - [ - q_dim, - kv_dim, - kv_dim, - ], - dim=0, - ) - state_dict[new_key.replace("qkv_proj", "wq")] = q - state_dict[new_key.replace("qkv_proj", "wk")] = k - state_dict[new_key.replace("qkv_proj", "wv")] = v - continue - if "o_proj" in key: - new_key = new_key.replace("o_proj", "wo") if new_key != key: state_dict[new_key] = state_dict.pop(key) + + ckpt_basename = os.path.basename(ckpt_file) + output_ckpt_dir.mkdir(parents=True, exist_ok=True) + + del state_dict['freqs_cis'] + _export_to_local(output_ckpt_dir, model_config, state_dict) diff --git a/default_shardings/gemma.yaml b/default_shardings/gemma.yaml index 4beda7c4..9eb1d1a9 100644 --- a/default_shardings/gemma.yaml +++ b/default_shardings/gemma.yaml @@ -3,7 +3,9 @@ # "replicated" to signify "replicated". # Integer signify axis to shard: 0 <= shard axis < rank -freqs_cis : -1 # torch.complex64 (16384, 128) +freqs_cis : null # torch.complex64 (16384, 128) +layers.*.self_attn.qkv_proj.weight: 0 +layers.*.self_attn.o_proj.weight: 1 layers.*.self_attn.wo.weight : 1 # 1, -1] # torch.float32 (2048, 2048) layers.*.self_attn.wq.weight : 0 # -1, 1] # torch.float32 (2048, 2048) layers.*.self_attn.wk.weight : 0 # -1, 1] # torch.float32 (256, 2048) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index f95d8cbd..378d338b 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -17,6 +17,7 @@ from typing import Any, List, Optional, Tuple, Union import threading import functools +import os from etils import epath from flax import struct @@ -703,6 +704,9 @@ def create_pytorch_engine( tokenizer = token_utils.load_vocab(tokenizer_path) pt_model = None + if not sharding_config: + sharding_config = os.path.join('default_shardings', model_name + '.yaml') + env_data = JetEngineEnvironmentData( tokenizer_path=tokenizer_path, checkpoint_path=checkpoint_path, diff --git a/jetstream_pt/third_party/gemma/model.py b/jetstream_pt/third_party/gemma/model.py index 0cfc9b15..1e0b40b3 100644 --- a/jetstream_pt/third_party/gemma/model.py +++ b/jetstream_pt/third_party/gemma/model.py @@ -35,6 +35,137 @@ def precompute_freqs_cis( return freqs_cis +def reshape_for_broadcast( + freqs_cis: torch.Tensor, x: torch.Tensor +) -> torch.Tensor: + ndim = x.ndim + assert 1 < ndim + assert freqs_cis.shape == ( + x.shape[0], + x.shape[2], + x.shape[3], + ), f"freqs_cis: {freqs_cis.shape }, x: {x.shape}" + shape = [d if i != 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + """Applies the rotary embedding to the query and key tensors.""" + x_ = torch.view_as_complex( + torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), + dim=-1)) + freqs_cis = reshape_for_broadcast(freqs_cis, x_) + x_out = torch.view_as_real(x_ * freqs_cis).type_as(x) + x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2) + x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], + -1).transpose(1, 2) + return x_out + + + + +class GemmaAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + device, + env, + ): + super().__init__() + + self.env = env + + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + self.hidden_size = hidden_size + self.head_dim = head_dim + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.scaling = self.head_dim**-0.5 + + Linear = ( + layers.WeightOnlyInt8Linear + if env.enable_weight_quantization + else torch.nn.Linear + ) + self.qkv_proj = Linear( + self.hidden_size, + (self.num_heads + 2 * self.num_kv_heads) * self.head_dim, + bias=False, + device=device, + ) + self.o_proj = Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=False, device=device + ) + + def forward( + self, + hidden_states, + freqs_cis, + mask, + cache, + ) -> torch.Tensor: + hidden_states_shape = hidden_states.shape + assert len(hidden_states_shape) == 3 + batch_size, input_len, _ = hidden_states_shape + qkv = self.qkv_proj(hidden_states) + xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + xq = xq.view(batch_size, -1, self.num_heads, self.head_dim) + xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim) + xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim) + + # Positional embedding. + xq = apply_rotary_emb(xq, freqs_cis=freqs_cis) + xk = apply_rotary_emb(xk, freqs_cis=freqs_cis) + + # Write new kv cache. + # [batch_size, input_len, n_local_kv_heads, head_dim] + + xk = xk.transpose(1, 2) + xv = xv.transpose(1, 2) + + # [batch_size, n_local_kv_heads, seq len, head_dim] + key, value = cache.update(xk, xv) + + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + if self.num_kv_heads != self.num_heads: + # [batch_size, max_seq_len, n_local_heads, head_dim] + key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2) + value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=2) + + # [batch_size, n_local_heads, input_len, head_dim] + q = xq.transpose(1, 2) + # [batch_size, n_local_heads, max_seq_len, head_dim] + k = key.transpose(1, 2) + v = value.transpose(1, 2) + + # [batch_size, n_local_heads, input_len, max_seq_len] + scores = torch.matmul(q, k.transpose(2, 3)) * self.scaling + scores = scores + mask + scores = F.softmax(scores.float(), dim=-1).type_as(q) + + # [batch_size, n_local_heads, input_len, head_dim] + output = torch.matmul(scores, v) + + # [batch_size, input_len, hidden_dim] + output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1) + output = self.o_proj(output) + return output + + class RMSNorm(torch.nn.Module): def __init__( @@ -99,14 +230,15 @@ class GemmaDecoderLayer(nn.Module): def __init__(self, config: gemma_config.GemmaConfig, env): super().__init__() - self.self_attn = layers.Attention( + self.self_attn = GemmaAttention( + config.hidden_size, config.num_attention_heads, config.num_key_value_heads, config.head_dim, - config.hidden_size, config.device, env, ) + self.mlp = GemmaMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, diff --git a/run_interactive.py b/run_interactive.py index d0d3b21f..5a332710 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -121,15 +121,15 @@ def main(argv): decode_state = engine.init_decode_state() prompts: List[str] = [ - "I believe the meaning of life is", + "The meaning of life is", # pylint: disable-next=all - "To add an element to an ArrayList of a specific class type in Java, you can follow the following steps:\n\n1. Create an instance of the class to be added.\n2. Get a reference to the ArrayList.\n3. Call the `add()` method on the ArrayList, passing the instance of the class as the argument.\n\nHere's an example of how to add an object of type `Person` to an ArrayList of type `ArrayList`:\n```csharp\n// Create a new instance of the Person class\nPerson person = new Person(\"John\", 25);\n\n// Get a reference to the ArrayList\nArrayList peopleList = new ArrayList<>();\n\n// Add the person object to the ArrayList\npeopleList.add(person);\n```\nIn this example, the `Person` class is assumed to have a constructor that takes two arguments: a String for the person's name, and an int for their age. You can substitute your own class and constructor as necessary.", - # pylint: disable-next=all - "[INST] <>\nYou are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n<>\n\nQuestion 1: What is commercial real estate finance?\nQuestion 2: What are Commercial Real Estate services?\nOptions are:\n[a]. no.\n[b]. yes.\nWould the answer to these two questions be the same? [/INST]", - # pylint: disable-next=all - "[INST] <>\nYou are an AI assistant that helps people find information. Provide a detailed answer so user don\u2019t need to search outside to understand the answer.\n<>\n\nUse reasoning to lead to the answer of the following question:\nWhere are you likely to find water underneath?\nOptions:\n- toilet\n- sink\n- jar\n- bridge\n- house\n Reasoning process: [/INST", - # pylint: disable-next=all - "[INST] <>\nYou are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n<>\n\nContinue the following story.\n\nKay didn't have shoes that fit her feet properly. She only wore sneakers, because the \nChoose from: [I] shoes fitted badly. [II] sneakers fitted badly. [/INST]", + # "To add an element to an ArrayList of a specific class type in Java, you can follow the following steps:\n\n1. Create an instance of the class to be added.\n2. Get a reference to the ArrayList.\n3. Call the `add()` method on the ArrayList, passing the instance of the class as the argument.\n\nHere's an example of how to add an object of type `Person` to an ArrayList of type `ArrayList`:\n```csharp\n// Create a new instance of the Person class\nPerson person = new Person(\"John\", 25);\n\n// Get a reference to the ArrayList\nArrayList peopleList = new ArrayList<>();\n\n// Add the person object to the ArrayList\npeopleList.add(person);\n```\nIn this example, the `Person` class is assumed to have a constructor that takes two arguments: a String for the person's name, and an int for their age. You can substitute your own class and constructor as necessary.", + # # pylint: disable-next=all + # "[INST] <>\nYou are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n<>\n\nQuestion 1: What is commercial real estate finance?\nQuestion 2: What are Commercial Real Estate services?\nOptions are:\n[a]. no.\n[b]. yes.\nWould the answer to these two questions be the same? [/INST]", + # # pylint: disable-next=all + # "[INST] <>\nYou are an AI assistant that helps people find information. Provide a detailed answer so user don\u2019t need to search outside to understand the answer.\n<>\n\nUse reasoning to lead to the answer of the following question:\nWhere are you likely to find water underneath?\nOptions:\n- toilet\n- sink\n- jar\n- bridge\n- house\n Reasoning process: [/INST", + # # pylint: disable-next=all + # "[INST] <>\nYou are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n<>\n\nContinue the following story.\n\nKay didn't have shoes that fit her feet properly. She only wore sneakers, because the \nChoose from: [I] shoes fitted badly. [II] sneakers fitted badly. [/INST]", ] for prompt in prompts: slot = random.randint(0, _BATCH_SIZE.value - 1) diff --git a/run_server.py b/run_server.py index 161af9bd..8fececac 100644 --- a/run_server.py +++ b/run_server.py @@ -111,7 +111,6 @@ def main(argv: Sequence[str]): param_size=_PARAM_SIZE.value, context_length=_CONTEXT_LENGTH.value, batch_size=_BATCH_SIZE.value, - model_name=_MODEL_NAME.value, quantize_weights=_QUANTIZE_WEIGHTS.value, quantize_kv=_QUANTIZE_KV_CACHE.value, max_cache_length=_MAX_CACHE_LENGTH.value, diff --git a/tests/test_model_impl.py b/tests/test_model_impl.py index 1a280563..971dd638 100644 --- a/tests/test_model_impl.py +++ b/tests/test_model_impl.py @@ -190,8 +190,8 @@ def init_weights(model): state_dict = model.state_dict() res = {} for k, v in state_dict.items(): - x = random.randint(1, 10) - res[k] = torch.ones(v.shape) * x + #x = random.randint(1, 10) + res[k] = torch.randn(v.shape) #* x model.load_state_dict(res, assign=True) attention_orig = gemma_orig.GemmaAttention( @@ -260,6 +260,7 @@ def load_hook(state_dict, prefix, *args): "Single Gemma Attention: Diff norm", (result_torch - expected_out).norm(), ) + self.assertTrue(torch.allclose(result_torch, expected_out, atol=1e-4)) # pylint: disable-next=all From 3fcd49d25467f51984035d8223a11b89c2125c7c Mon Sep 17 00:00:00 2001 From: Han Qi Date: Thu, 9 May 2024 13:46:46 +0000 Subject: [PATCH 2/4] Change Gemma to use Gemma Attention from model_original This way it produces more accurate results --- convert_checkpoints.py | 25 ++- install_everything.sh | 14 +- jetstream_pt/engine.py | 2 +- jetstream_pt/layers.py | 198 +++++++++++++++--------- jetstream_pt/third_party/gemma/model.py | 91 ++++++----- run_interactive.py | 16 +- run_server.py | 1 + scripts/create_empty_sharding_map.py | 93 +++++++++++ tests/test_model_impl.py | 166 ++++++++++---------- 9 files changed, 384 insertions(+), 222 deletions(-) create mode 100644 scripts/create_empty_sharding_map.py diff --git a/convert_checkpoints.py b/convert_checkpoints.py index 6a1a84f0..aa5fc916 100644 --- a/convert_checkpoints.py +++ b/convert_checkpoints.py @@ -414,11 +414,12 @@ def convert_hf_gemma_weights( ckpt_file = list(input_ckpt_dir.glob("*.ckpt")) assert len(ckpt_file) == 1, "only expect 1 ckpt file for Gemma model." ckpt_file = ckpt_file[0] - state_dict = torch.load(ckpt_file, map_location=torch.device("cpu"))[ + state_dict = torch.load(str(ckpt_file), map_location=torch.device("cpu"))[ "model_state_dict" ] model_config = json.loads((input_ckpt_dir / "config.json").read_text()) for key in list(state_dict.keys()): + print(key) if state_dict[key].dtype.is_complex and _OUTPUT_SAFETENSORS.value: assert ( key == "freqs_cis" @@ -431,14 +432,26 @@ def convert_hf_gemma_weights( new_key = key if key.startswith(prefix_to_remove): new_key = new_key.removeprefix(prefix_to_remove) + if "qkv_proj" in key: + q_dim = model_config["num_attention_heads"] * model_config["head_dim"] + kv_dim = model_config["num_key_value_heads"] * model_config["head_dim"] + qkv = state_dict.pop(key) + q, k, v = qkv.split( + [ + q_dim, + kv_dim, + kv_dim, + ], + dim=0, + ) + state_dict[new_key.replace("qkv_proj", "wq")] = q + state_dict[new_key.replace("qkv_proj", "wk")] = k + state_dict[new_key.replace("qkv_proj", "wv")] = v + continue + if new_key != key: state_dict[new_key] = state_dict.pop(key) - - ckpt_basename = os.path.basename(ckpt_file) output_ckpt_dir.mkdir(parents=True, exist_ok=True) - - del state_dict['freqs_cis'] - _export_to_local(output_ckpt_dir, model_config, state_dict) diff --git a/install_everything.sh b/install_everything.sh index f9838a45..46482302 100644 --- a/install_everything.sh +++ b/install_everything.sh @@ -16,15 +16,15 @@ TORCHXLA_TAG=jetstream-pytorch JETSTREAM_TAG=v0.2.1 # Uninstall existing jax -pip3 show jax && pip3 uninstall -y jax -pip3 show jaxlib && pip3 uninstall -y jaxlib -pip3 show libtpu-nightly && pip3 uninstall -y libtpu-nightly +pip show jax && pip uninstall -y jax +pip show jaxlib && pip uninstall -y jaxlib +pip show libtpu-nightly && pip uninstall -y libtpu-nightly -pip3 install pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +pip install pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html # torch cpu -pip3 install torch --index-url https://download.pytorch.org/whl/cpu -pip3 install tensorflow flatbuffers absl-py flax sentencepiece seqio google-cloud-storage -pip3 install safetensors colorama coverage ray[default] humanize +pip install torch --index-url https://download.pytorch.org/whl/cpu +pip install tensorflow flatbuffers absl-py flax sentencepiece seqio google-cloud-storage +pip install safetensors colorama coverage ray[default] humanize mkdir -p deps pushd deps diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 378d338b..7b234782 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -705,7 +705,7 @@ def create_pytorch_engine( pt_model = None if not sharding_config: - sharding_config = os.path.join('default_shardings', model_name + '.yaml') + sharding_config = os.path.join("default_shardings", model_name + ".yaml") env_data = JetEngineEnvironmentData( tokenizer_path=tokenizer_path, diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index a58fcf6b..d85c98e2 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -129,6 +129,117 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: ) +class AttentionKernel: + + def __init__(self, env): + self.env = env + + def __call__(self, xq, xk, xv, mask, cache): + """ + Args: + xq: torch.Tensor of (batch size, num_heads, seqlen, head_dim) + xk: torch.Tensor of (batch size, num_heads, seqlen, head_dim) + xv: torch.Tensor of (batch size, num_heads, seqlen, head_dim) + mask: mask with 0 and -inf, or None + cache: CacheManagerInterface object + """ + bsz, num_heads, seqlen, head_dim = xq.shape + _, _, _, kv_head_dim = xk.shape + n_rep = head_dim // kv_head_dim + if seqlen == 1: + xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) + + with jax.named_scope("attn_insert_cache"): + keys, values = cache.update(xk, xv) + self.env.apply_sharding(keys, axis=1) + self.env.apply_sharding(values, axis=1) + keys = repeat_kv(keys, n_rep) + values = repeat_kv(values, n_rep) + with jax.named_scope("attn_mat1"): + ## Attention start + # scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim) + scores = torch_xla2.extra.call_jax( + jnp.einsum, "ikjl,ikml->ikjm", xq, keys + ) / math.sqrt(head_dim) + self.env.apply_sharding(scores, axis=1) + if mask is not None: + # if mask.shape != (1,1,16,16): + # breakpoint() + scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen) + with jax.named_scope("attn_soft"): + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + + with jax.named_scope("attn_mat2"): + # output = torch.einsum( + # "ikjm,ikml->ikjl", scores, values + # ) # (bs, n_local_heads, seqlen, head_dim) + output = torch_xla2.extra.call_jax( + jnp.einsum, "ikjm,ikml->ikjl", scores, values + ) + if seqlen == 1: + output = output[:, :, 0:1, :] + # For XLA matmul performance boost + # output = torch.matmul(scores, values) + self.env.apply_sharding(output, axis=1) + return output + + +class Int8KVAttentionKernel: + + def __init__(self, env): + self.env = env + + def __call__(self, xq, xk, xv, mask, cache): + """ + Args: + xq: torch.Tensor of (batch size, num_heads, seqlen, head_dim) + xk: torch.Tensor of (batch size, num_heads, seqlen, head_dim) + xv: torch.Tensor of (batch size, num_heads, seqlen, head_dim) + mask: mask with 0 and -inf, or None + cache: CacheManagerInterface object + """ + bsz, num_heads, seqlen, head_dim = xq.shape + _, _, _, kv_head_dim = xk.shape + n_rep = head_dim // kv_head_dim + if seqlen == 1: + xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) + + with jax.named_scope("attn_insert_cache"): + keys, values, k_scaler, v_scaler = cache.update(xk, xv) + self.env.apply_sharding(keys, axis=1) + self.env.apply_sharding(values, axis=1) + keys = repeat_kv(keys, n_rep) + values = repeat_kv(values, n_rep) + with jax.named_scope("attn_mat1"): + ## Attention start + # scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim) + scores = ( + torch_xla2.extra.call_jax(jnp.einsum, "ikjl,ikml->ikjm", xq, keys) + / math.sqrt(head_dim) + * (k_scaler.reshape(bsz, 1, 1, keys.shape[2])) + ) + self.env.apply_sharding(scores, axis=1) + if mask is not None: + scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen) + with jax.named_scope("attn_soft"): + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + scores = scores * v_scaler.reshape((bsz, 1, 1, keys.shape[2])) + self.env.apply_sharding(scores, axis=1) + + with jax.named_scope("attn_mat2"): + # output = torch.einsum( + # "ikjm,ikml->ikjl", scores, values + # ) # (bs, n_local_heads, seqlen, head_dim) + output = torch_xla2.extra.call_jax( + jnp.einsum, "ikjm,ikml->ikjl", scores, values + ) + if seqlen == 1: + output = output[:, :, 0:1, :] + # output = torch.matmul(scores, values) + self.env.apply_sharding(output, axis=1) + return output + + class Attention(nn.Module): """Attention module.""" @@ -151,6 +262,12 @@ def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env): bias=False, device=device, ) + + Kernel = ( + Int8KVAttentionKernel if env.enable_kv_quantization else AttentionKernel + ) + self.attention_kernel = Kernel(env) + self.q_size = n_heads * self.head_dim self.kv_size = self.n_kv_heads * self.head_dim if self.env.qkv_fusion: @@ -219,81 +336,8 @@ def forward( xk = xk.transpose(1, 2) xv = xv.transpose(1, 2) + xq = xq.transpose(1, 2) - if seqlen == 1: - xq = torch.broadcast_to(xq, (xq.shape[0], 2, xq.shape[2], xq.shape[3])) - - if not self.env.enable_kv_quantization: - with jax.named_scope("attn_insert_cache"): - keys, values = cache.update(xk, xv) - self.env.apply_sharding(keys, axis=1) - self.env.apply_sharding(values, axis=1) - keys = repeat_kv(keys, self.n_rep) - values = repeat_kv(values, self.n_rep) - with jax.named_scope("attn_mat1"): - ## Attention start - # scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim) - scores = torch_xla2.extra.call_jax( - jnp.einsum, "ijkl,ikml->ikjm", xq, keys - ) / math.sqrt(self.head_dim) - self.env.apply_sharding(scores, axis=1) - if mask is not None: - # if mask.shape != (1,1,16,16): - # breakpoint() - scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen) - with jax.named_scope("attn_soft"): - scores = F.softmax(scores.float(), dim=-1).type_as(xq) - - with jax.named_scope("attn_mat2"): - # output = torch.einsum( - # "ikjm,ikml->ikjl", scores, values - # ) # (bs, n_local_heads, seqlen, head_dim) - output = torch_xla2.extra.call_jax( - jnp.einsum, "ikjm,ikml->ikjl", scores, values - ) - if seqlen == 1: - output = output[:, :, 0:1, :] - # For XLA matmul performance boost - # output = torch.matmul(scores, values) - self.env.apply_sharding(output, axis=1) - output = output.transpose(-3, -2).contiguous().view(bsz, seqlen, -1) - self.env.apply_sharding(output, axis=2) - output = self.wo(output) - return output - else: - with jax.named_scope("attn_insert_cache"): - keys, values, k_scaler, v_scaler = cache.update(xk, xv) - self.env.apply_sharding(keys, axis=1) - self.env.apply_sharding(values, axis=1) - keys = repeat_kv(keys, self.n_rep) - values = repeat_kv(values, self.n_rep) - with jax.named_scope("attn_mat1"): - ## Attention start - # scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim) - scores = ( - torch_xla2.extra.call_jax(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) - / math.sqrt(self.head_dim) - * (k_scaler.reshape(bsz, 1, 1, keys.shape[2])) - ) - self.env.apply_sharding(scores, axis=1) - if mask is not None: - scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen) - with jax.named_scope("attn_soft"): - scores = F.softmax(scores.float(), dim=-1).type_as(xq) - scores = scores * v_scaler.reshape((bsz, 1, 1, keys.shape[2])) - self.env.apply_sharding(scores, axis=1) - - with jax.named_scope("attn_mat2"): - # output = torch.einsum( - # "ikjm,ikml->ikjl", scores, values - # ) # (bs, n_local_heads, seqlen, head_dim) - output = torch_xla2.extra.call_jax( - jnp.einsum, "ikjm,ikml->ikjl", scores, values - ) - if seqlen == 1: - output = output[:, :, 0:1, :] - # output = torch.matmul(scores, values) - self.env.apply_sharding(output, axis=1) - output = output.transpose(-3, -2).contiguous().view(bsz, seqlen, -1) - self.env.apply_sharding(output, axis=2) - return self.wo(output) + output = self.attention_kernel(xq, xk, xv, mask, cache) + output = output.transpose(-3, -2).contiguous().view(bsz, seqlen, -1) + return self.wo(output) diff --git a/jetstream_pt/third_party/gemma/model.py b/jetstream_pt/third_party/gemma/model.py index 1e0b40b3..a8573519 100644 --- a/jetstream_pt/third_party/gemma/model.py +++ b/jetstream_pt/third_party/gemma/model.py @@ -50,18 +50,17 @@ def reshape_for_broadcast( def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: - """Applies the rotary embedding to the query and key tensors.""" - x_ = torch.view_as_complex( - torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), - dim=-1)) - freqs_cis = reshape_for_broadcast(freqs_cis, x_) - x_out = torch.view_as_real(x_ * freqs_cis).type_as(x) - x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2) - x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], - -1).transpose(1, 2) - return x_out - - + """Applies the rotary embedding to the query and key tensors.""" + x_ = torch.view_as_complex( + torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1) + ) + freqs_cis = reshape_for_broadcast(freqs_cis, x_) + x_out = torch.view_as_real(x_ * freqs_cis).type_as(x) + x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2) + x_out = x_out.reshape( + x_out.shape[0], x_out.shape[1], x_out.shape[2], -1 + ).transpose(1, 2) + return x_out class GemmaAttention(nn.Module): @@ -98,15 +97,37 @@ def __init__( if env.enable_weight_quantization else torch.nn.Linear ) - self.qkv_proj = Linear( - self.hidden_size, - (self.num_heads + 2 * self.num_kv_heads) * self.head_dim, + self.wq = Linear( + hidden_size, + num_heads * self.head_dim, + bias=False, + device=device, + ) + self.wk = Linear( + hidden_size, + self.num_kv_heads * self.head_dim, + bias=False, + device=device, + ) + self.wv = Linear( + hidden_size, + self.num_kv_heads * self.head_dim, bias=False, device=device, ) self.o_proj = Linear( - self.num_heads * self.head_dim, self.hidden_size, bias=False, device=device + self.num_heads * self.head_dim, + self.hidden_size, + bias=False, + device=device, + ) + + Kernel = ( + layers.Int8KVAttentionKernel + if env.enable_kv_quantization + else layers.AttentionKernel ) + self.attention_kernel = Kernel(env) def forward( self, @@ -118,13 +139,19 @@ def forward( hidden_states_shape = hidden_states.shape assert len(hidden_states_shape) == 3 batch_size, input_len, _ = hidden_states_shape - qkv = self.qkv_proj(hidden_states) - xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + xq = self.wq(hidden_states) + xk = self.wk(hidden_states) + xv = self.wv(hidden_states) xq = xq.view(batch_size, -1, self.num_heads, self.head_dim) xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim) xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim) + self.env.apply_sharding(xq, axis=2) + self.env.apply_sharding(xk, axis=2) + self.env.apply_sharding(xv, axis=2) + # Positional embedding. xq = apply_rotary_emb(xq, freqs_cis=freqs_cis) xk = apply_rotary_emb(xk, freqs_cis=freqs_cis) @@ -134,31 +161,9 @@ def forward( xk = xk.transpose(1, 2) xv = xv.transpose(1, 2) + xq = xq.transpose(1, 2) - # [batch_size, n_local_kv_heads, seq len, head_dim] - key, value = cache.update(xk, xv) - - key = key.transpose(1, 2) - value = value.transpose(1, 2) - - if self.num_kv_heads != self.num_heads: - # [batch_size, max_seq_len, n_local_heads, head_dim] - key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2) - value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=2) - - # [batch_size, n_local_heads, input_len, head_dim] - q = xq.transpose(1, 2) - # [batch_size, n_local_heads, max_seq_len, head_dim] - k = key.transpose(1, 2) - v = value.transpose(1, 2) - - # [batch_size, n_local_heads, input_len, max_seq_len] - scores = torch.matmul(q, k.transpose(2, 3)) * self.scaling - scores = scores + mask - scores = F.softmax(scores.float(), dim=-1).type_as(q) - - # [batch_size, n_local_heads, input_len, head_dim] - output = torch.matmul(scores, v) + output = self.attention_kernel(xq, xk, xv, mask, cache) # [batch_size, input_len, hidden_dim] output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1) @@ -238,7 +243,7 @@ def __init__(self, config: gemma_config.GemmaConfig, env): config.device, env, ) - + self.mlp = GemmaMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, diff --git a/run_interactive.py b/run_interactive.py index 5a332710..d0d3b21f 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -121,15 +121,15 @@ def main(argv): decode_state = engine.init_decode_state() prompts: List[str] = [ - "The meaning of life is", + "I believe the meaning of life is", # pylint: disable-next=all - # "To add an element to an ArrayList of a specific class type in Java, you can follow the following steps:\n\n1. Create an instance of the class to be added.\n2. Get a reference to the ArrayList.\n3. Call the `add()` method on the ArrayList, passing the instance of the class as the argument.\n\nHere's an example of how to add an object of type `Person` to an ArrayList of type `ArrayList`:\n```csharp\n// Create a new instance of the Person class\nPerson person = new Person(\"John\", 25);\n\n// Get a reference to the ArrayList\nArrayList peopleList = new ArrayList<>();\n\n// Add the person object to the ArrayList\npeopleList.add(person);\n```\nIn this example, the `Person` class is assumed to have a constructor that takes two arguments: a String for the person's name, and an int for their age. You can substitute your own class and constructor as necessary.", - # # pylint: disable-next=all - # "[INST] <>\nYou are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n<>\n\nQuestion 1: What is commercial real estate finance?\nQuestion 2: What are Commercial Real Estate services?\nOptions are:\n[a]. no.\n[b]. yes.\nWould the answer to these two questions be the same? [/INST]", - # # pylint: disable-next=all - # "[INST] <>\nYou are an AI assistant that helps people find information. Provide a detailed answer so user don\u2019t need to search outside to understand the answer.\n<>\n\nUse reasoning to lead to the answer of the following question:\nWhere are you likely to find water underneath?\nOptions:\n- toilet\n- sink\n- jar\n- bridge\n- house\n Reasoning process: [/INST", - # # pylint: disable-next=all - # "[INST] <>\nYou are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n<>\n\nContinue the following story.\n\nKay didn't have shoes that fit her feet properly. She only wore sneakers, because the \nChoose from: [I] shoes fitted badly. [II] sneakers fitted badly. [/INST]", + "To add an element to an ArrayList of a specific class type in Java, you can follow the following steps:\n\n1. Create an instance of the class to be added.\n2. Get a reference to the ArrayList.\n3. Call the `add()` method on the ArrayList, passing the instance of the class as the argument.\n\nHere's an example of how to add an object of type `Person` to an ArrayList of type `ArrayList`:\n```csharp\n// Create a new instance of the Person class\nPerson person = new Person(\"John\", 25);\n\n// Get a reference to the ArrayList\nArrayList peopleList = new ArrayList<>();\n\n// Add the person object to the ArrayList\npeopleList.add(person);\n```\nIn this example, the `Person` class is assumed to have a constructor that takes two arguments: a String for the person's name, and an int for their age. You can substitute your own class and constructor as necessary.", + # pylint: disable-next=all + "[INST] <>\nYou are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n<>\n\nQuestion 1: What is commercial real estate finance?\nQuestion 2: What are Commercial Real Estate services?\nOptions are:\n[a]. no.\n[b]. yes.\nWould the answer to these two questions be the same? [/INST]", + # pylint: disable-next=all + "[INST] <>\nYou are an AI assistant that helps people find information. Provide a detailed answer so user don\u2019t need to search outside to understand the answer.\n<>\n\nUse reasoning to lead to the answer of the following question:\nWhere are you likely to find water underneath?\nOptions:\n- toilet\n- sink\n- jar\n- bridge\n- house\n Reasoning process: [/INST", + # pylint: disable-next=all + "[INST] <>\nYou are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n<>\n\nContinue the following story.\n\nKay didn't have shoes that fit her feet properly. She only wore sneakers, because the \nChoose from: [I] shoes fitted badly. [II] sneakers fitted badly. [/INST]", ] for prompt in prompts: slot = random.randint(0, _BATCH_SIZE.value - 1) diff --git a/run_server.py b/run_server.py index 8fececac..161af9bd 100644 --- a/run_server.py +++ b/run_server.py @@ -111,6 +111,7 @@ def main(argv: Sequence[str]): param_size=_PARAM_SIZE.value, context_length=_CONTEXT_LENGTH.value, batch_size=_BATCH_SIZE.value, + model_name=_MODEL_NAME.value, quantize_weights=_QUANTIZE_WEIGHTS.value, quantize_kv=_QUANTIZE_KV_CACHE.value, max_cache_length=_MAX_CACHE_LENGTH.value, diff --git a/scripts/create_empty_sharding_map.py b/scripts/create_empty_sharding_map.py new file mode 100644 index 00000000..32d82b99 --- /dev/null +++ b/scripts/create_empty_sharding_map.py @@ -0,0 +1,93 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from absl import app +from absl import flags + +from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData, process_sharding_name +from jetstream_pt.third_party.llama2 import model_exportable, model_args +from jetstream_pt.third_party.gemma import config as gemma_config, model as gemma_model + +FLAGS = flags.FLAGS + +_MODEL_NAME = flags.DEFINE_string( + "model_name", None, "model type", required=False +) + +_SIZE = flags.DEFINE_string("size", "tiny", "size of model") + +_COLLAPSE_SAME_LAYERS = flags.DEFINE_bool("collapse_same_layers", True, "") + + +def create_model(): + batch_size = 3 + env_data = JetEngineEnvironmentData( + batch_size=3, + max_decode_length=1024, + max_input_sequence_length=1024, + enable_weight_quantization=True, + enable_kv_quantization=True, + cache_sequence_length=1024, + bf16_enable=True, + ) + model_name = _MODEL_NAME.value + param_size = _SIZE.value + if model_name.startswith("llama"): + + args = model_args.get_model_args( + param_size, + 1024, + batch_size, + vocab_size=32000, + bf16_enable=True, + ) + args.device = "meta" + args.quantize = False + env = JetEngineEnvironment(env_data) + return model_exportable.Transformer(args, env) + elif model_name == "gemma": + args = gemma_config.get_model_config(param_size) + args.device = "meta" + env_data.model_type = "gemma-" + param_size + env_data.num_layers = args.num_hidden_layers + env = JetEngineEnvironment(env_data) + pt_model = gemma_model.GemmaModel(args, env) + return pt_model + + +# pylint: disable-next=all +def main(argv): + model = create_model() + res = {} + for k, v in model.state_dict().items(): + res[process_sharding_name(k)] = v + + print( + f""" +# Sharding config for {_MODEL_NAME.value} +# Sharding should either be an int between 0 and rank - 1 +# signifying the axis to shard or -1 / null signifying replicated + +""" + ) + + for k, v in res.items(): + print(k, ":", -1, "# ", str(v.dtype), tuple(v.shape)) + + +if __name__ == "__main__": + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + app.run(main) diff --git a/tests/test_model_impl.py b/tests/test_model_impl.py index 971dd638..775c780e 100644 --- a/tests/test_model_impl.py +++ b/tests/test_model_impl.py @@ -24,6 +24,7 @@ from jetstream_pt.third_party.llama import model_exportable from jetstream_pt.third_party.llama import model_original from jetstream_pt.third_party.gemma import model_original as gemma_orig +from jetstream_pt.third_party.gemma import model as gemma from jetstream_pt import layers from jetstream_pt import cache_manager @@ -32,7 +33,7 @@ class ModelComponentTest(unittest.TestCase): """Test diff between original model and xla model for transformer, transformer block, attention and other component in model""" - def setup(self): + def setUp(self): """setup torch env""" jax.config.update("jax_platform_name", "cpu") torch.set_default_dtype(torch.float32) @@ -178,90 +179,95 @@ def test_attention(self): self.assertTrue(torch.allclose(result_torch, expected_out, atol=1e-4)) def test_gemma_attention(self): - env, model_arg = helpers.make_env_tiny(False) - - hidden_size = model_arg.dim - num_heads = model_arg.n_heads - num_kv_heads = model_arg.n_kv_heads - head_dim = model_arg.dim // model_arg.n_heads - # env._data.qkv_fusion = True - - def init_weights(model): - state_dict = model.state_dict() - res = {} - for k, v in state_dict.items(): - #x = random.randint(1, 10) - res[k] = torch.randn(v.shape) #* x - model.load_state_dict(res, assign=True) - - attention_orig = gemma_orig.GemmaAttention( - hidden_size=hidden_size, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - quant=False, - ) - init_weights(attention_orig) - - attention_ours = layers.Attention( - n_heads=num_heads, - n_kv_heads=num_kv_heads, - head_dim=head_dim, - hidden_size=hidden_size, - device="meta", - env=env, - ) - - def load_hook(state_dict, prefix, *args): - wo = state_dict.pop(prefix + "o_proj.weight") - state_dict[prefix + "wo.weight"] = wo - qkv = state_dict.pop(prefix + "qkv_proj.weight") - q, k, v = qkv.split( - [ - attention_orig.q_size, - attention_orig.kv_size, - attention_orig.kv_size, - ], - dim=0, + with jax.default_matmul_precision("float32"): + env, model_arg = helpers.make_env_tiny(False) + + hidden_size = model_arg.dim + num_heads = model_arg.n_heads + num_kv_heads = model_arg.n_kv_heads + head_dim = model_arg.dim // model_arg.n_heads + # env._data.qkv_fusion = True + + def init_weights(model): + state_dict = model.state_dict() + res = {} + for k, v in state_dict.items(): + # x = random.randint(1, 10) + res[k] = torch.randn(v.shape) # * x + model.load_state_dict(res, assign=True) + + attention_orig = gemma_orig.GemmaAttention( + hidden_size=hidden_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + quant=False, + ) + init_weights(attention_orig) + + attention_ours = gemma.GemmaAttention( + hidden_size=hidden_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + device="meta", + env=env, ) - state_dict[prefix + "wq.weight"] = q - state_dict[prefix + "wk.weight"] = k - state_dict[prefix + "wv.weight"] = v - - seqlen = 32 - batch = 1 - x = torch.randn( - (batch, seqlen, hidden_size) - ) # (batch, seqlen, embedding dim) - start_pos = 0 - freqs_cis = self._make_freqs_cis(model_arg, seqlen, start_pos) - mask = self._prefill_mask(seqlen, start_pos) - kv_write_indexes = torch.arange(0, seqlen) - cache_k = torch.zeros((batch, seqlen, num_heads, head_dim)) - cache_v = torch.zeros((batch, seqlen, num_heads, head_dim)) - inputs_orig = (x, freqs_cis, kv_write_indexes, (cache_k, cache_v), mask) - - expected_out = attention_orig(*inputs_orig) - cache = cache_manager.KVCachePrefill() - freqs_cis = freqs_cis.reshape(batch, seqlen, -1) - input_ours = ( - x, - freqs_cis, - mask, - cache, - ) + def load_hook(state_dict, prefix, *args): + qkv = state_dict.pop(prefix + "qkv_proj.weight") + q, k, v = qkv.split( + [ + attention_orig.q_size, + attention_orig.kv_size, + attention_orig.kv_size, + ], + dim=0, + ) + state_dict[prefix + "wq.weight"] = q + state_dict[prefix + "wk.weight"] = k + state_dict[prefix + "wv.weight"] = v + + seqlen = 32 + batch = 1 + x = torch.randn( + (batch, seqlen, hidden_size) + ) # (batch, seqlen, embedding dim) + start_pos = 0 + freqs_cis = self._make_freqs_cis(model_arg, seqlen, start_pos) + mask = self._prefill_mask(seqlen, start_pos) + kv_write_indexes = torch.arange(0, seqlen) + cache_k = torch.zeros((batch, seqlen, num_heads, head_dim)) + cache_v = torch.zeros((batch, seqlen, num_heads, head_dim)) + inputs_orig = (x, freqs_cis, kv_write_indexes, (cache_k, cache_v), mask) + + expected_out = attention_orig(*inputs_orig) + + cache = cache_manager.KVCachePrefill() + freqs_cis = freqs_cis.reshape(batch, seqlen, -1) + input_ours = ( + x, + freqs_cis, + mask, + cache, + ) - state_dict = dict(attention_orig.state_dict()) - load_hook(state_dict, "") - result_torch = self._call_xla_model(attention_ours, state_dict, input_ours) + state_dict = dict(attention_orig.state_dict()) + load_hook(state_dict, "") + result_torch = self._call_xla_model( + attention_ours, state_dict, input_ours + ) - print( - "Single Gemma Attention: Diff norm", - (result_torch - expected_out).norm(), - ) + print( + "Single Gemma Attention: Diff norm", + (result_torch - expected_out).norm(), + ) + print( + "Single Gemma Attention: Diff max", + torch.max((result_torch - expected_out).abs()), + ) - self.assertTrue(torch.allclose(result_torch, expected_out, atol=1e-4)) + self.assertTrue(torch.allclose(result_torch, expected_out, atol=1e-3)) # pylint: disable-next=all def test_transformer_block(self): From a9fe13e9bdf5174006049b2938a0b4d408323abc Mon Sep 17 00:00:00 2001 From: Han Qi Date: Thu, 9 May 2024 18:07:25 +0000 Subject: [PATCH 3/4] comments --- convert_checkpoints.py | 2 -- default_shardings/gemma.yaml | 4 +--- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/convert_checkpoints.py b/convert_checkpoints.py index aa5fc916..65488f74 100644 --- a/convert_checkpoints.py +++ b/convert_checkpoints.py @@ -419,7 +419,6 @@ def convert_hf_gemma_weights( ] model_config = json.loads((input_ckpt_dir / "config.json").read_text()) for key in list(state_dict.keys()): - print(key) if state_dict[key].dtype.is_complex and _OUTPUT_SAFETENSORS.value: assert ( key == "freqs_cis" @@ -451,7 +450,6 @@ def convert_hf_gemma_weights( if new_key != key: state_dict[new_key] = state_dict.pop(key) - output_ckpt_dir.mkdir(parents=True, exist_ok=True) _export_to_local(output_ckpt_dir, model_config, state_dict) diff --git a/default_shardings/gemma.yaml b/default_shardings/gemma.yaml index 9eb1d1a9..03cc392b 100644 --- a/default_shardings/gemma.yaml +++ b/default_shardings/gemma.yaml @@ -3,10 +3,8 @@ # "replicated" to signify "replicated". # Integer signify axis to shard: 0 <= shard axis < rank -freqs_cis : null # torch.complex64 (16384, 128) -layers.*.self_attn.qkv_proj.weight: 0 +freqs_cis : -1 # torch.complex64 (16384, 128) layers.*.self_attn.o_proj.weight: 1 -layers.*.self_attn.wo.weight : 1 # 1, -1] # torch.float32 (2048, 2048) layers.*.self_attn.wq.weight : 0 # -1, 1] # torch.float32 (2048, 2048) layers.*.self_attn.wk.weight : 0 # -1, 1] # torch.float32 (256, 2048) layers.*.self_attn.wv.weight : 0 # -1, 1] # torch.float32 (256, 2048) From 827d4644a2adf8b94e1d10bfd18da4d9a324f3eb Mon Sep 17 00:00:00 2001 From: Han Qi Date: Thu, 9 May 2024 18:19:42 +0000 Subject: [PATCH 4/4] Add doc --- docs/add_a_new_model.md | 286 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 286 insertions(+) create mode 100644 docs/add_a_new_model.md diff --git a/docs/add_a_new_model.md b/docs/add_a_new_model.md new file mode 100644 index 00000000..42245bb8 --- /dev/null +++ b/docs/add_a_new_model.md @@ -0,0 +1,286 @@ +Add a new Model +=============== + +This doc is a detailed guide on how to add a new model to jetstream-pytorch. +The complexity of adding a new model depends highly on the model architecture itself, +and right now is a manual process. + +NOTE: Only LLMs that employ autoregressive decoding that utilices a KV cache are suitable +for serving with Jetstream. Other models such as Stable Diffusion are NOT suitable +with the optimization techniques used in Jetstream. + +The core part of adding a model is to let Jetstream serving engine manage +the KV cache. This management is abstracted by the [`class CacheInterface`](jetstream_pt/cache_manager.py). This interface has a single `update` method that will abstract +the act of inserting and then reading the cache. + +We will walk through this process using [Gemma model](https://github.com/google/gemma_pytorch) as an example. + +# Step 0: Get the model code + +Jetstream pytorch stores its models in the jetstream_pt/third_party directory. + +The usual convention is: + +1. Make a verbatim copy of the model code and supporting files + (such as args class, tokenizers etc) in a separate directory. In our case + it would be [jetstream_pt/third_party/gemma](jetstream_pt/third_party/gemma) + +2. Make a copy of the `model.py` to `model_original.py`; because we will be modifying + it to follow the conventions of Jetstream; and keeping the original can help with + debugging accuracies (and unit tests). + +*Optional:* Clean up model implementation: The easiest model to port are those of + "reference implementations". Models already with optimizations and/or custom + cuda kernels would need to have those changes removed. + +In our case, we choose to use the reference Gemma model from google's github instead +of the HuggingFace version, because HuggingFace version have also training code that +would need to be removed. + +# Step 1: Modify the model to fit the calling conventions expected by Jetstream. + +The model that Jetstream expects and calls follows this calling convention: + +```python +class Model(torch.nn.Module): + + def forward( + self, + tokens: torch.Tensor, + input_pos: torch.Tensor, + caches: List[CacheInterface], + mask: torch.Tensor, + ) -> torch.Tensor: + +``` + +The arguments are: + +* `tokens`: A int tensor with shape (batch_size, sequence_length). This is the token ids + before embedding + +* `input_pos`: The position of the tokens in the overall sentence. This is an int + tensor of shape (batch_size, sequence_length). Note: due to continues batching, + not all batch have the same sequence length. + +* `caches`: A list of objects implementing the `CacheInterface`. CacheInterface has a + single `update` method. + +* `mask`: Mask used in causal attention. + +The return value should be a tensor of shape (batch_size, sequence_length, vocab_size) +of **logits** (not probability) for the next token. + +### Gemma example: + +Now looking back to our Gemma model reference. There are 2 classes in the original +model that is suitable to be our model [GemmaModel](https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L353) and [GemmaForCausalLM](https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L386). Looking at their forward method signature: + +```python + +class GemmaModel(nn.Module): + def forward( + self, + hidden_states: torch.Tensor, + freqs_cis: torch.Tensor, + kv_write_indices: torch.Tensor, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + mask: torch.Tensor, + ) -> torch.Tensor: + +class GemmaForCausalLM(nn.Module): + @torch.no_grad() + def forward( + self, + input_token_ids: torch.Tensor, + input_positions: torch.Tensor, + kv_write_indices: torch.Tensor, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + mask: torch.Tensor, + output_positions: torch.Tensor, + temperatures: Union[torch.Tensor, None], + top_ps: torch.Tensor, + top_ks: torch.Tensor, + **kwargs, + ) -> torch.Tensor: +``` + +We can see that `GemmaModel` is probably closest to port. So we choose that one. +However there are few issues: + +1. GemmaModel takes `hidden_states` instead of tokens +2. GemmaModel returns `hidden_states` after the layers and not logits. + +Let's fix those first. + +Looking at where `GemmaModel` is called in `model.py`, we found that: + +``` + # [batch_size, input_len, hidden_size] + hidden_states = self.embedder(input_token_ids) + # Gemma normalizes the embedding by sqrt(hidden_size). + hidden_states = hidden_states * (self.config.hidden_size**0.5) +``` + +So the input_tokens are embedded with `self.embedder` and processed before calling +`GemmaModel`. So let's move these bit to inside of GemmaModel. + +Now, look where the output of `GemmaModel` is consumed, we see it is feed to `self.sampler`. + +`self.sampler` is of class `Sampler` and it's forward has: + +```python + hidden_states = hidden_states.index_select( + 1, output_positions).squeeze(dim=1) + logits = torch.matmul(hidden_states, embedding.t()) + if embedding_bias is not None: + logits += embedding_bias + + if temperatures is None: + return torch.argmax(logits, dim=-1).squeeze(dim=-1) + ... +``` + +We see it performed some math with hidden states to produce logits, which is what +GemmaModel should return. Now, let move these bits into `GemmaModel` as well. + +Lastly, GemmaModel takes a list of tuple of torch.Tensor as input for caches, +we need to replace it with cache object. + +This cache is plumbed through all the way to `GemmaAttention`, and the [following lines](https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L264C1-L268C53): + +```python + # Write new kv cache. + # [batch_size, input_len, n_local_kv_heads, head_dim] + k_cache, v_cache = kv_cache + k_cache.index_copy_(1, kv_write_indices, xk) + v_cache.index_copy_(1, kv_write_indices, xv) +``` + +is precisely the cache update. +So we need to replace those lines with + +``` +xk = xk.transpose(1, 2) +xv = xv.transpose(1, 2) +k_cache, v_cache = cache.update(xk, xv) +``` + +The transpose is needed because the cache interface's `update` method expects +shape of (batch, num_heads, sequence length, head dim) instead of + (batch, sequence length, num_heads, head dim) that GemmaAttention produces. + + +In our case, because the Attention math is the standard one, we can just call out +to `AttentionKernel` defined in [layers.py](jetstream_pt/layers.py). `AttentionKernel` +also handles reading and writing of `cache` automatically. + +At this point, the model should be runnable. However to run it on a realistic TPU, +we need to add model parallelism. + +# Step 2: Add model parallelism + +Model parallelism is often neccesary to run on TPUs. The typical setup for running +inference work loads is by using TPU `v5light-8` which has 8 TPU chips with 16GB of +high bandwidth memory (HBM) each. The typical `7B` model won't fit on single chip. + +So we need to add model parallelism so the model weights are sharded among the 8 devices. +This is necesary for larger models, such as 70Bs even on high memory chips (v5p). +So it's a good practice to do it right away. + +Jetstream uses GSMPD to for tensor parallelism, the only information we need to +give it is, for every tensor weights, what axis we will shard. We do so by writing +a sharding config file. + +## Generate an sharding config: + +The keys of the sharding file is the name of the weights, (with numeric layers replaced with *), +and value the axis to shard. +for Gemma, we can generate such file by printing out the keys in it's `state_dict`. +See [create_empty_sharding_map.py](scripts/create_empty_sharding_map.py) for example. + +Below: + +```yaml +freqs_cis : -1 # torch.complex64 (16384, 128) +layers.*.self_attn.qkv_proj.weight: 0 +layers.*.self_attn.o_proj.weight: 1 +layers.*.self_attn.wo.weight : 1 # 1, -1] # torch.float32 (2048, 2048) +layers.*.self_attn.wq.weight : 0 # -1, 1] # torch.float32 (2048, 2048) +layers.*.self_attn.wk.weight : 0 # -1, 1] # torch.float32 (256, 2048) +layers.*.self_attn.wv.weight : 0 # -1, 1] # torch.float32 (256, 2048) +layers.*.mlp.gate_proj.weight : 0 # -1, 1] # torch.float32 (16384, 2048) +layers.*.mlp.gate_proj.bias : 0 # -1] # torch.float32 (16384,) +layers.*.mlp.up_proj.weight : 0 # -1, 1] # torch.float32 (16384, 2048) +layers.*.mlp.up_proj.bias : 0 # -1] # torch.float32 (16384,) +layers.*.mlp.down_proj.weight : 1 # 1, -1] # torch.float32 (2048, 16384) +layers.*.mlp.down_proj.bias : -1 # torch.float32 (2048,) +layers.*.input_layernorm.weight : -1 # torch.float32 (2048,) +layers.*.post_attention_layernorm.weight : -1 # torch.float32 (2048,) +norm.weight : -1 # torch.float32 (2048,) +embedder.weight : 1 # # 1, -1] # torch.float32 (256000, 2048) +``` + +the weights `layers.*.self_attn.qkv_proj.weight` where * goes for 1..28, are sharded +on the second dimension (0 based indexing) etc. and -1 means "replicated". + +Theoretically, any valid sharding would work. To find a sharding that performs well one +can usually get some hints from the original model implementation. + +For example, in case of Gemma, the authors also provided an TPU version: https://github.com/google/gemma_pytorch/blob/main/gemma/model_xla.py + +in that file, those with `ColumnParallelLinear` should be sharded on the dimension 0, +and with `RowParallelLinear` should be shard on dimension 1; the others should be +replicated. + +# Step 3: Activation Sharding and quantization + +Sometimes we would like to specify shardings for the activation because GSPMD cannot +fully infer all the shardings. + +The typical example of such case happens after a reshape. For example: if I have a matrix +of shape [A, B * C]; and the second dim is sharded; reshaping it to shape [A, B, C], +the compiler would know that one of the dim B or C is sharded, but cannot know which one. +In this case, it is helpful to specify with a sharding constraint. + +This is done by calling `env.apply_sharding(tensor, axis=1)` on the tensor. + +The `env` object is an instance of `Environment` class; that will be passed in the +model constructor. It also contains some common configurations (such as whether user wants quantization), that is useful for the models. + +For such, we the store that variable in `self.env` and use it when needed. + +For quantization, it suffices to swap `nn.Linear` layers with `Int8QuantizedLinear` defined in +`layers.py` + +# Step 4: Wiring everything up. + +The last step is to modify [engine.py](https://github.com/google/jetstream-pytorch/blob/main/jetstream_pt/engine.py#L738) +and add an if branch in this function. + +This function should receive information about model name and size; and +here it should instantate the model object itself. It also need to tell the environment +information about the cache to allocate: notably how many layers and the shape of +cache. The shape is expected to be (batch size, num_kv_heads, sequence_length, head_dim). + +## Test it out + +After these steps you should be able to run your model using + +```bash +python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml --model=gemma +``` + +If you run it without checkpoint_path it will use random weights, so you can +verify that the code actually run. + +# Step 5: Weight convertion + +Because we modified the model, and the names of variables on the model might have +changed. If so, we need to also modify `convert_weights.py` script to map +the original weights to modified names. + +For example: I split qkv projection to 3 separate projection, this helps with +performance in a sharded environment. So I need to make `convert_weights` script +able to split the weights as well. +