Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/MaxText/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def get_decoder_layers(self):
return [llama2.LlamaDecoderLayerToLinen]
case DecoderBlockType.MISTRAL:
# TODO(ranran): update to Mistral with sliding window attention
return [mistral.MistralDecoderLayer]
return [mistral.MistralDecoderLayerToLinen]
case DecoderBlockType.MIXTRAL:
return [mixtral.MixtralDecoderLayer]
case DecoderBlockType.DEEPSEEK:
Expand Down
194 changes: 110 additions & 84 deletions src/MaxText/layers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,131 +21,151 @@
import jax.numpy as jnp

from flax import linen as nn
from flax import nnx

from MaxText.layers.linears import mlp_block
from MaxText.layers import models
from MaxText import max_utils
from MaxText.layers import nnx_wrappers, initializers
from MaxText.layers.linears import Dropout, MlpBlock
from MaxText.layers.models import Config
from MaxText.layers.attentions import Attention
from MaxText.layers import quantizations
from MaxText.layers.attentions import attention_as_linen
from MaxText.layers.quantizations import AqtQuantization as Quant
from MaxText.layers.normalizations import rms_norm
from MaxText.layers.normalizations import RMSNorm


# -----------------------------------------
# The Decoder Layer for Mistral
# -----------------------------------------


class MistralDecoderLayer(nn.Module):
class MistralDecoderLayer(nnx.Module):
"""Transformer decoder layer that attends to the encoder."""

config: models.Config
mesh: Mesh
model_mode: str
quant: None | Quant = None
def __init__(
self,
config: Config,
model_mode: str,
mesh: Mesh,
*,
rngs: nnx.Rngs,
quant: None | Quant = None,
):
self.config = config
self.mesh = mesh
self.quant = quant
self.rngs = rngs

batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode)
dummy_inputs_shape = (batch_size, seq_len, config.emb_dim)

self.pre_self_attention_layer_norm = RMSNorm(
num_features=config.emb_dim,
dtype=config.dtype,
weight_dtype=config.weight_dtype,
kernel_axes=("norm",),
epsilon=config.normalization_layer_epsilon,
rngs=self.rngs,
)

self.self_attention = Attention(
config=config,
num_query_heads=config.num_query_heads,
num_kv_heads=config.num_kv_heads,
head_dim=config.head_dim,
max_target_length=config.max_target_length,
max_prefill_predict_length=config.max_prefill_predict_length,
attention_kernel=config.attention,
inputs_q_shape=dummy_inputs_shape,
inputs_kv_shape=dummy_inputs_shape,
mesh=mesh,
dtype=config.dtype,
weight_dtype=config.weight_dtype,
dropout_rate=config.dropout_rate,
float32_qk_product=config.float32_qk_product,
float32_logits=config.float32_logits,
quant=self.quant,
kv_quant=quantizations.configure_kv_quant(config),
prefill_cache_axis_order=tuple(map(int, config.prefill_cache_axis_order.split(","))),
ar_cache_axis_order=tuple(map(int, config.ar_cache_axis_order.split(","))),
compute_axis_order=tuple(map(int, config.compute_axis_order.split(","))),
reshape_q=config.reshape_q,
use_ragged_attention=config.use_ragged_attention,
ragged_block_size=config.ragged_block_size,
model_mode=model_mode,
rngs=self.rngs,
)

self.post_self_attention_layer_norm = RMSNorm(
num_features=config.emb_dim,
dtype=config.dtype,
weight_dtype=config.weight_dtype,
kernel_axes=("norm",),
epsilon=config.normalization_layer_epsilon,
rngs=self.rngs,
)

self.mlp = MlpBlock(
in_features=config.emb_dim,
intermediate_dim=config.mlp_dim,
activations=config.mlp_activations,
intermediate_dropout_rate=config.dropout_rate,
dtype=config.dtype,
weight_dtype=config.weight_dtype,
config=config,
quant=self.quant,
model_mode=model_mode,
rngs=self.rngs,
)

self.dropout = Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs)

self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed")

@nn.compact
def __call__(
self,
inputs,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
page_state: None | int = None,
slot: None | int = None,
previous_chunk=None,
page_state=None,
slot=None,
):

cfg = self.config
mesh = self.mesh

inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
inputs = checkpoint_name(inputs, "decoder_layer_input")
lnx_rms = rms_norm(
num_features=inputs.shape[-1],
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
name="pre_self_attention_layer_norm",
kernel_axes=("norm",),
epsilon=cfg.normalization_layer_epsilon,
)
lnx = lnx_rms(inputs)

lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_norm_length", "activation_embed"))

# Self-attention block
attention_layer = attention_as_linen(
config=cfg,
num_query_heads=cfg.num_query_heads,
num_kv_heads=cfg.num_kv_heads,
head_dim=cfg.head_dim,
max_target_length=cfg.max_target_length,
max_prefill_predict_length=cfg.max_prefill_predict_length,
attention_kernel=cfg.attention,
inputs_q_shape=lnx.shape,
inputs_kv_shape=lnx.shape,
mesh=mesh,
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
dropout_rate=cfg.dropout_rate,
name="self_attention",
float32_qk_product=cfg.float32_qk_product,
float32_logits=cfg.float32_logits,
quant=self.quant,
kv_quant=quantizations.configure_kv_quant(cfg),
prefill_cache_axis_order=tuple(map(int, cfg.prefill_cache_axis_order.split(","))),
ar_cache_axis_order=tuple(map(int, cfg.ar_cache_axis_order.split(","))),
compute_axis_order=tuple(map(int, cfg.compute_axis_order.split(","))),
model_mode=model_mode,
)
lnx = self.pre_self_attention_layer_norm(inputs)
lnx = nn.with_logical_constraint(lnx, self.activation_axis_names)

attention_lnx = attention_layer(
attention_lnx = self.self_attention(
lnx,
lnx,
decoder_positions,
decoder_segment_ids=decoder_segment_ids,
deterministic=deterministic,
model_mode=model_mode,
slot=slot,
page_state=page_state,
previous_chunk=previous_chunk,
)

attention_lnx = nn.with_logical_constraint(
attention_lnx, ("activation_batch", "activation_norm_length", "activation_embed")
)
attention_lnx = nn.with_logical_constraint(attention_lnx, self.activation_axis_names)
intermediate_inputs = inputs + attention_lnx

# Fully Connected
hidden_states = rms_norm(
num_features=intermediate_inputs.shape[-1],
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
name="post_self_attention_layer_norm",
kernel_axes=("norm",),
epsilon=cfg.normalization_layer_epsilon,
)(intermediate_inputs)
hidden_states = nn.with_logical_constraint(
hidden_states, ("activation_batch", "activation_norm_length", "activation_embed")
)
hidden_states = self.post_self_attention_layer_norm(intermediate_inputs)
hidden_states = nn.with_logical_constraint(hidden_states, self.activation_axis_names)

mlp_lnx = mlp_block(
in_features=hidden_states.shape[-1],
intermediate_dim=cfg.mlp_dim,
activations=cfg.mlp_activations,
intermediate_dropout_rate=cfg.dropout_rate,
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
name="mlp",
config=cfg,
quant=self.quant,
)(hidden_states, deterministic=deterministic)
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))
# MLP block.
mlp_lnx = self.mlp(hidden_states, deterministic=deterministic)
mlp_lnx = nn.with_logical_constraint(mlp_lnx, self.activation_axis_names)

layer_output = mlp_lnx + intermediate_inputs
layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic)

layer_output = nn.with_logical_constraint(
layer_output,
("activation_batch", "activation_norm_length", "activation_embed"),
)
layer_output = self.dropout(layer_output, deterministic=deterministic)
layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names)

if cfg.record_internal_nn_metrics:
self.sow("intermediates", "activation_mean", jnp.mean(layer_output))
Expand All @@ -160,3 +180,9 @@ def __call__(
return layer_output, None
else:
return layer_output


MistralDecoderLayerToLinen = nnx_wrappers.to_linen_class(
MistralDecoderLayer,
base_metadata_fn=initializers.variable_to_logically_partitioned,
)
Loading