Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions default_shardings/gemma.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# "replicated" to signify "replicated".
# Integer signify axis to shard: 0 <= shard axis < rank

freqs_cis : null # torch.complex64 (16384, 128)
freqs_cis : -1 # torch.complex64 (16384, 128)
layers.*.self_attn.wo.weight : 1 # 1, -1] # torch.float32 (2048, 2048)
layers.*.self_attn.wq.weight : 0 # -1, 1] # torch.float32 (2048, 2048)
layers.*.self_attn.wk.weight : 0 # -1, 1] # torch.float32 (256, 2048)
Expand All @@ -13,8 +13,8 @@ layers.*.mlp.gate_proj.bias : 0 # -1] # torch.float32 (16384,)
layers.*.mlp.up_proj.weight : 0 # -1, 1] # torch.float32 (16384, 2048)
layers.*.mlp.up_proj.bias : 0 # -1] # torch.float32 (16384,)
layers.*.mlp.down_proj.weight : 1 # 1, -1] # torch.float32 (2048, 16384)
layers.*.mlp.down_proj.bias : null # torch.float32 (2048,)
layers.*.input_layernorm.weight : null # torch.float32 (2048,)
layers.*.post_attention_layernorm.weight : null # torch.float32 (2048,)
norm.weight : null # torch.float32 (2048,)
layers.*.mlp.down_proj.bias : -1 # torch.float32 (2048,)
layers.*.input_layernorm.weight : -1 # torch.float32 (2048,)
layers.*.post_attention_layernorm.weight : -1 # torch.float32 (2048,)
norm.weight : -1 # torch.float32 (2048,)
embedder.weight : 1 # # 1, -1] # torch.float32 (256000, 2048)
19 changes: 6 additions & 13 deletions tests/test_llama_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
import torch
import torch_xla2
from torch.utils import _pytree as pytree
from . import helpers


from jetstream_pt.engine import PyTorchEngine
from jetstream_pt.third_party.llama import model_exportable, model_args
from jetstream_pt.third_party.llama.generation_original import LlamaOriginal
from jetstream_pt import environment
from tests import helpers


class LlamaE2ETest(unittest.TestCase):
Expand Down Expand Up @@ -93,9 +93,8 @@ def test_jetstream_llama2_seed(self):
jax.config.update("jax_platform_name", "cpu")
print(f"---------> {jax.devices()}")

torch.set_default_dtype(torch.bfloat16)
# pylint: disable-next=all
env, model_arg = helpers.make_env_tiny()
env, model_arg = helpers.make_env_tiny(bf16_enable=True)
# pylint: disable-next=all
tokens = np.arange(10, dtype=np.int32)
true_length = tokens.shape[-1]
Expand Down Expand Up @@ -221,7 +220,6 @@ def test_llama_e2e_float32(self):
print(f"---------> {jax.devices()}")

env, model_arg = helpers.make_env_tiny(bf16_enable=False)
torch.set_default_dtype(torch.float32)
out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg)
self.assertEqual(out_tokens, expected_output_tokens)

Expand All @@ -232,7 +230,6 @@ def test_llama_e2e_bfloat16(self):
print(f"---------> {jax.devices()}")

env, model_arg = helpers.make_env_tiny(bf16_enable=True)
torch.set_default_dtype(torch.bfloat16)
out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg)
self.assertNotEqual(out_tokens, expected_output_tokens)

Expand All @@ -242,9 +239,8 @@ def test_llama_e2e_two_addtional_tokens(self):
jax.config.update("jax_platform_name", "cpu")
print(f"---------> {jax.devices()}")

torch.set_default_dtype(torch.bfloat16)
# pylint: disable-next=all
env, model_arg = helpers.make_env_tiny()
env, model_arg = helpers.make_env_tiny(bf16_enable=True)
# pylint: disable-next=all
tokens = np.arange(10, dtype=np.int32)
tokens = np.append(tokens, [15050, 3503], axis=-1)
Expand Down Expand Up @@ -315,9 +311,8 @@ def test_llama_e2e_four_addtional_tokens(self):
jax.config.update("jax_platform_name", "cpu")
print(f"---------> {jax.devices()}")

torch.set_default_dtype(torch.bfloat16)
# pylint: disable-next=all
env, model_arg = helpers.make_env_tiny()
env, model_arg = helpers.make_env_tiny(bf16_enable=True)
# pylint: disable-next=all
tokens = np.arange(10, dtype=np.int32)
tokens = np.append(tokens, [15050, 3503, 11833, 28551], axis=-1)
Expand Down Expand Up @@ -387,7 +382,6 @@ def test_llama_with_original_prefill_decode_32(self):
print(f"---------> {jax.devices()}")

env, model_arg = helpers.make_env_tiny(bf16_enable=False)
torch.set_default_dtype(torch.float32)
# pylint: disable-next=all
tokens = np.arange(10, dtype=np.int32)
true_length = tokens.shape[-1]
Expand Down Expand Up @@ -458,12 +452,11 @@ def test_llama_with_original_prefill_decode_32(self):

# pylint: disable-next=all
def test_llama_with_original_prefill_decode(self):
"""test jetstream llama by comparing original prefill and decode steps with float32"""
"""test jetstream llama by comparing original prefill and decode steps with bf16"""
jax.config.update("jax_platform_name", "cpu")
print(f"---------> {jax.devices()}")

torch.set_default_dtype(torch.bfloat16)
env, model_arg = helpers.make_env_tiny()
env, model_arg = helpers.make_env_tiny(bf16_enable=True)
# pylint: disable-next=all
tokens = np.arange(10, dtype=np.int32)
true_length = tokens.shape[-1]
Expand Down