diff --git a/default_shardings/gemma.yaml b/default_shardings/gemma.yaml index da57d36e..4beda7c4 100644 --- a/default_shardings/gemma.yaml +++ b/default_shardings/gemma.yaml @@ -3,7 +3,7 @@ # "replicated" to signify "replicated". # Integer signify axis to shard: 0 <= shard axis < rank -freqs_cis : null # torch.complex64 (16384, 128) +freqs_cis : -1 # torch.complex64 (16384, 128) layers.*.self_attn.wo.weight : 1 # 1, -1] # torch.float32 (2048, 2048) layers.*.self_attn.wq.weight : 0 # -1, 1] # torch.float32 (2048, 2048) layers.*.self_attn.wk.weight : 0 # -1, 1] # torch.float32 (256, 2048) @@ -13,8 +13,8 @@ layers.*.mlp.gate_proj.bias : 0 # -1] # torch.float32 (16384,) layers.*.mlp.up_proj.weight : 0 # -1, 1] # torch.float32 (16384, 2048) layers.*.mlp.up_proj.bias : 0 # -1] # torch.float32 (16384,) layers.*.mlp.down_proj.weight : 1 # 1, -1] # torch.float32 (2048, 16384) -layers.*.mlp.down_proj.bias : null # torch.float32 (2048,) -layers.*.input_layernorm.weight : null # torch.float32 (2048,) -layers.*.post_attention_layernorm.weight : null # torch.float32 (2048,) -norm.weight : null # torch.float32 (2048,) +layers.*.mlp.down_proj.bias : -1 # torch.float32 (2048,) +layers.*.input_layernorm.weight : -1 # torch.float32 (2048,) +layers.*.post_attention_layernorm.weight : -1 # torch.float32 (2048,) +norm.weight : -1 # torch.float32 (2048,) embedder.weight : 1 # # 1, -1] # torch.float32 (256000, 2048) diff --git a/tests/test_llama_e2e.py b/tests/test_llama_e2e.py index c4eb32b4..0b27b28b 100644 --- a/tests/test_llama_e2e.py +++ b/tests/test_llama_e2e.py @@ -22,13 +22,13 @@ import torch import torch_xla2 from torch.utils import _pytree as pytree -from . import helpers from jetstream_pt.engine import PyTorchEngine from jetstream_pt.third_party.llama import model_exportable, model_args from jetstream_pt.third_party.llama.generation_original import LlamaOriginal from jetstream_pt import environment +from tests import helpers class LlamaE2ETest(unittest.TestCase): @@ -93,9 +93,8 @@ def test_jetstream_llama2_seed(self): jax.config.update("jax_platform_name", "cpu") print(f"---------> {jax.devices()}") - torch.set_default_dtype(torch.bfloat16) # pylint: disable-next=all - env, model_arg = helpers.make_env_tiny() + env, model_arg = helpers.make_env_tiny(bf16_enable=True) # pylint: disable-next=all tokens = np.arange(10, dtype=np.int32) true_length = tokens.shape[-1] @@ -221,7 +220,6 @@ def test_llama_e2e_float32(self): print(f"---------> {jax.devices()}") env, model_arg = helpers.make_env_tiny(bf16_enable=False) - torch.set_default_dtype(torch.float32) out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) self.assertEqual(out_tokens, expected_output_tokens) @@ -232,7 +230,6 @@ def test_llama_e2e_bfloat16(self): print(f"---------> {jax.devices()}") env, model_arg = helpers.make_env_tiny(bf16_enable=True) - torch.set_default_dtype(torch.bfloat16) out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) self.assertNotEqual(out_tokens, expected_output_tokens) @@ -242,9 +239,8 @@ def test_llama_e2e_two_addtional_tokens(self): jax.config.update("jax_platform_name", "cpu") print(f"---------> {jax.devices()}") - torch.set_default_dtype(torch.bfloat16) # pylint: disable-next=all - env, model_arg = helpers.make_env_tiny() + env, model_arg = helpers.make_env_tiny(bf16_enable=True) # pylint: disable-next=all tokens = np.arange(10, dtype=np.int32) tokens = np.append(tokens, [15050, 3503], axis=-1) @@ -315,9 +311,8 @@ def test_llama_e2e_four_addtional_tokens(self): jax.config.update("jax_platform_name", "cpu") print(f"---------> {jax.devices()}") - torch.set_default_dtype(torch.bfloat16) # pylint: disable-next=all - env, model_arg = helpers.make_env_tiny() + env, model_arg = helpers.make_env_tiny(bf16_enable=True) # pylint: disable-next=all tokens = np.arange(10, dtype=np.int32) tokens = np.append(tokens, [15050, 3503, 11833, 28551], axis=-1) @@ -387,7 +382,6 @@ def test_llama_with_original_prefill_decode_32(self): print(f"---------> {jax.devices()}") env, model_arg = helpers.make_env_tiny(bf16_enable=False) - torch.set_default_dtype(torch.float32) # pylint: disable-next=all tokens = np.arange(10, dtype=np.int32) true_length = tokens.shape[-1] @@ -458,12 +452,11 @@ def test_llama_with_original_prefill_decode_32(self): # pylint: disable-next=all def test_llama_with_original_prefill_decode(self): - """test jetstream llama by comparing original prefill and decode steps with float32""" + """test jetstream llama by comparing original prefill and decode steps with bf16""" jax.config.update("jax_platform_name", "cpu") print(f"---------> {jax.devices()}") - torch.set_default_dtype(torch.bfloat16) - env, model_arg = helpers.make_env_tiny() + env, model_arg = helpers.make_env_tiny(bf16_enable=True) # pylint: disable-next=all tokens = np.arange(10, dtype=np.int32) true_length = tokens.shape[-1]