Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions jetstream_pt/third_party/mixtral/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,15 @@ def from_name(cls, name: str):
num_experts=8,
num_activated_experts=2,
),
"Mixtral-tiny": dict(
block_size=128,
n_layer=3,
n_head=32,
n_local_heads=8,
dim=128,
intermediate_size=None,
rope_base=1000000.0,
num_experts=8,
num_activated_experts=2,
),
}
23 changes: 15 additions & 8 deletions jetstream_pt/third_party/mixtral/model_original.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
from .config import ModelArgs, find_multiple


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this function been used?

def find_multiple(n: int, k: int) -> int:
if n % k == 0:
return n
return n + k - (n % k)


class KVCache(nn.Module):

def __init__(
Expand All @@ -31,7 +37,8 @@ def __init__(
max_seq_length,
n_heads,
head_dim,
dtype=torch.bfloat16,
# dtype=torch.bfloat16,
dtype=torch.float32,
):
super().__init__()
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
Expand Down Expand Up @@ -191,22 +198,21 @@ class ConditionalFeedForward(nn.Module):

def __init__(self, config):
super().__init__()
# Replace the weight init of torch.empty with torch.rand for testing purpose
self.w1 = nn.Parameter(
torch.empty(config.num_experts, config.intermediate_size, config.dim)
torch.rand(config.num_experts, config.intermediate_size, config.dim)
)
self.w2 = nn.Parameter(
torch.empty(config.num_experts, config.dim, config.intermediate_size)
torch.rand(config.num_experts, config.dim, config.intermediate_size)
)
self.w3 = nn.Parameter(
torch.empty(config.num_experts, config.intermediate_size, config.dim)
torch.rand(config.num_experts, config.intermediate_size, config.dim)
)

def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor:
# T = num_tokens, I = intermediate size, D = hidden dim, A = activated experts
w1_weights = self.w1[expert_indices] # [T, A, D, D]
w3_weights = self.w3[expert_indices] # [T, A, D, D]
w2_weights = self.w2[expert_indices] # [T, A, D, D]
# x: [T, D]
x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights))
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
Expand All @@ -215,7 +221,7 @@ def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor:

class MOEFeedForward(nn.Module):

def __init__(self, config, env=None) -> None:
def __init__(self, config) -> None:
super().__init__()
self.gate = nn.Linear(config.dim, config.num_experts, bias=False)
self.cond_ffn = ConditionalFeedForward(config)
Expand Down Expand Up @@ -261,7 +267,8 @@ def precompute_freqs_cis(
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
return cache.to(dtype=torch.bfloat16)
# return cache.to(dtype=torch.bfloat16)
return cache


def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
Expand Down
29 changes: 27 additions & 2 deletions tests/helpers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import jax
import torch
import torch_xla2
import jax
from jetstream_pt.third_party.llama import model_args
from jetstream_pt.third_party.mixtral import config as mixtral_config
from jetstream_pt import environment


Expand Down Expand Up @@ -31,13 +31,38 @@ def make_env_tiny(bf16_enable=True):
return env, config


def make_mixtral_env(bf16_enable=True):
torch_dtype = torch.bfloat16 if bf16_enable else torch.float32
torch.set_default_dtype(torch_dtype)
jax.config.update("jax_dynamic_shapes", False)
jax.config.update("jax_traceback_filtering", "off")
config = mixtral_config.ModelArgs.from_name("Mixtral-tiny")
environment_data = environment.JetEngineEnvironmentData()
environment_data.max_input_sequence_length = 128
environment_data.cache_sequence_length = 128
environment_data.bf16_enable = bf16_enable
environment_data.model_type = "mixtral"
environment_data.batch_size = 1
environment_data.num_layers = config.n_layer
environment_data.cache_shape = (
1,
config.n_local_heads,
environment_data.cache_sequence_length,
config.dim // config.n_head,
)
env = environment.JetEngineEnvironment(environment_data)
env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu
return env, config


def to_xla_tensor(tree):
return torch_xla2.default_env().to_xla(tree)


def call_xla_model(model, weights, args):
with jax.default_device(jax.devices("cpu")[0]):
xla_weights, xla_inputs = to_xla_tensor((weights, args))
result = torch.func.functional_call(model, xla_weights, xla_inputs)
with torch_xla2.default_env():
result = torch.func.functional_call(model, xla_weights, xla_inputs)
result_torch = torch_xla2.tensor.j2t(result._elem)
return result_torch
63 changes: 63 additions & 0 deletions tests/test_model_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from jetstream_pt.third_party.llama import model_original
from jetstream_pt.third_party.gemma import model_original as gemma_orig
from jetstream_pt.third_party.gemma import model as gemma
from jetstream_pt.third_party.mixtral import model_original as mixtral_orig
from jetstream_pt.third_party.mixtral import model as mixtral
from jetstream_pt.third_party.mixtral import config as mixtral_config
from jetstream_pt import torchjax
Expand Down Expand Up @@ -362,6 +363,68 @@ def test_transformer(self):
print("Transformer: Diff norm", (result_torch - expected_out).norm())
self.assertTrue(torch.allclose(result_torch, expected_out, atol=1e-4))

# pylint: disable-next=all
def test_mixtral_transformer(self):
"""test transformer diff between original model vs xla_model"""
env, model_arg = helpers.make_mixtral_env(False)

model_orig = mixtral_orig.Transformer(model_arg)
model_orig.setup_caches(max_batch_size=1, max_seq_length=env.cache_len)

state_dict = dict(model_orig.state_dict())
state_dict["freqs_cis"] = model_orig.freqs_cis
new_dict = {}

for k, v in state_dict.items():
if "kv_cache" in k:
continue
if "wqkv" in k:
wq = k.replace("wqkv", "wq")
wk = k.replace("wqkv", "wk")
wv = k.replace("wqkv", "wv")
kv_size = model_arg.n_local_heads * model_arg.head_dim
wq_t, wk_t, wv_t = v.split([model_arg.dim, kv_size, kv_size], dim=0)

new_dict[wq] = wq_t
new_dict[wk] = wk_t
new_dict[wv] = wv_t
continue
# "Freqs_cis" for exported model is calculated differently, by complex data type
if "freqs_cis" in k:
new_dict[k] = mixtral.precompute_freqs_cis(
model_arg.block_size,
model_arg.dim // model_arg.n_head,
model_arg.rope_base,
)
continue
new_dict[k] = v

model_ours = mixtral.Transformer(model_arg, env)

# Invoke original model
seqlen = 32
x = torch.randint(0, 32000, (1, seqlen)) # (batch, seqlen, embedding dim)
start_pos = 0
mask = self._prefill_mask(seqlen, start_pos)
input_pos = torch.arange(0, seqlen)
inputs_orig = (x, input_pos)

expected_out = model_orig(*inputs_orig)

# Invoke the exported model
caches = env.make_caches_prefill()

input_ours = (
x,
input_pos,
caches,
mask,
)
result_torch = helpers.call_xla_model(model_ours, new_dict, input_ours)

print("Transformer: Diff norm", (result_torch - expected_out).norm())
self.assertTrue(torch.allclose(result_torch, expected_out, atol=1e-4))

def test_mixtral_moe(self):
config = mixtral_config.ModelArgs()
config.intermediate_size = 16
Expand Down