From 2a8e71a3a297dfcd7cb81c6930a9803c6d4dda65 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Wed, 19 Jun 2024 23:27:10 +0000 Subject: [PATCH 1/2] Add test for Mixtral model. --- jetstream_pt/third_party/mixtral/config.py | 11 ++++ .../third_party/mixtral/model_original.py | 23 ++++--- tests/helpers.py | 35 ++++++++-- tests/test_model_impl.py | 66 +++++++++++++++++++ 4 files changed, 122 insertions(+), 13 deletions(-) diff --git a/jetstream_pt/third_party/mixtral/config.py b/jetstream_pt/third_party/mixtral/config.py index cf6ab3d1..8966d737 100644 --- a/jetstream_pt/third_party/mixtral/config.py +++ b/jetstream_pt/third_party/mixtral/config.py @@ -75,4 +75,15 @@ def from_name(cls, name: str): num_experts=8, num_activated_experts=2, ), + "Mixtral-tiny": dict( + block_size=128, + n_layer=3, + n_head=32, + n_local_heads=8, + dim=128, + intermediate_size=None, + rope_base=1000000.0, + num_experts=8, + num_activated_experts=2, + ), } diff --git a/jetstream_pt/third_party/mixtral/model_original.py b/jetstream_pt/third_party/mixtral/model_original.py index 5087d35a..add9ad83 100644 --- a/jetstream_pt/third_party/mixtral/model_original.py +++ b/jetstream_pt/third_party/mixtral/model_original.py @@ -23,6 +23,12 @@ from .config import ModelArgs, find_multiple +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + + class KVCache(nn.Module): def __init__( @@ -31,7 +37,8 @@ def __init__( max_seq_length, n_heads, head_dim, - dtype=torch.bfloat16, + # dtype=torch.bfloat16, + dtype=torch.float32, ): super().__init__() cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) @@ -191,22 +198,21 @@ class ConditionalFeedForward(nn.Module): def __init__(self, config): super().__init__() + # Replace the weight init of torch.empty with torch.rand for testing purpose self.w1 = nn.Parameter( - torch.empty(config.num_experts, config.intermediate_size, config.dim) + torch.rand(config.num_experts, config.intermediate_size, config.dim) ) self.w2 = nn.Parameter( - torch.empty(config.num_experts, config.dim, config.intermediate_size) + torch.rand(config.num_experts, config.dim, config.intermediate_size) ) self.w3 = nn.Parameter( - torch.empty(config.num_experts, config.intermediate_size, config.dim) + torch.rand(config.num_experts, config.intermediate_size, config.dim) ) def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor: - # T = num_tokens, I = intermediate size, D = hidden dim, A = activated experts w1_weights = self.w1[expert_indices] # [T, A, D, D] w3_weights = self.w3[expert_indices] # [T, A, D, D] w2_weights = self.w2[expert_indices] # [T, A, D, D] - # x: [T, D] x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights)) x3 = torch.einsum("ti, taoi -> tao", x, w3_weights) expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights) @@ -215,7 +221,7 @@ def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor: class MOEFeedForward(nn.Module): - def __init__(self, config, env=None) -> None: + def __init__(self, config) -> None: super().__init__() self.gate = nn.Linear(config.dim, config.num_experts, bias=False) self.cond_ffn = ConditionalFeedForward(config) @@ -261,7 +267,8 @@ def precompute_freqs_cis( freqs = torch.outer(t, freqs) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) - return cache.to(dtype=torch.bfloat16) + # return cache.to(dtype=torch.bfloat16) + return cache def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: diff --git a/tests/helpers.py b/tests/helpers.py index 5886e7bc..82ace510 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,8 +1,8 @@ import jax import torch import torch_xla2 -import jax from jetstream_pt.third_party.llama import model_args +from jetstream_pt.third_party.mixtral import config as mixtral_config from jetstream_pt import environment @@ -31,13 +31,38 @@ def make_env_tiny(bf16_enable=True): return env, config +def make_mixtral_env(bf16_enable=True): + torch_dtype = torch.bfloat16 if bf16_enable else torch.float32 + torch.set_default_dtype(torch_dtype) + jax.config.update("jax_dynamic_shapes", False) + jax.config.update("jax_traceback_filtering", "off") + config = mixtral_config.ModelArgs.from_name("Mixtral-tiny") + environment_data = environment.JetEngineEnvironmentData() + environment_data.max_input_sequence_length = 128 + environment_data.cache_sequence_length = 128 + environment_data.bf16_enable = bf16_enable + environment_data.model_type = "mixtral" + environment_data.batch_size = 1 + environment_data.num_layers = config.n_layer + environment_data.cache_shape = ( + 1, + config.n_local_heads, + environment_data.cache_sequence_length, + config.dim // config.n_head, + ) + env = environment.JetEngineEnvironment(environment_data) + env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu + return env, config + + def to_xla_tensor(tree): return torch_xla2.default_env().to_xla(tree) def call_xla_model(model, weights, args): with jax.default_device(jax.devices("cpu")[0]): - xla_weights, xla_inputs = to_xla_tensor((weights, args)) - result = torch.func.functional_call(model, xla_weights, xla_inputs) - result_torch = torch_xla2.tensor.j2t(result._elem) - return result_torch + with torch_xla2.default_env(): + xla_weights, xla_inputs = to_xla_tensor((weights, args)) + result = torch.func.functional_call(model, xla_weights, xla_inputs) + result_torch = torch_xla2.tensor.j2t(result._elem) + return result_torch diff --git a/tests/test_model_impl.py b/tests/test_model_impl.py index 65ac8913..c072d9d0 100644 --- a/tests/test_model_impl.py +++ b/tests/test_model_impl.py @@ -23,6 +23,7 @@ from jetstream_pt.third_party.llama import model_original from jetstream_pt.third_party.gemma import model_original as gemma_orig from jetstream_pt.third_party.gemma import model as gemma +from jetstream_pt.third_party.mixtral import model_original as mixtral_orig from jetstream_pt.third_party.mixtral import model as mixtral from jetstream_pt.third_party.mixtral import config as mixtral_config from jetstream_pt import torchjax @@ -362,6 +363,71 @@ def test_transformer(self): print("Transformer: Diff norm", (result_torch - expected_out).norm()) self.assertTrue(torch.allclose(result_torch, expected_out, atol=1e-4)) + # pylint: disable-next=all + def test_mixtral_transformer(self): + """test transformer diff between original model vs xla_model""" + env, model_arg = helpers.make_mixtral_env(False) + + model_orig = mixtral_orig.Transformer(model_arg) + model_orig.setup_caches(max_batch_size=1, max_seq_length=env.cache_len) + + state_dict = dict(model_orig.state_dict()) + state_dict["freqs_cis"] = model_orig.freqs_cis + new_dict = {} + + for k, v in state_dict.items(): + if "kv_cache" in k: + continue + if "wqkv" in k: + wq = k.replace("wqkv", "wq") + wk = k.replace("wqkv", "wk") + wv = k.replace("wqkv", "wv") + kv_size = model_arg.n_local_heads * model_arg.head_dim + wq_t, wk_t, wv_t = v.split([model_arg.dim, kv_size, kv_size], dim=0) + + new_dict[wq] = wq_t + new_dict[wk] = wk_t + new_dict[wv] = wv_t + continue + # "Freqs_cis" for exported model is calculated differently, by complex data type + if "freqs_cis" in k: + new_dict[k] = mixtral.precompute_freqs_cis( + model_arg.block_size, + model_arg.dim // model_arg.n_head, + model_arg.rope_base, + ) + continue + new_dict[k] = v + + import pdb + + pdb.set_trace() + model_ours = mixtral.Transformer(model_arg, env) + + # Invoke original model + seqlen = 32 + x = torch.randint(0, 32000, (1, seqlen)) # (batch, seqlen, embedding dim) + start_pos = 0 + mask = self._prefill_mask(seqlen, start_pos) + input_pos = torch.arange(0, seqlen) + inputs_orig = (x, input_pos) + + expected_out = model_orig(*inputs_orig) + + # Invoke the exported model + caches = env.make_caches_prefill() + + input_ours = ( + x, + input_pos, + caches, + mask, + ) + result_torch = helpers.call_xla_model(model_ours, new_dict, input_ours) + + print("Transformer: Diff norm", (result_torch - expected_out).norm()) + self.assertTrue(torch.allclose(result_torch, expected_out, atol=1e-4)) + def test_mixtral_moe(self): config = mixtral_config.ModelArgs() config.intermediate_size = 16 From c16432d4d579a9a92858b18521f34422dcf3d86b Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Thu, 20 Jun 2024 17:31:10 +0000 Subject: [PATCH 2/2] Fix per comments. --- tests/helpers.py | 6 +++--- tests/test_model_impl.py | 3 --- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/helpers.py b/tests/helpers.py index 82ace510..00442517 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -61,8 +61,8 @@ def to_xla_tensor(tree): def call_xla_model(model, weights, args): with jax.default_device(jax.devices("cpu")[0]): + xla_weights, xla_inputs = to_xla_tensor((weights, args)) with torch_xla2.default_env(): - xla_weights, xla_inputs = to_xla_tensor((weights, args)) result = torch.func.functional_call(model, xla_weights, xla_inputs) - result_torch = torch_xla2.tensor.j2t(result._elem) - return result_torch + result_torch = torch_xla2.tensor.j2t(result._elem) + return result_torch diff --git a/tests/test_model_impl.py b/tests/test_model_impl.py index c072d9d0..b0dfc151 100644 --- a/tests/test_model_impl.py +++ b/tests/test_model_impl.py @@ -399,9 +399,6 @@ def test_mixtral_transformer(self): continue new_dict[k] = v - import pdb - - pdb.set_trace() model_ours = mixtral.Transformer(model_arg, env) # Invoke original model