From f5de8d4654c9bc24167877e772f8026bb261008c Mon Sep 17 00:00:00 2001 From: Han Qi Date: Fri, 3 May 2024 16:32:55 +0000 Subject: [PATCH] add gemma 1 shard config for llamA gemma 3 gemma4 formatter --- default_shardings/gemma.yaml | 20 + default_shardings/llama-2.yaml | 23 + jetstream_pt/__init__.py | 2 + jetstream_pt/engine.py | 102 ++-- jetstream_pt/environment.py | 45 ++ jetstream_pt/layers.py | 59 +- jetstream_pt/third_party/gemma/__init__.py | 13 + jetstream_pt/third_party/gemma/config.py | 88 +++ jetstream_pt/third_party/gemma/model.py | 205 +++++++ .../third_party/gemma/model_original.py | 574 ++++++++++++++++++ jetstream_pt/third_party/gemma/tokenizer.py | 46 ++ .../third_party/llama/generation_original.py | 2 +- jetstream_pt/third_party/llama/model_args.py | 7 +- .../third_party/llama/model_exportable.py | 9 +- run_interactive.py | 8 + run_server.py | 13 + {tests => scripts}/jax_experiments.py | 68 ++- tests/helpers.py | 1 + tests/test_engine.py | 10 +- tests/test_llama_e2e.py | 10 +- tests/test_model_impl.py | 97 ++- 21 files changed, 1290 insertions(+), 112 deletions(-) create mode 100644 default_shardings/gemma.yaml create mode 100644 default_shardings/llama-2.yaml create mode 100644 jetstream_pt/third_party/gemma/__init__.py create mode 100644 jetstream_pt/third_party/gemma/config.py create mode 100644 jetstream_pt/third_party/gemma/model.py create mode 100644 jetstream_pt/third_party/gemma/model_original.py create mode 100644 jetstream_pt/third_party/gemma/tokenizer.py rename {tests => scripts}/jax_experiments.py (80%) diff --git a/default_shardings/gemma.yaml b/default_shardings/gemma.yaml new file mode 100644 index 00000000..da57d36e --- /dev/null +++ b/default_shardings/gemma.yaml @@ -0,0 +1,20 @@ + +# Sharding config for gemma +# "replicated" to signify "replicated". +# Integer signify axis to shard: 0 <= shard axis < rank + +freqs_cis : null # torch.complex64 (16384, 128) +layers.*.self_attn.wo.weight : 1 # 1, -1] # torch.float32 (2048, 2048) +layers.*.self_attn.wq.weight : 0 # -1, 1] # torch.float32 (2048, 2048) +layers.*.self_attn.wk.weight : 0 # -1, 1] # torch.float32 (256, 2048) +layers.*.self_attn.wv.weight : 0 # -1, 1] # torch.float32 (256, 2048) +layers.*.mlp.gate_proj.weight : 0 # -1, 1] # torch.float32 (16384, 2048) +layers.*.mlp.gate_proj.bias : 0 # -1] # torch.float32 (16384,) +layers.*.mlp.up_proj.weight : 0 # -1, 1] # torch.float32 (16384, 2048) +layers.*.mlp.up_proj.bias : 0 # -1] # torch.float32 (16384,) +layers.*.mlp.down_proj.weight : 1 # 1, -1] # torch.float32 (2048, 16384) +layers.*.mlp.down_proj.bias : null # torch.float32 (2048,) +layers.*.input_layernorm.weight : null # torch.float32 (2048,) +layers.*.post_attention_layernorm.weight : null # torch.float32 (2048,) +norm.weight : null # torch.float32 (2048,) +embedder.weight : 1 # # 1, -1] # torch.float32 (256000, 2048) diff --git a/default_shardings/llama-2.yaml b/default_shardings/llama-2.yaml new file mode 100644 index 00000000..35859a21 --- /dev/null +++ b/default_shardings/llama-2.yaml @@ -0,0 +1,23 @@ + +# Sharding config for llama-2 +# Sharding should either be an int between 0 and rank - 1 +# signifying the axis to shard or -1 / null signifying replicated + + +freqs_cis : -1 # torch.complex64 (2048, 64) +tok_embeddings.weight : 1 # torch.float32 (32000, 4096) +layers.*.attention.wo.weight : 1 # torch.int8 (4096, 4096) +layers.*.attention.wo.weight_scaler : 0 # torch.bfloat16 (4096,) +layers.*.attention.wq.weight : 0 # torch.int8 (4096, 4096) +layers.*.attention.wq.weight_scaler : 0 # torch.bfloat16 (4096,) +layers.*.attention.wk.weight : 0 # torch.int8 (4096, 4096) +layers.*.attention.wk.weight_scaler : 0 # torch.bfloat16 (4096,) +layers.*.attention.wv.weight : 0 # torch.int8 (4096, 4096) +layers.*.attention.wv.weight_scaler : 0 # torch.bfloat16 (4096,) +layers.*.feed_forward.w1.weight : 0 # torch.float32 (11008, 4096) +layers.*.feed_forward.w2.weight : 1 # torch.float32 (4096, 11008) +layers.*.feed_forward.w3.weight : 0 # torch.float32 (11008, 4096) +layers.*.attention_norm.weight : -1 # torch.float32 (4096,) +layers.*.ffn_norm.weight : -1 # torch.float32 (4096,) +norm.weight : -1 # torch.float32 (4096,) +output.weight : 0 # torch.float32 (32000, 4096) diff --git a/jetstream_pt/__init__.py b/jetstream_pt/__init__.py index 1a16fa82..6256f1ce 100644 --- a/jetstream_pt/__init__.py +++ b/jetstream_pt/__init__.py @@ -13,3 +13,5 @@ # limitations under the License. from jetstream_pt.engine import create_pytorch_engine + +__all__ = ["create_pytorch_engine"] diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 328f0fb3..ea77342e 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -34,6 +34,7 @@ from jetstream_pt import quantize from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData from jetstream_pt.third_party.llama import model_exportable, model_args +from jetstream_pt.third_party.gemma import config as gemma_config, model as gemma_model Mesh = jax.sharding.Mesh @@ -108,32 +109,6 @@ def __init__( # out_shardings=self.get_decode_state_sharding()) self._lock = threading.RLock() - # pylint: disable-next=all - def sharding_by_name(self, name): - - # This allows easier way to edit shardings - """ - for key, val in self.env._data.experimental_sharding_axis_override.items(): - if name.endswith(key): - return self.env.sharding_by_axis(val) - """ - - if "weight_scaler" in name: - return self.x_sharding - if "tok_embeddings." in name: - return self.y_sharding - if "attention." in name: - if "wo" in name: - return self.y_sharding - return self.x_sharding - if "feed_forward." in name: - if "w2" in name: - return self.y_sharding - return self.x_sharding - if "output" in name: - return self.x_sharding - return self.replicated - # pylint: disable-next=all def init_decode_state( self, @@ -561,7 +536,7 @@ def _load_from_safetensors(self, path): for key, model_weights in self.pt_model.state_dict().items(): if key == "freqs_cis": continue - arr = jax.device_put(f.get_tensor(key), self.sharding_by_name(key)) + arr = jax.device_put(f.get_tensor(key), self.env.sharding_by_name(key)) assert tuple(model_weights.shape) == tuple( arr.shape ), f"key: {key} error: {model_weights.shape} != {arr.shape}" @@ -587,7 +562,7 @@ def load_params(self) -> Params: else: jax_weights = self._make_state_dict_jax(self.pt_model.state_dict()) jax_weights = { - key: jax.device_put(value, self.sharding_by_name(key)) + key: jax.device_put(value, self.env.sharding_by_name(key)) for key, value in jax_weights.items() } for k, v in jax_weights.items(): @@ -664,6 +639,7 @@ def create_pytorch_engine( quantize_weights=False, quantize_kv=False, max_cache_length=1024, + sharding_config=None, ) -> PyTorchEngine: """Returns: The pytorch engine.""" @@ -706,6 +682,20 @@ def create_pytorch_engine( tokenizer = token_utils.load_vocab(tokenizer_path) pt_model = None + env_data = JetEngineEnvironmentData( + tokenizer_path=tokenizer_path, + checkpoint_path=checkpoint_path, + checkpoint_format=checkpoint_format, + batch_size=batch_size, + max_decode_length=max_decode_length, + max_input_sequence_length=context_length, + enable_weight_quantization=quantize_weights, + enable_kv_quantization=quantize_kv, + cache_sequence_length=max_cache_length, + bf16_enable=bf16_enable, + sharding_config_path=sharding_config, + ) + if model_name.startswith("llama"): args = model_args.get_model_args( @@ -713,35 +703,37 @@ def create_pytorch_engine( ) args.device = "meta" args.quantize = quantize_weights - env_data = JetEngineEnvironmentData( - tokenizer_path=tokenizer_path, - checkpoint_path=checkpoint_path, - checkpoint_format=checkpoint_format, - model_type="llama-2-" + param_size, - batch_size=batch_size, - max_decode_length=max_decode_length, - max_input_sequence_length=context_length, - enable_weight_quantization=quantize_weights, - enable_kv_quantization=quantize_kv, - cache_sequence_length=max_cache_length, - bf16_enable=bf16_enable, - num_layers=args.n_layers, - cache_shape=( - batch_size, - args.n_kv_heads, - max_cache_length, - args.dim // args.n_heads, - ), + env_data.cache_shape = ( + batch_size, + args.n_kv_heads, + max_cache_length, + args.dim // args.n_heads, ) + env_data.model_type = "llama-2-" + param_size + env_data.num_layers = args.n_layers env = JetEngineEnvironment(env_data) pt_model = model_exportable.Transformer(args, env) - - num_params_size = 0 - num_params = 0 - for _, v in pt_model.state_dict().items(): - num_params += 1 - num_params_size += np.prod(v.shape) * (1 if v.dtype == torch.int8 else 2) - print("Number of param Gbytes:", num_params_size / (1 << 30)) - print("Number of param: ", num_params) + elif model_name == "gemma": + args = gemma_config.get_model_config(param_size) + env_data.cache_shape = ( + batch_size, + args.num_key_value_heads, + max_cache_length, + args.head_dim, + ) + env_data.model_type = "gemma-" + param_size + env_data.num_layers = args.num_hidden_layers + env = JetEngineEnvironment(env_data) + pt_model = gemma_model.GemmaModel(args, env) + else: + raise RuntimeError(f"Model with name {model_name} not found") + + num_params_size = 0 + num_params = 0 + for _, v in pt_model.state_dict().items(): + num_params += 1 + num_params_size += np.prod(v.shape) * (1 if v.dtype == torch.int8 else 2) + print("Number of param Gbytes:", num_params_size / (1 << 30)) + print("Number of param: ", num_params) return PyTorchEngine(pt_model=pt_model, env=env) diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index f223f837..453172ef 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -15,6 +15,8 @@ from typing import Tuple, Dict import dataclasses +import yaml + import jax import jax.sharding as jsharding from jax.experimental import mesh_utils @@ -71,6 +73,8 @@ class JetEngineEnvironmentData: # If Ture, use bfloat16 as dtype. If False, use float32 as dtype bf16_enable: bool = True + sharding_config_path: str = "" + # pylint: disable-next=all class JetEngineEnvironment: @@ -100,6 +104,15 @@ def __init__(self, data: JetEngineEnvironmentData): self.cache_sharding = jsharding.NamedSharding( self._mesh, P(*cache_sharding) ) + self._load_sharding_config() + + def _load_sharding_config(self): + """Load sharding config""" + if self._data.sharding_config_path: + with open(self._data.sharding_config_path, encoding="utf-8") as f: + self._sharding_config = yaml.safe_load(f) + else: + self._sharding_config = {} def __getattr__(self, name): return getattr(self._data, name) @@ -150,3 +163,35 @@ def make_caches_generate(self): ) ) return caches + + def sharding_by_name(self, name): + """Create sharding specified in the config.""" + if name in self._sharding_config: + return self.sharding_by_axis(self._sharding_config[name]) + + name = process_sharding_name(name) + if name in self._sharding_config: + return self.sharding_by_axis(self._sharding_config[name]) + + raise RuntimeError("Sharding for name: ", name, " not specified") + + +def process_sharding_name(name): + """Replace integers in param name with *. + + Presumably all layers should have the same sharding. + """ + + def is_integer(t): + try: + int(t) + return True + # pylint: disable-next=all + except: # noqa: E722 + return False + + tokens = name.split(".") + for i, t in enumerate(tokens): + if is_integer(t): + tokens[i] = "*" + return ".".join(tokens) diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index d413d516..e0b32df9 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -132,57 +132,53 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: class Attention(nn.Module): """Attention module.""" - def __init__(self, args, env): + def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env): super().__init__() - - self.n_kv_heads = ( - args.n_heads if args.n_kv_heads is None else args.n_kv_heads - ) - self.n_local_heads = args.n_heads - self.n_local_kv_heads = self.n_kv_heads - self.n_rep = self.n_local_heads // self.n_local_kv_heads - self.head_dim = args.dim // args.n_heads - self.max_seq_len = args.max_seq_len - self.n_heads = args.n_heads - + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.head_dim = head_dim + self.n_rep = self.n_heads // self.n_kv_heads self.env = env + self.hidden_size = hidden_size - LinearLayer = WeightOnlyInt8Linear if args.quantize else nn.Linear + LinearLayer = ( + WeightOnlyInt8Linear if env.enable_weight_quantization else nn.Linear + ) self.wo = LinearLayer( - args.n_heads * self.head_dim, - args.dim, + n_heads * self.head_dim, + hidden_size, bias=False, - device=args.device, + device=device, ) - self.q_size = args.n_heads * self.head_dim + self.q_size = n_heads * self.head_dim self.kv_size = self.n_kv_heads * self.head_dim if self.env.qkv_fusion: self._register_load_state_dict_pre_hook(self.load_hook) self.wqkv = LinearLayer( - args.dim, - (args.n_heads + 2 * self.n_kv_heads) * self.head_dim, + hidden_size, + (n_heads + 2 * self.n_kv_heads) * self.head_dim, bias=False, - device=args.device, + device=device, ) else: self.wq = LinearLayer( - args.dim, - args.n_heads * self.head_dim, + hidden_size, + n_heads * self.head_dim, bias=False, - device=args.device, + device=device, ) self.wk = LinearLayer( - args.dim, + hidden_size, self.n_kv_heads * self.head_dim, bias=False, - device=args.device, + device=device, ) self.wv = LinearLayer( - args.dim, + hidden_size, self.n_kv_heads * self.head_dim, bias=False, - device=args.device, + device=device, ) def load_hook(self, state_dict, prefix, *args): @@ -210,9 +206,9 @@ def forward( ) else: xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) - xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) self.env.apply_sharding(xq, axis=2) self.env.apply_sharding(xk, axis=2) @@ -262,7 +258,8 @@ def forward( self.env.apply_sharding(output, axis=1) output = output.transpose(-3, -2).contiguous().view(bsz, seqlen, -1) self.env.apply_sharding(output, axis=2) - return self.wo(output) + output = self.wo(output) + return output else: with jax.named_scope("attn_insert_cache"): keys, values, k_scaler, v_scaler = cache.update(xk, xv) diff --git a/jetstream_pt/third_party/gemma/__init__.py b/jetstream_pt/third_party/gemma/__init__.py new file mode 100644 index 00000000..6d5e14bc --- /dev/null +++ b/jetstream_pt/third_party/gemma/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jetstream_pt/third_party/gemma/config.py b/jetstream_pt/third_party/gemma/config.py new file mode 100644 index 00000000..5c1f00a6 --- /dev/null +++ b/jetstream_pt/third_party/gemma/config.py @@ -0,0 +1,88 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Gemma model config.""" + +import dataclasses +import torch +from typing import Optional + + +# Keep a mapping from dtype strings to the supported torch dtypes. +_STR_DTYPE_TO_TORCH_DTYPE = dict( + { + "float16": torch.float16, + "float": torch.float32, + "float32": torch.float32, + "bfloat16": torch.bfloat16, + } +) + + +@dataclasses.dataclass +class GemmaConfig: + # The number of tokens in the vocabulary. + vocab_size: int = 256000 + # The maximum sequence length that this model might ever be used with. + max_position_embeddings: int = 8192 + # The number of blocks in the model. + num_hidden_layers: int = 28 + # The number of attention heads used in the attention layers of the model. + num_attention_heads: int = 16 + # The number of key-value heads for implementing attention. + num_key_value_heads: int = 16 + # The hidden size of the model. + hidden_size: int = 3072 + # The dimension of the MLP representations. + intermediate_size: int = 24576 + # The number of head dimensions. + head_dim: int = 256 + # The epsilon used by the rms normalization layers. + rms_norm_eps: float = 1e-6 + # The dtype of the weights. + dtype: str = "bfloat16" + # Whether a quantized version of the model is used. + quant: bool = False + # The path to the model tokenizer. + tokenizer: Optional[str] = "tokenizer/tokenizer.model" + + device: str = "meta" + + def get_dtype(self) -> Optional[torch.dtype]: + """Gets the torch dtype from the config dtype string.""" + return _STR_DTYPE_TO_TORCH_DTYPE.get(self.dtype, None) + + +def get_config_for_7b() -> GemmaConfig: + return GemmaConfig() + + +def get_config_for_2b() -> GemmaConfig: + return GemmaConfig( + num_hidden_layers=18, + num_attention_heads=8, + num_key_value_heads=1, + hidden_size=2048, + intermediate_size=16384, + ) + + +def get_model_config(variant: str) -> GemmaConfig: + if variant == "7b": + return get_config_for_7b() + elif variant == "2b": + return get_config_for_2b() + return ValueError( + f'Invalid variant {variant}. Supported variants are "2b"' 'and "7b"' + ) diff --git a/jetstream_pt/third_party/gemma/model.py b/jetstream_pt/third_party/gemma/model.py new file mode 100644 index 00000000..0d03227f --- /dev/null +++ b/jetstream_pt/third_party/gemma/model.py @@ -0,0 +1,205 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Gemma model implementation.""" + +import torch +from torch import nn +import torch.nn.functional as F +from typing import Any, List + +from . import config as gemma_config + +from jetstream_pt import layers +import jax + + +def precompute_freqs_cis( + dim: int, end: int, theta: float = 10000.0 +) -> torch.Tensor: + """Precomputes the frequency cis.""" + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) + freqs = torch.outer(t, freqs).float() + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +class RMSNorm(torch.nn.Module): + + def __init__( + self, + dim: int, + eps: float = 1e-6, + add_unit_offset: bool = True, + device: str = "meta", + ): + super().__init__() + self.eps = eps + self.add_unit_offset = add_unit_offset + self.weight = nn.Parameter(torch.zeros(dim, device=device)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + x = self._norm(x.float()).type_as(x) + if self.add_unit_offset: + output = x * (1 + self.weight) + else: + output = x * self.weight + return output + + +class GemmaMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + device, + env, + ): + super().__init__() + Linear = ( + layers.WeightOnlyInt8Linear + if env.enable_weight_quantization + else torch.nn.Linear + ) + self.gate_proj = Linear(hidden_size, intermediate_size, device) + self.up_proj = Linear(hidden_size, intermediate_size, device) + self.down_proj = Linear(intermediate_size, hidden_size, device) + + def forward(self, x): + gate = self.gate_proj(x) + gate = F.gelu(gate, approximate="tanh") + up = self.up_proj(x) + fuse = gate * up + outputs = self.down_proj(fuse) + return outputs + + +class GemmaDecoderLayer(nn.Module): + + def __init__(self, config: gemma_config.GemmaConfig, env): + super().__init__() + self.self_attn = layers.Attention( + config.num_attention_heads, + config.num_key_value_heads, + config.head_dim, + config.hidden_size, + config.device, + env, + ) + self.mlp = GemmaMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + env=env, + device=config.device, + ) + self.input_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, device=config.device + ) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, device=config.device + ) + + def forward( + self, + hidden_states: torch.Tensor, + freqs_cis: torch.Tensor, + cache: Any, + mask: torch.Tensor, + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + hidden_states, + freqs_cis=freqs_cis, + mask=mask, + cache=cache, + ) + hidden_states = residual + hidden_states + + # MLP + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class GemmaModel(nn.Module): + + def __init__(self, config: gemma_config.GemmaConfig, env): + super().__init__() + self.config = config + self.vocab_size = config.vocab_size + self.env = env + + self.layers = nn.ModuleList() + for _ in range(config.num_hidden_layers): + self.layers.append(GemmaDecoderLayer(config, env)) + self.norm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, device=config.device + ) + Embedding = ( + layers.Int8Embedding + if env.enable_weight_quantization + else torch.nn.Embedding + ) + + self.embedder = Embedding( + config.vocab_size, config.hidden_size, device=config.device + ) + rope_theta = getattr(config, "rope_theta", 10000) + freqs_cis = precompute_freqs_cis( + config.head_dim, config.max_position_embeddings * 2, theta=rope_theta + ) + self.register_buffer("freqs_cis", freqs_cis) + + @torch.no_grad() + def forward( + self, + tokens: torch.Tensor, + input_pos: torch.Tensor, + caches: List[Any], + mask, + ): + with jax.named_scope("transformer_freq"): + bsz, seqlen = tokens.shape + freqs_cis = self.freqs_cis[input_pos] + freqs_cis = freqs_cis.reshape(bsz, seqlen, -1) + + hidden_states = self.embedder(tokens) + hidden_states = hidden_states * (self.config.hidden_size**0.5) + + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states = layer( + hidden_states=hidden_states, + freqs_cis=freqs_cis, + cache=caches[i], + mask=mask, + ) + hidden_states = self.norm(hidden_states) + + embedder_weight = self.embedder.weight + if self.config.quant: + embedder_weight = embedder_weight * self.embedder.weight_scaler.unsqueeze( + -1 + ) + logits = torch.matmul(hidden_states, embedder_weight.t()) + return logits diff --git a/jetstream_pt/third_party/gemma/model_original.py b/jetstream_pt/third_party/gemma/model_original.py new file mode 100644 index 00000000..d1a9c47f --- /dev/null +++ b/jetstream_pt/third_party/gemma/model_original.py @@ -0,0 +1,574 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Gemma model implementation.""" + +import torch +from torch import nn +import torch.nn.functional as F +from typing import Any, List, Optional, Sequence, Tuple, Union + +from . import config as gemma_config +from . import tokenizer + + +class Sampler(nn.Module): + + def __init__(self, vocab_size: int): + super().__init__() + self.vocab_size = vocab_size + + @torch.no_grad() + def forward( + self, + embedding: torch.Tensor, + hidden_states: torch.Tensor, + output_positions: torch.Tensor, + temperatures: Union[torch.Tensor, None], + top_ps: torch.Tensor, + top_ks: torch.Tensor, + embedding_bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Select the last element for each sequence. + # (batch_size, input_len, hidden_size) -> (batch_size, hidden_size) + hidden_states = hidden_states.index_select(1, output_positions).squeeze( + dim=1 + ) + logits = torch.matmul(hidden_states, embedding.t()) + if embedding_bias is not None: + logits += embedding_bias + + if temperatures is None: + return torch.argmax(logits, dim=-1).squeeze(dim=-1) + + # Apply temperature scaling. + logits.div_(temperatures.unsqueeze(dim=1)) + + # Calculate probabilities with softmax. + probs = torch.softmax(logits, dim=-1, dtype=torch.float) + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + + # Apply top-p, top-k. + probs_sum = torch.cumsum(probs_sort, dim=-1) + top_ps_mask = (probs_sum - probs_sort) > top_ps.unsqueeze(dim=1) + probs_sort = torch.where(top_ps_mask, 0, probs_sort) + + top_ks_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device) + top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1) + top_ks_mask = top_ks_mask >= top_ks.unsqueeze(dim=1) + probs_sort = torch.where(top_ks_mask, 0, probs_sort) + + # Re-normalization. + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + probs = torch.gather( + probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1) + ) + + next_token_ids = torch.multinomial( + probs, num_samples=1, replacement=True + ).squeeze(dim=-1) + return next_token_ids + + +def precompute_freqs_cis( + dim: int, end: int, theta: float = 10000.0 +) -> torch.Tensor: + """Precomputes the frequency cis.""" + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) + freqs = torch.outer(t, freqs).float() + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + """Applies the rotary embedding to the query and key tensors.""" + x_ = torch.view_as_complex( + torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1) + ) + x_out = torch.view_as_real(x_ * freqs_cis).type_as(x) + x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2) + x_out = x_out.reshape( + x_out.shape[0], x_out.shape[1], x_out.shape[2], -1 + ).transpose(1, 2) + return x_out + + +class Linear(nn.Module): + + def __init__(self, in_features: int, out_features: int, quant: bool): + super().__init__() + if quant: + self.weight = nn.Parameter( + torch.empty((out_features, in_features), dtype=torch.int8), + requires_grad=False, + ) + self.weight_scaler = nn.Parameter(torch.Tensor(out_features)) + else: + self.weight = nn.Parameter( + torch.empty((out_features, in_features)), + requires_grad=False, + ) + self.quant = quant + + def forward(self, x): + weight = self.weight + if self.quant: + weight = weight * self.weight_scaler.unsqueeze(-1) + output = F.linear(x, weight) + return output + + +class Embedding(nn.Module): + + def __init__(self, num_embeddings: int, embedding_dim: int, quant: bool): + super().__init__() + if quant: + self.weight = nn.Parameter( + torch.empty((num_embeddings, embedding_dim), dtype=torch.int8), + requires_grad=False, + ) + self.weight_scaler = nn.Parameter(torch.Tensor(num_embeddings)) + else: + self.weight = nn.Parameter( + torch.empty((num_embeddings, embedding_dim)), + requires_grad=False, + ) + self.quant = quant + + def forward(self, x): + weight = self.weight + if self.quant: + weight = weight * self.weight_scaler.unsqueeze(-1) + output = F.embedding(x, weight) + return output + + +class RMSNorm(torch.nn.Module): + + def __init__( + self, + dim: int, + eps: float = 1e-6, + add_unit_offset: bool = True, + ): + super().__init__() + self.eps = eps + self.add_unit_offset = add_unit_offset + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + x = self._norm(x.float()).type_as(x) + if self.add_unit_offset: + output = x * (1 + self.weight) + else: + output = x * self.weight + return output + + +class GemmaMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + quant: bool, + ): + super().__init__() + self.gate_proj = Linear(hidden_size, intermediate_size, quant) + self.up_proj = Linear(hidden_size, intermediate_size, quant) + self.down_proj = Linear(intermediate_size, hidden_size, quant) + + def forward(self, x): + gate = self.gate_proj(x) + gate = F.gelu(gate, approximate="tanh") + up = self.up_proj(x) + fuse = gate * up + outputs = self.down_proj(fuse) + return outputs + + +class GemmaAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + quant: bool, + ): + super().__init__() + + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + self.hidden_size = hidden_size + self.head_dim = head_dim + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = Linear( + self.hidden_size, + (self.num_heads + 2 * self.num_kv_heads) * self.head_dim, + quant=quant, + ) + self.o_proj = Linear( + self.num_heads * self.head_dim, self.hidden_size, quant=quant + ) + + def forward( + self, + hidden_states: torch.Tensor, + freqs_cis: torch.Tensor, + kv_write_indices: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor], + mask: torch.Tensor, + ) -> torch.Tensor: + hidden_states_shape = hidden_states.shape + assert len(hidden_states_shape) == 3 + batch_size, input_len, _ = hidden_states_shape + qkv = self.qkv_proj(hidden_states) + xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + xq = xq.view(batch_size, -1, self.num_heads, self.head_dim) + xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim) + xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim) + + # Positional embedding. + xq = apply_rotary_emb(xq, freqs_cis=freqs_cis) + xk = apply_rotary_emb(xk, freqs_cis=freqs_cis) + + # Write new kv cache. + # [batch_size, input_len, n_local_kv_heads, head_dim] + k_cache, v_cache = kv_cache + k_cache.index_copy_(1, kv_write_indices, xk) + v_cache.index_copy_(1, kv_write_indices, xv) + + key = k_cache + value = v_cache + if self.num_kv_heads != self.num_heads: + # [batch_size, max_seq_len, n_local_heads, head_dim] + key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2) + value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=2) + + # [batch_size, n_local_heads, input_len, head_dim] + q = xq.transpose(1, 2) + # [batch_size, n_local_heads, max_seq_len, head_dim] + k = key.transpose(1, 2) + v = value.transpose(1, 2) + + # [batch_size, n_local_heads, input_len, max_seq_len] + scores = torch.matmul(q, k.transpose(2, 3)) * self.scaling + scores = scores + mask + scores = F.softmax(scores.float(), dim=-1).type_as(q) + + # [batch_size, n_local_heads, input_len, head_dim] + output = torch.matmul(scores, v) + + # [batch_size, input_len, hidden_dim] + output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1) + output = self.o_proj(output) + return output + + +class GemmaDecoderLayer(nn.Module): + + def __init__( + self, + config: gemma_config.GemmaConfig, + ): + super().__init__() + self.self_attn = GemmaAttention( + hidden_size=config.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=config.head_dim, + quant=config.quant, + ) + self.mlp = GemmaMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant=config.quant, + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + freqs_cis: torch.Tensor, + kv_write_indices: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor], + mask: torch.Tensor, + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + freqs_cis=freqs_cis, + kv_write_indices=kv_write_indices, + kv_cache=kv_cache, + mask=mask, + ) + hidden_states = residual + hidden_states + + # MLP + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class GemmaModel(nn.Module): + + def __init__(self, config: gemma_config.GemmaConfig): + super().__init__() + self.config = config + self.vocab_size = config.vocab_size + + self.layers = nn.ModuleList() + for _ in range(config.num_hidden_layers): + self.layers.append(GemmaDecoderLayer(config)) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + freqs_cis: torch.Tensor, + kv_write_indices: torch.Tensor, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + mask: torch.Tensor, + ) -> torch.Tensor: + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states = layer( + hidden_states=hidden_states, + freqs_cis=freqs_cis, + kv_write_indices=kv_write_indices, + kv_cache=kv_caches[i], + mask=mask, + ) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class GemmaForCausalLM(nn.Module): + + def __init__( + self, + config: gemma_config.GemmaConfig, + ): + super().__init__() + self.config = config + assert config.hidden_size % config.num_attention_heads == 0 + + max_seq_len = config.max_position_embeddings + head_dim = config.head_dim + vocab_size = config.vocab_size + + self.tokenizer = tokenizer.Tokenizer(config.tokenizer) + self.embedder = Embedding(vocab_size, config.hidden_size, config.quant) + self.model = GemmaModel(config) + self.sampler = Sampler(vocab_size) + + # Pre-compute rotary embedding table. + rope_theta = getattr(config, "rope_theta", 10000) + freqs_cis = precompute_freqs_cis( + head_dim, max_seq_len * 2, theta=rope_theta + ) + self.register_buffer("freqs_cis", freqs_cis) + + @torch.no_grad() + def forward( + self, + input_token_ids: torch.Tensor, + input_positions: torch.Tensor, + kv_write_indices: torch.Tensor, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + mask: torch.Tensor, + output_positions: torch.Tensor, + temperatures: Union[torch.Tensor, None], + top_ps: torch.Tensor, + top_ks: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + freqs_cis = self.freqs_cis.index_select(0, input_positions) + kv_write_indices = input_positions + + # [batch_size, input_len, hidden_size] + hidden_states = self.embedder(input_token_ids) + # Gemma normalizes the embedding by sqrt(hidden_size). + hidden_states = hidden_states * (self.config.hidden_size**0.5) + + hidden_states = self.model( + hidden_states=hidden_states, + freqs_cis=freqs_cis, + kv_write_indices=kv_write_indices, + kv_caches=kv_caches, + mask=mask, + ) + embedder_weight = self.embedder.weight + if self.config.quant: + embedder_weight = embedder_weight * self.embedder.weight_scaler.unsqueeze( + -1 + ) + next_tokens = self.sampler( + embedding=embedder_weight, + hidden_states=hidden_states, + output_positions=output_positions, + temperatures=temperatures, + top_ps=top_ps, + top_ks=top_ks, + ) + return next_tokens + + def generate( + self, + prompts: Union[str, Sequence[str]], + device: Any, + output_len: int = 100, + temperature: Union[float, None] = 0.95, + top_p: float = 1.0, + top_k: int = 100, + ) -> Union[str, Sequence[str]]: + """Generates responses for given prompts using Gemma model.""" + # If a single prompt is provided, treat it as a batch of 1. + is_str_prompt = isinstance(prompts, str) + if is_str_prompt: + prompts = [prompts] + + batch_size = len(prompts) + prompt_tokens = [self.tokenizer.encode(prompt) for prompt in prompts] + min_prompt_len = min(len(p) for p in prompt_tokens) + max_prompt_len = max(len(p) for p in prompt_tokens) + max_seq_len = max_prompt_len + output_len + assert max_seq_len <= self.config.max_position_embeddings + + # build KV caches + kv_caches = [] + for _ in range(self.config.num_hidden_layers): + size = ( + batch_size, + max_seq_len, + self.config.num_key_value_heads, + self.config.head_dim, + ) + dtype = self.config.get_dtype() + k_cache = torch.zeros(size=size, dtype=dtype, device=device) + v_cache = torch.zeros(size=size, dtype=dtype, device=device) + kv_caches.append((k_cache, v_cache)) + + # prepare inputs + token_ids_tensor = torch.full( + (batch_size, max_seq_len), self.tokenizer.pad_id, dtype=torch.int64 + ) + input_token_ids_tensor = torch.full( + (batch_size, min_prompt_len), self.tokenizer.pad_id, dtype=torch.int64 + ) + for i, p in enumerate(prompt_tokens): + token_ids_tensor[i, : len(p)] = torch.tensor(p) + input_token_ids_tensor[i, :min_prompt_len] = torch.tensor( + p[:min_prompt_len] + ) + token_ids_tensor = token_ids_tensor.to(device) + input_token_ids_tensor = input_token_ids_tensor.to(device) + prompt_mask_tensor = token_ids_tensor != self.tokenizer.pad_id + input_positions_tensor = torch.arange( + 0, min_prompt_len, dtype=torch.int64 + ).to(device) + mask_tensor = torch.full( + (1, 1, max_seq_len, max_seq_len), -2.3819763e38 + ).to(torch.float) + mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device) + curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor) + output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(device) + temperatures_tensor = ( + None + if not temperature + else torch.FloatTensor([temperature] * batch_size).to(device) + ) + top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device) + top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device) + output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to(device) + + # Prefill up to min_prompt_len tokens, then treat other prefill as + # decode and ignore output. + for i in range(max_seq_len - min_prompt_len): + next_token_ids = self( + input_token_ids=input_token_ids_tensor, + input_positions=input_positions_tensor, + kv_write_indices=None, + kv_caches=kv_caches, + mask=curr_mask_tensor, + output_positions=output_positions_tensor, + temperatures=temperatures_tensor, + top_ps=top_ps_tensor, + top_ks=top_ks_tensor, + ) + + curr_prompt_mask = prompt_mask_tensor.index_select( + 1, output_index + ).squeeze(dim=1) + curr_token_ids = token_ids_tensor.index_select(1, output_index).squeeze( + dim=1 + ) + output_token_ids = torch.where( + curr_prompt_mask, curr_token_ids, next_token_ids + ).unsqueeze(dim=1) + token_ids_tensor.index_copy_(1, output_index, output_token_ids) + + input_token_ids_tensor = output_token_ids + input_positions_tensor = output_index.unsqueeze(dim=-1) + curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor) + output_positions_tensor = torch.tensor(0, dtype=torch.int64).to(device) + output_index = output_index + 1 + + # Detokenization. + token_ids = token_ids_tensor.tolist() + results = [] + for i, tokens in enumerate(token_ids): + trimmed_output = tokens[ + len(prompt_tokens[i]) : len(prompt_tokens[i]) + output_len + ] + if self.tokenizer.eos_id in trimmed_output: + eos_index = trimmed_output.index(self.tokenizer.eos_id) + trimmed_output = trimmed_output[:eos_index] + results.append(self.tokenizer.decode(trimmed_output)) + + # If a string was provided as input, return a string as output. + return results[0] if is_str_prompt else results + + def load_weights(self, model_path: str): + self.load_state_dict( + torch.load( + model_path, + mmap=True, + weights_only=True, + )["model_state_dict"], + strict=False, + ) diff --git a/jetstream_pt/third_party/gemma/tokenizer.py b/jetstream_pt/third_party/gemma/tokenizer.py new file mode 100644 index 00000000..38591ead --- /dev/null +++ b/jetstream_pt/third_party/gemma/tokenizer.py @@ -0,0 +1,46 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import List, Optional + +from sentencepiece import SentencePieceProcessor + + +class Tokenizer: + + def __init__(self, model_path: Optional[str]): + # Reload tokenizer. + assert os.path.isfile(model_path), model_path + self.sp_model = SentencePieceProcessor(model_file=model_path) + + # BOS / EOS token IDs. + self.n_words: int = self.sp_model.vocab_size() + self.bos_id: int = self.sp_model.bos_id() + self.eos_id: int = self.sp_model.eos_id() + self.pad_id: int = self.sp_model.pad_id() + assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() + + def encode(self, s: str, bos: bool = True, eos: bool = False) -> List[int]: + """Converts a string into a list of tokens.""" + assert isinstance(s, str) + t = self.sp_model.encode(s) + if bos: + t = [self.bos_id] + t + if eos: + t = t + [self.eos_id] + return t + + def decode(self, t: List[int]) -> str: + """Converts a list of tokens into a string.""" + return self.sp_model.decode(t) diff --git a/jetstream_pt/third_party/llama/generation_original.py b/jetstream_pt/third_party/llama/generation_original.py index dd4339ee..6f1af3ee 100644 --- a/jetstream_pt/third_party/llama/generation_original.py +++ b/jetstream_pt/third_party/llama/generation_original.py @@ -2,7 +2,7 @@ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. import os -from typing import List, Literal, Optional, Tuple, TypedDict +from typing import List, Literal, Optional, TypedDict import torch from jetstream_pt.third_party.llama import model_original diff --git a/jetstream_pt/third_party/llama/model_args.py b/jetstream_pt/third_party/llama/model_args.py index b9143384..bb04371b 100755 --- a/jetstream_pt/third_party/llama/model_args.py +++ b/jetstream_pt/third_party/llama/model_args.py @@ -2,12 +2,7 @@ """The original Llama2 model.""" import dataclasses -import math -from typing import Optional, Tuple - -import torch -from torch import nn -import torch.nn.functional as F +from typing import Optional @dataclasses.dataclass diff --git a/jetstream_pt/third_party/llama/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py index 106e1f0b..07559a8f 100644 --- a/jetstream_pt/third_party/llama/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -75,7 +75,14 @@ def __init__( self.dim = args.dim self.head_dim = args.dim // args.n_heads - self.attention = Attention(args, env) + self.attention = Attention( + args.n_heads, + args.n_kv_heads or args.n_heads, + args.dim // args.n_heads, + args.dim, + env=env, + device=args.device, + ) self.feed_forward = FeedForward( dim=args.dim, hidden_dim=4 * args.dim, diff --git a/run_interactive.py b/run_interactive.py index f338beb0..0dd16bb9 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -39,6 +39,9 @@ "The tokenizer model path", required=False, ) +_MODEL_NAME = flags.DEFINE_string( + "model_name", None, "model type", required=False +) _CKPT_PATH = flags.DEFINE_string( "checkpoint_path", None, "Directory for .pth checkpoints", required=False ) @@ -74,6 +77,9 @@ "llama-2", "name of the model. Supported options are llama-2 and llama-3", ) +_SHARDING_CONFIG = flags.DEFINE_string( + "sharding_config", "", "config file for sharding" +) def create_engine(): @@ -84,6 +90,7 @@ def create_engine(): devices = jax.devices() start = time.perf_counter() engine = je.create_pytorch_engine( + model_name=_MODEL_NAME.value, devices=devices, tokenizer_path=_TOKENIZER_PATH.value, ckpt_path=_CKPT_PATH.value, @@ -95,6 +102,7 @@ def create_engine(): quantize_weights=_QUANTIZE_WEIGHTS.value, quantize_kv=_QUANTIZE_KV_CACHE.value, max_cache_length=_MAX_CACHE_LENGTH.value, + sharding_config=_SHARDING_CONFIG.value, ) print("Initialize engine", time.perf_counter() - start) diff --git a/run_server.py b/run_server.py index 77d7f173..c3603fba 100644 --- a/run_server.py +++ b/run_server.py @@ -86,6 +86,12 @@ _MAX_CACHE_LENGTH = flags.DEFINE_integer( "max_cache_length", 1024, "kv_cache_quantize" ) +_SHARDING_CONFIG = flags.DEFINE_string( + "sharding_config", "", "config file for sharding" +) +_MODEL_NAME = flags.DEFINE_string( + "model_name", "llama-2", "model name, defaults to llama-2" +) # pylint: disable-next=all @@ -95,6 +101,11 @@ def main(argv: Sequence[str]): # No devices for local cpu test. A None for prefill and a None for generate. devices = server_lib.get_devices() print(f"devices: {devices}") + sharding_config_path = _SHARDING_CONFIG.value + if not sharding_config_path: + sharding_config_path = os.path.join( + "default_shardings", _MODEL_NAME.value + ".yaml" + ) engine = jetstream_pt.create_pytorch_engine( devices=devices, tokenizer_path=_TOKENIZER_PATH.value, @@ -107,6 +118,8 @@ def main(argv: Sequence[str]): quantize_weights=_QUANTIZE_WEIGHTS.value, quantize_kv=_QUANTIZE_KV_CACHE.value, max_cache_length=_MAX_CACHE_LENGTH.value, + sharding_config=sharding_config_path, + model_name=_MODEL_NAME.value, ) server_config = ServerConfig( interleaved_slices=(_PLATFORM.value,), diff --git a/tests/jax_experiments.py b/scripts/jax_experiments.py similarity index 80% rename from tests/jax_experiments.py rename to scripts/jax_experiments.py index e11a317c..3d38094a 100644 --- a/tests/jax_experiments.py +++ b/scripts/jax_experiments.py @@ -331,4 +331,70 @@ def test6(): print(x[:, :, 0:1, :]) -test6() +# pylint: disable-next=all +def test7(): + """insert cache test""" + batch, seq, heads, dim = 96, 2048, 40, 128 + sharding = PositionalSharding(mesh_utils.create_device_mesh((8,))) + sharding = sharding.reshape((1, 8, 1, 1)) + val_sharding = sharding.reshape((1, 8, 1, 1)) + caches_k = jnp.zeros( + (batch, heads, seq, dim), device=sharding, dtype=jnp.bfloat16 + ) + jnp.zeros((batch, heads, seq, dim), device=sharding, dtype=jnp.bfloat16) + + def insert_cache(cache, index, new_entry): + res = cache.at[:, :, index, :].set(new_entry) + res = jax.lax.with_sharding_constraint(res, sharding) + return res + + # pylint: disable-next=all + def insert_cache2(cache, index, new_entry): + res = cache.at[jnp.arange(batch), :, index, :].set(new_entry) + res = jax.lax.with_sharding_constraint(res, sharding) + return res + + insert_cache = jax.jit(insert_cache, donate_argnums=(0, 1)) + insert_cache2 = jax.jit(insert_cache2, donate_argnums=(0, 1)) + insert_seqlen = 1024 + + subkey = jax.random.PRNGKey(234) + to_insert = jax.device_put( + jax.random.normal( + subkey, (1, heads, insert_seqlen, dim), dtype=jnp.bfloat16 + ), + device=val_sharding, + ).block_until_ready() + # pylint: disable-next=all + j = jnp.int32(7).block_until_ready() + + update_indexes = (jnp.arange(-insert_seqlen, 0) + 7) % 1024 + head_indexes = jnp.arange(heads).reshape(1, -1, 1) + + rng = jax.random.PRNGKey(0) + + jax.profiler.start_trace("/tmp/insert_trace") + for func in (insert_cache, insert_cache2): + for _ in range(10): + all_times = 0 + for j in range(40): + rng, subkey = jax.random.split(rng) + val = jax.device_put( + jax.random.normal(subkey, (batch, heads, dim), dtype=jnp.bfloat16), + device=sharding.reshape((1, 8, 1)), + ).block_until_ready() + # pylint: disable-next=all + j = jnp.int32(j).block_until_ready() + if func == insert_cache2: + j = jnp.broadcast_to(j, (batch,)).block_until_ready() + start = time.perf_counter() + # pylint: disable-next=all + caches_k = func(caches_k, j, val) + caches_k.block_until_ready() + end = time.perf_counter() + all_times += end - start + print(func.__name__, "time is", all_times) + jax.profiler.stop_trace() + + +test7() diff --git a/tests/helpers.py b/tests/helpers.py index dd0c7c50..32b8652c 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,3 +1,4 @@ +import jax import torch import jax from jetstream_pt.third_party.llama import model_args diff --git a/tests/test_engine.py b/tests/test_engine.py index a24e0d8e..286e9b31 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -13,15 +13,7 @@ # limitations under the License. # pylint: disable=all -from typing import List, Any -import unittest -import torch -import torch_xla2 -import jax.numpy as jnp - -from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData -from jetstream_pt.engine import PyTorchEngine, Prefix, DecodeState -from jetstream_pt.third_party.llama import model_exportable, model_original + # This model will output tokens with value of 2 # and will update caches with value of 1.0 diff --git a/tests/test_llama_e2e.py b/tests/test_llama_e2e.py index 10e698d6..c4eb32b4 100644 --- a/tests/test_llama_e2e.py +++ b/tests/test_llama_e2e.py @@ -34,10 +34,6 @@ class LlamaE2ETest(unittest.TestCase): """This test class includes all E2E test for llama2""" - def setup(self): - """setup torch env""" - torch.set_default_dtype(torch.bfloat16) - def _to_jax(self, tree): return pytree.tree_map_only(torch.Tensor, torch_xla2.tensor.t2j, tree) @@ -225,6 +221,7 @@ def test_llama_e2e_float32(self): print(f"---------> {jax.devices()}") env, model_arg = helpers.make_env_tiny(bf16_enable=False) + torch.set_default_dtype(torch.float32) out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) self.assertEqual(out_tokens, expected_output_tokens) @@ -235,6 +232,7 @@ def test_llama_e2e_bfloat16(self): print(f"---------> {jax.devices()}") env, model_arg = helpers.make_env_tiny(bf16_enable=True) + torch.set_default_dtype(torch.bfloat16) out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) self.assertNotEqual(out_tokens, expected_output_tokens) @@ -388,8 +386,8 @@ def test_llama_with_original_prefill_decode_32(self): jax.config.update("jax_platform_name", "cpu") print(f"---------> {jax.devices()}") - torch.set_default_dtype(torch.float32) env, model_arg = helpers.make_env_tiny(bf16_enable=False) + torch.set_default_dtype(torch.float32) # pylint: disable-next=all tokens = np.arange(10, dtype=np.int32) true_length = tokens.shape[-1] @@ -464,7 +462,7 @@ def test_llama_with_original_prefill_decode(self): jax.config.update("jax_platform_name", "cpu") print(f"---------> {jax.devices()}") - torch.set_default_dtype(torch.float32) + torch.set_default_dtype(torch.bfloat16) env, model_arg = helpers.make_env_tiny() # pylint: disable-next=all tokens = np.arange(10, dtype=np.int32) diff --git a/tests/test_model_impl.py b/tests/test_model_impl.py index e8c0f375..1a280563 100644 --- a/tests/test_model_impl.py +++ b/tests/test_model_impl.py @@ -13,6 +13,7 @@ # limitations under the License. import unittest +import random import jax import jax.numpy as jnp import torch @@ -22,8 +23,8 @@ from jetstream_pt.third_party.llama import model_exportable from jetstream_pt.third_party.llama import model_original +from jetstream_pt.third_party.gemma import model_original as gemma_orig from jetstream_pt import layers -from jetstream_pt import environment from jetstream_pt import cache_manager @@ -101,7 +102,14 @@ def test_attention(self): env, model_arg = helpers.make_env_tiny(False) attention_orig = model_original.Attention(model_arg) - attention_ours = layers.Attention(model_arg, env) + attention_ours = layers.Attention( + n_heads=model_arg.n_heads, + n_kv_heads=model_arg.n_kv_heads, + head_dim=model_arg.dim // model_arg.n_heads, + hidden_size=model_arg.dim, + device="cpu", + env=env, + ) seqlen = 32 batch = 1 @@ -169,6 +177,91 @@ def test_attention(self): ) self.assertTrue(torch.allclose(result_torch, expected_out, atol=1e-4)) + def test_gemma_attention(self): + env, model_arg = helpers.make_env_tiny(False) + + hidden_size = model_arg.dim + num_heads = model_arg.n_heads + num_kv_heads = model_arg.n_kv_heads + head_dim = model_arg.dim // model_arg.n_heads + # env._data.qkv_fusion = True + + def init_weights(model): + state_dict = model.state_dict() + res = {} + for k, v in state_dict.items(): + x = random.randint(1, 10) + res[k] = torch.ones(v.shape) * x + model.load_state_dict(res, assign=True) + + attention_orig = gemma_orig.GemmaAttention( + hidden_size=hidden_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + quant=False, + ) + init_weights(attention_orig) + + attention_ours = layers.Attention( + n_heads=num_heads, + n_kv_heads=num_kv_heads, + head_dim=head_dim, + hidden_size=hidden_size, + device="meta", + env=env, + ) + + def load_hook(state_dict, prefix, *args): + wo = state_dict.pop(prefix + "o_proj.weight") + state_dict[prefix + "wo.weight"] = wo + qkv = state_dict.pop(prefix + "qkv_proj.weight") + q, k, v = qkv.split( + [ + attention_orig.q_size, + attention_orig.kv_size, + attention_orig.kv_size, + ], + dim=0, + ) + state_dict[prefix + "wq.weight"] = q + state_dict[prefix + "wk.weight"] = k + state_dict[prefix + "wv.weight"] = v + + seqlen = 32 + batch = 1 + x = torch.randn( + (batch, seqlen, hidden_size) + ) # (batch, seqlen, embedding dim) + start_pos = 0 + freqs_cis = self._make_freqs_cis(model_arg, seqlen, start_pos) + mask = self._prefill_mask(seqlen, start_pos) + kv_write_indexes = torch.arange(0, seqlen) + cache_k = torch.zeros((batch, seqlen, num_heads, head_dim)) + cache_v = torch.zeros((batch, seqlen, num_heads, head_dim)) + inputs_orig = (x, freqs_cis, kv_write_indexes, (cache_k, cache_v), mask) + + expected_out = attention_orig(*inputs_orig) + + cache = cache_manager.KVCachePrefill() + freqs_cis = freqs_cis.reshape(batch, seqlen, -1) + input_ours = ( + x, + freqs_cis, + mask, + cache, + ) + + state_dict = dict(attention_orig.state_dict()) + load_hook(state_dict, "") + result_torch = self._call_xla_model(attention_ours, state_dict, input_ours) + + print( + "Single Gemma Attention: Diff norm", + (result_torch - expected_out).norm(), + ) + self.assertTrue(torch.allclose(result_torch, expected_out, atol=1e-4)) + # pylint: disable-next=all def test_transformer_block(self): env, model_arg = helpers.make_env_tiny(False)