Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions default_shardings/gemma.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@

# Sharding config for gemma
# "replicated" to signify "replicated".
# Integer signify axis to shard: 0 <= shard axis < rank

freqs_cis : null # torch.complex64 (16384, 128)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we keep consistency on replicated sharding? If either null or -1 is fine, shall we just keep -1 in our code base (use null in gemma, but -1 in llama)?

layers.*.self_attn.wo.weight : 1 # 1, -1] # torch.float32 (2048, 2048)
layers.*.self_attn.wq.weight : 0 # -1, 1] # torch.float32 (2048, 2048)
layers.*.self_attn.wk.weight : 0 # -1, 1] # torch.float32 (256, 2048)
layers.*.self_attn.wv.weight : 0 # -1, 1] # torch.float32 (256, 2048)
layers.*.mlp.gate_proj.weight : 0 # -1, 1] # torch.float32 (16384, 2048)
layers.*.mlp.gate_proj.bias : 0 # -1] # torch.float32 (16384,)
layers.*.mlp.up_proj.weight : 0 # -1, 1] # torch.float32 (16384, 2048)
layers.*.mlp.up_proj.bias : 0 # -1] # torch.float32 (16384,)
layers.*.mlp.down_proj.weight : 1 # 1, -1] # torch.float32 (2048, 16384)
layers.*.mlp.down_proj.bias : null # torch.float32 (2048,)
layers.*.input_layernorm.weight : null # torch.float32 (2048,)
layers.*.post_attention_layernorm.weight : null # torch.float32 (2048,)
norm.weight : null # torch.float32 (2048,)
embedder.weight : 1 # # 1, -1] # torch.float32 (256000, 2048)
23 changes: 23 additions & 0 deletions default_shardings/llama-2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

# Sharding config for llama-2
# Sharding should either be an int between 0 and rank - 1
# signifying the axis to shard or -1 / null signifying replicated


freqs_cis : -1 # torch.complex64 (2048, 64)
tok_embeddings.weight : 1 # torch.float32 (32000, 4096)
layers.*.attention.wo.weight : 1 # torch.int8 (4096, 4096)
layers.*.attention.wo.weight_scaler : 0 # torch.bfloat16 (4096,)
layers.*.attention.wq.weight : 0 # torch.int8 (4096, 4096)
layers.*.attention.wq.weight_scaler : 0 # torch.bfloat16 (4096,)
layers.*.attention.wk.weight : 0 # torch.int8 (4096, 4096)
layers.*.attention.wk.weight_scaler : 0 # torch.bfloat16 (4096,)
layers.*.attention.wv.weight : 0 # torch.int8 (4096, 4096)
layers.*.attention.wv.weight_scaler : 0 # torch.bfloat16 (4096,)
layers.*.feed_forward.w1.weight : 0 # torch.float32 (11008, 4096)
layers.*.feed_forward.w2.weight : 1 # torch.float32 (4096, 11008)
layers.*.feed_forward.w3.weight : 0 # torch.float32 (11008, 4096)
layers.*.attention_norm.weight : -1 # torch.float32 (4096,)
layers.*.ffn_norm.weight : -1 # torch.float32 (4096,)
norm.weight : -1 # torch.float32 (4096,)
output.weight : 0 # torch.float32 (32000, 4096)
2 changes: 2 additions & 0 deletions jetstream_pt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@
# limitations under the License.

from jetstream_pt.engine import create_pytorch_engine

__all__ = ["create_pytorch_engine"]
102 changes: 47 additions & 55 deletions jetstream_pt/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from jetstream_pt import quantize
from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData
from jetstream_pt.third_party.llama import model_exportable, model_args
from jetstream_pt.third_party.gemma import config as gemma_config, model as gemma_model


Mesh = jax.sharding.Mesh
Expand Down Expand Up @@ -108,32 +109,6 @@ def __init__(
# out_shardings=self.get_decode_state_sharding())
self._lock = threading.RLock()

# pylint: disable-next=all
def sharding_by_name(self, name):

# This allows easier way to edit shardings
"""
for key, val in self.env._data.experimental_sharding_axis_override.items():
if name.endswith(key):
return self.env.sharding_by_axis(val)
"""

if "weight_scaler" in name:
return self.x_sharding
if "tok_embeddings." in name:
return self.y_sharding
if "attention." in name:
if "wo" in name:
return self.y_sharding
return self.x_sharding
if "feed_forward." in name:
if "w2" in name:
return self.y_sharding
return self.x_sharding
if "output" in name:
return self.x_sharding
return self.replicated

# pylint: disable-next=all
def init_decode_state(
self,
Expand Down Expand Up @@ -561,7 +536,7 @@ def _load_from_safetensors(self, path):
for key, model_weights in self.pt_model.state_dict().items():
if key == "freqs_cis":
continue
arr = jax.device_put(f.get_tensor(key), self.sharding_by_name(key))
arr = jax.device_put(f.get_tensor(key), self.env.sharding_by_name(key))
assert tuple(model_weights.shape) == tuple(
arr.shape
), f"key: {key} error: {model_weights.shape} != {arr.shape}"
Expand All @@ -587,7 +562,7 @@ def load_params(self) -> Params:
else:
jax_weights = self._make_state_dict_jax(self.pt_model.state_dict())
jax_weights = {
key: jax.device_put(value, self.sharding_by_name(key))
key: jax.device_put(value, self.env.sharding_by_name(key))
for key, value in jax_weights.items()
}
for k, v in jax_weights.items():
Expand Down Expand Up @@ -664,6 +639,7 @@ def create_pytorch_engine(
quantize_weights=False,
quantize_kv=False,
max_cache_length=1024,
sharding_config=None,
) -> PyTorchEngine:
"""Returns: The pytorch engine."""

Expand Down Expand Up @@ -706,42 +682,58 @@ def create_pytorch_engine(
tokenizer = token_utils.load_vocab(tokenizer_path)
pt_model = None

env_data = JetEngineEnvironmentData(
tokenizer_path=tokenizer_path,
checkpoint_path=checkpoint_path,
checkpoint_format=checkpoint_format,
batch_size=batch_size,
max_decode_length=max_decode_length,
max_input_sequence_length=context_length,
enable_weight_quantization=quantize_weights,
enable_kv_quantization=quantize_kv,
cache_sequence_length=max_cache_length,
bf16_enable=bf16_enable,
sharding_config_path=sharding_config,
)

if model_name.startswith("llama"):

args = model_args.get_model_args(
model_name + "-" + param_size, context_length, batch_size, bf16_enable
)
args.device = "meta"
args.quantize = quantize_weights
env_data = JetEngineEnvironmentData(
tokenizer_path=tokenizer_path,
checkpoint_path=checkpoint_path,
checkpoint_format=checkpoint_format,
model_type="llama-2-" + param_size,
batch_size=batch_size,
max_decode_length=max_decode_length,
max_input_sequence_length=context_length,
enable_weight_quantization=quantize_weights,
enable_kv_quantization=quantize_kv,
cache_sequence_length=max_cache_length,
bf16_enable=bf16_enable,
num_layers=args.n_layers,
cache_shape=(
batch_size,
args.n_kv_heads,
max_cache_length,
args.dim // args.n_heads,
),
env_data.cache_shape = (
batch_size,
args.n_kv_heads,
max_cache_length,
args.dim // args.n_heads,
)
env_data.model_type = "llama-2-" + param_size
env_data.num_layers = args.n_layers
env = JetEngineEnvironment(env_data)
pt_model = model_exportable.Transformer(args, env)

num_params_size = 0
num_params = 0
for _, v in pt_model.state_dict().items():
num_params += 1
num_params_size += np.prod(v.shape) * (1 if v.dtype == torch.int8 else 2)
print("Number of param Gbytes:", num_params_size / (1 << 30))
print("Number of param: ", num_params)
elif model_name == "gemma":
args = gemma_config.get_model_config(param_size)
env_data.cache_shape = (
batch_size,
args.num_key_value_heads,
max_cache_length,
args.head_dim,
)
env_data.model_type = "gemma-" + param_size
env_data.num_layers = args.num_hidden_layers
env = JetEngineEnvironment(env_data)
pt_model = gemma_model.GemmaModel(args, env)
else:
raise RuntimeError(f"Model with name {model_name} not found")

num_params_size = 0
num_params = 0
for _, v in pt_model.state_dict().items():
num_params += 1
num_params_size += np.prod(v.shape) * (1 if v.dtype == torch.int8 else 2)
print("Number of param Gbytes:", num_params_size / (1 << 30))
print("Number of param: ", num_params)

return PyTorchEngine(pt_model=pt_model, env=env)
45 changes: 45 additions & 0 deletions jetstream_pt/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from typing import Tuple, Dict

import dataclasses
import yaml

import jax
import jax.sharding as jsharding
from jax.experimental import mesh_utils
Expand Down Expand Up @@ -71,6 +73,8 @@ class JetEngineEnvironmentData:
# If Ture, use bfloat16 as dtype. If False, use float32 as dtype
bf16_enable: bool = True

sharding_config_path: str = ""


# pylint: disable-next=all
class JetEngineEnvironment:
Expand Down Expand Up @@ -100,6 +104,15 @@ def __init__(self, data: JetEngineEnvironmentData):
self.cache_sharding = jsharding.NamedSharding(
self._mesh, P(*cache_sharding)
)
self._load_sharding_config()

def _load_sharding_config(self):
"""Load sharding config"""
if self._data.sharding_config_path:
with open(self._data.sharding_config_path, encoding="utf-8") as f:
self._sharding_config = yaml.safe_load(f)
else:
self._sharding_config = {}

def __getattr__(self, name):
return getattr(self._data, name)
Expand Down Expand Up @@ -150,3 +163,35 @@ def make_caches_generate(self):
)
)
return caches

def sharding_by_name(self, name):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, it's more clear than previous hardcode one.

"""Create sharding specified in the config."""
if name in self._sharding_config:
return self.sharding_by_axis(self._sharding_config[name])

name = process_sharding_name(name)
if name in self._sharding_config:
return self.sharding_by_axis(self._sharding_config[name])

raise RuntimeError("Sharding for name: ", name, " not specified")


def process_sharding_name(name):
"""Replace integers in param name with *.

Presumably all layers should have the same sharding.
"""

def is_integer(t):
try:
int(t)
return True
# pylint: disable-next=all
except: # noqa: E722
return False

tokens = name.split(".")
for i, t in enumerate(tokens):
if is_integer(t):
tokens[i] = "*"
return ".".join(tokens)
59 changes: 28 additions & 31 deletions jetstream_pt/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,57 +132,53 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
class Attention(nn.Module):
"""Attention module."""

def __init__(self, args, env):
def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, it's nice to see the layers is decoupled with args.

super().__init__()

self.n_kv_heads = (
args.n_heads if args.n_kv_heads is None else args.n_kv_heads
)
self.n_local_heads = args.n_heads
self.n_local_kv_heads = self.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.max_seq_len = args.max_seq_len
self.n_heads = args.n_heads

self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = head_dim
self.n_rep = self.n_heads // self.n_kv_heads
self.env = env
self.hidden_size = hidden_size

LinearLayer = WeightOnlyInt8Linear if args.quantize else nn.Linear
LinearLayer = (
WeightOnlyInt8Linear if env.enable_weight_quantization else nn.Linear
)

self.wo = LinearLayer(
args.n_heads * self.head_dim,
args.dim,
n_heads * self.head_dim,
hidden_size,
bias=False,
device=args.device,
device=device,
)
self.q_size = args.n_heads * self.head_dim
self.q_size = n_heads * self.head_dim
self.kv_size = self.n_kv_heads * self.head_dim
if self.env.qkv_fusion:
self._register_load_state_dict_pre_hook(self.load_hook)
self.wqkv = LinearLayer(
args.dim,
(args.n_heads + 2 * self.n_kv_heads) * self.head_dim,
hidden_size,
(n_heads + 2 * self.n_kv_heads) * self.head_dim,
bias=False,
device=args.device,
device=device,
)
else:
self.wq = LinearLayer(
args.dim,
args.n_heads * self.head_dim,
hidden_size,
n_heads * self.head_dim,
bias=False,
device=args.device,
device=device,
)
self.wk = LinearLayer(
args.dim,
hidden_size,
self.n_kv_heads * self.head_dim,
bias=False,
device=args.device,
device=device,
)
self.wv = LinearLayer(
args.dim,
hidden_size,
self.n_kv_heads * self.head_dim,
bias=False,
device=args.device,
device=device,
)

def load_hook(self, state_dict, prefix, *args):
Expand Down Expand Up @@ -210,9 +206,9 @@ def forward(
)
else:
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)

self.env.apply_sharding(xq, axis=2)
self.env.apply_sharding(xk, axis=2)
Expand Down Expand Up @@ -262,7 +258,8 @@ def forward(
self.env.apply_sharding(output, axis=1)
output = output.transpose(-3, -2).contiguous().view(bsz, seqlen, -1)
self.env.apply_sharding(output, axis=2)
return self.wo(output)
output = self.wo(output)
return output
else:
with jax.named_scope("attn_insert_cache"):
keys, values, k_scaler, v_scaler = cache.update(xk, xv)
Expand Down
13 changes: 13 additions & 0 deletions jetstream_pt/third_party/gemma/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading