Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions MaxText/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class DecoderBlockType(enum.Enum):
QWEN3 = "qwen3"
QWEN3_MOE = "qwen3_moe"
GPT3 = "gpt3"
GPT_OSS = "gpt_oss"
SIMPLE = "simple"
SIMPLE_MLP = "simple_mlp"
LLAMA4 = "llama4"
Expand Down
6 changes: 5 additions & 1 deletion MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ base_mlp_dim: 7168
base_num_decoder_layers: 16
head_dim: 128
mlp_activations: ["silu", "linear"]
mlp_activations_limit: -1.0
dropout_rate: 0.0
logits_via_embedding: False
normalize_embedding_logits: True # whether to normalize pre-softmax logits if logits_via_embedding is true
Expand Down Expand Up @@ -183,7 +184,8 @@ first_num_dense_layers: 0 # number of initial dense layers in the model
shared_experts: 1
routed_scaling_factor: 1.0 # scaling factor for routing scores
routed_score_func: "" # scoring function for routing
routed_bias: False # a flag if a bias term is added for routing
routed_bias: False # a flag if a learnable bias is added for routing
mlp_bias: False # a flag if a learnable bias is added for MLP matmul
n_routing_groups: -1 # number of groups for routing, disabled by default
topk_routing_group: -1 # number of top groups to route inputs. For EP,
# sending activations to a maximum of topk_routing_group distinct devices can yield performance benefits.
Expand Down Expand Up @@ -268,6 +270,8 @@ param_scan_axis: 1
# The attention_type parameter determines the variants of attention, e.g. global or local_sliding
attention: 'autoselected' # Supported attention: autoselected, dot_product, flash, cudnn_flash_te
attention_type: 'global' # Supported attention_type: global, local_sliding, chunk, mla
attention_bias: False # If True, adds a learnable bias to the query, key, and value projections
attention_sink: False
sliding_window_size: 0
chunk_attn_window_size: 0
attn_logits_soft_cap: 0.0
Expand Down
55 changes: 55 additions & 0 deletions MaxText/configs/models/gpt-oss-120b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for gpt-oss-120b
# https://huggingface.co/openai/gpt-oss-120b/blob/main/config.json

# tokenizer_type: "huggingface"

# Attention
base_emb_dim: 2880
base_num_query_heads: 64
base_num_kv_heads: 8
head_dim: 64
sliding_window_size: 128
attention_bias: True
attention_sink: True

# RoPE
rope_type: "yarn"
rope_max_timescale: 150_000
max_position_embeddings: 131072
original_max_position_embeddings: 4096
rope_factor: 32
beta_fast: 32
beta_slow: 1

# MLP
base_mlp_dim: 2880
base_moe_mlp_dim: 2880
mlp_activations: ["sigmoid","linear"]
mlp_activations_limit: 7.0
routed_bias: True
mlp_bias: True
num_experts: 128
num_experts_per_tok: 4

# General
base_num_decoder_layers: 6 #36
vocab_size: 201088
normalization_layer_epsilon: 1.0e-5
enable_dropout: False
logits_via_embedding: False
decoder_block: "gpt_oss"
inhomogeneous_layer_cycle_interval: 2
55 changes: 55 additions & 0 deletions MaxText/configs/models/gpt-oss-20b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for gpt-oss-20b
# https://huggingface.co/openai/gpt-oss-20b/blob/main/config.json

# tokenizer_type: "huggingface"

# Attention
base_emb_dim: 2880
base_num_query_heads: 64
base_num_kv_heads: 8
head_dim: 64
sliding_window_size: 128
attention_bias: True
attention_sink: True

# RoPE
rope_type: "yarn"
rope_max_timescale: 150_000
max_position_embeddings: 131072
original_max_position_embeddings: 4096
rope_factor: 32
beta_fast: 32
beta_slow: 1

# MLP
base_mlp_dim: 2880
base_moe_mlp_dim: 2880
mlp_activations: ["sigmoid","linear"]
mlp_activations_limit: 7.0
routed_bias: True
mlp_bias: True
num_experts: 32
num_experts_per_tok: 4

# General
base_num_decoder_layers: 24
vocab_size: 201088
normalization_layer_epsilon: 1.0e-5
enable_dropout: False
logits_via_embedding: False
decoder_block: "gpt_oss"
inhomogeneous_layer_cycle_interval: 2
66 changes: 49 additions & 17 deletions MaxText/layers/attention_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import math

import numpy as np
from packaging import version

from jax import lax
from jax.ad_checkpoint import checkpoint_name
Expand Down Expand Up @@ -131,6 +132,12 @@ def apply_mask_to_logits(logits: Array, mask: Array):
return jnp.where((mask >= DEFAULT_MASK_VALUE * 0.5), logits, DEFAULT_MASK_VALUE)


def validate_flash_attention_with_sinks_on_gpu(sinks: Array | None) -> None:
"""Helper function to check for sinks with flash attention on GPU."""
if sinks is not None:
raise ValueError("The flash attention with sinks is not supported on GPU yet.")


# TODO(agagik): change splash_attention_mask._ComputableMask to be non protected
class ChunkedCausalMask(splash_attention_mask._ComputableMask): # pylint: disable=protected-access
"""Lazy chunked causal mask.
Expand Down Expand Up @@ -677,6 +684,7 @@ def apply_attention(
use_ragged_attention: bool = False,
previous_chunk: Any = None,
bidirectional_mask: Any = None,
sinks: Array = None,
*,
qk_product_einsum: Callable[..., Array],
wv_product_einsum: Callable[..., Array],
Expand Down Expand Up @@ -712,6 +720,7 @@ def apply_attention(
model_mode,
previous_chunk,
bidirectional_mask=bidirectional_mask,
sinks=sinks,
qk_product_einsum=qk_product_einsum,
wv_product_einsum=wv_product_einsum,
)
Expand All @@ -727,8 +736,9 @@ def apply_attention(
"""Decode not supported with flash attention.
Use `dot_product` instead."""
)
return self.tpu_flash_attention(query, key, value, decoder_segment_ids, self.attn_logits_soft_cap), None, None
return self.tpu_flash_attention(query, key, value, decoder_segment_ids, self.attn_logits_soft_cap, sinks), None, None
else:
validate_flash_attention_with_sinks_on_gpu(sinks)
if model_mode == MODEL_MODE_AUTOREGRESSIVE:
# fallback to dot_product as pallas gpu flash attention doesn't support decode stage
return self.apply_attention_dot(
Expand Down Expand Up @@ -763,6 +773,7 @@ def apply_attention(
out = gpu_pallas_attention.mha(query, key, value, decoder_segment_ids, sm_scale=1.0, causal=True)
return out, None, None
elif self.attention_kernel == "cudnn_flash_te":
validate_flash_attention_with_sinks_on_gpu(sinks)
if isinstance(key, KVTensor):
key = key.dequant()
if isinstance(value, KVTensor):
Expand All @@ -774,6 +785,7 @@ def apply_attention(
)
return self.cudnn_flash_attention(query, key, value, decoder_segment_ids, model_mode), None, None
elif self.attention_kernel == "cudnn_flash_jax":
validate_flash_attention_with_sinks_on_gpu(sinks)
if isinstance(key, KVTensor):
key = key.dequant()
if isinstance(value, KVTensor):
Expand Down Expand Up @@ -877,6 +889,7 @@ def tpu_flash_attention(
value: Array,
decoder_segment_ids: Array | None,
attn_logits_soft_cap: float | None = None,
sinks: Array = None,
) -> Array:
"""TPU Flash Attention."""

Expand Down Expand Up @@ -1029,6 +1042,7 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
segment_axis_names_splash_kernel,
None, # no sharding for cp_size
None, # no sharding for load_balanced_context_parallel
None, # no sharding for sinks
),
out_specs=axis_names_q,
check_rep=False,
Expand All @@ -1042,6 +1056,7 @@ def wrap_flash_attention(
splash_kernel,
cp_size,
load_balanced_context_parallel,
sinks,
):
# If load_balanced_context_parallel is enabled, reorder the key and value tensors
# to ensure that they are contiguous in memory.
Expand All @@ -1065,8 +1080,13 @@ def wrap_flash_attention(
decoder_segment_ids_tuple = splash_attention_kernel.SegmentIds(decoder_segment_ids_q, decoder_segment_ids_kv)
else:
decoder_segment_ids_tuple = None
attention_output = jax.vmap(splash_kernel)(query, key, value, segment_ids=decoder_segment_ids_tuple)

# TODO(ranran): remove if/else branch once b/441336842 is fixed
if version.parse(jax.__version__) < version.parse("0.7.2.dev20250824"):
attention_output = jax.vmap(splash_kernel)(query, key, value, decoder_segment_ids_tuple)
else:
attention_output = jax.vmap(splash_kernel, in_axes=(0, 0, 0, 0, None))(
query, key, value, decoder_segment_ids_tuple, sinks
)
return attention_output

x = wrap_flash_attention(
Expand All @@ -1078,6 +1098,7 @@ def wrap_flash_attention(
splash_kernel,
cp_size,
load_balanced_context_parallel,
sinks,
)

x = jnp.transpose(x, axes=(0, 2, 1, 3))
Expand Down Expand Up @@ -1194,6 +1215,7 @@ def compute_local_attention(
q_seq_len: int,
model_mode: str,
wv_product_einsum: Callable[..., Array],
sinks: Array = None,
) -> tuple[Array, Array, Array]:
"""Computes the attention of a local subset of the kv cache.
Local attention results will need to be combined with any other local attentions and normalized
Expand All @@ -1210,19 +1232,26 @@ def compute_local_attention(
local_max is the local max of exponentials
local_sum is the sum of exponentials for this chunk, divided by exp(local_max).
"""
local_max = jnp.max(attn_weights, axis=-1, keepdims=True)
local_exps = jnp.exp(attn_weights - local_max)
local_sum = jnp.sum(local_exps, axis=-1, keepdims=True)

local_sum = jnp.moveaxis(local_sum, -2, 1)
local_max = jnp.moveaxis(local_max, -2, 1)

local_max = jnp.reshape(
local_max, (local_max.shape[0], local_max.shape[1], local_max.shape[2] * local_max.shape[3], 1)
)
local_sum = jnp.reshape(
local_sum, (local_sum.shape[0], local_sum.shape[1], local_sum.shape[2] * local_sum.shape[3], 1)
)
b, n_kv, g, t, s = attn_weights.shape
n_q = n_kv * g
logits = jnp.reshape(attn_weights, (b, n_q, t, s))
if sinks is not None:
# broadcast sinks to match the attn weights dimension and combine
sinks_param = sinks.astype(attn_weights.dtype) # (n_q,)
sinks_logits = sinks_param[jnp.newaxis, :, jnp.newaxis, jnp.newaxis] # (1, n_q, 1, 1)
sinks_logits = jnp.broadcast_to(sinks_logits, (b, n_q, t, 1))
logits = jnp.concatenate([logits, sinks_logits], axis=-1)

# softmax
local_max = jnp.max(logits, axis=-1, keepdims=True)
local_exps_combined = jnp.exp(logits - local_max)
local_sum = jnp.sum(local_exps_combined, axis=-1, keepdims=True)

# reshape and transpose
local_exps = local_exps_combined[..., :s]
local_exps = jnp.reshape(local_exps, (b, n_kv, g, t, s))
local_max = jnp.transpose(local_max, (0, 2, 1, 3)) # (b, t, n_q, 1)
local_sum = jnp.transpose(local_sum, (0, 2, 1, 3)) # (b, t, n_q, 1)

local_out = self.wv_product(local_exps, value, model_mode, wv_product_einsum)
if model_mode == MODEL_MODE_AUTOREGRESSIVE and self.is_partition_in_decode(q_seq_len):
Expand Down Expand Up @@ -1254,6 +1283,7 @@ def apply_attention_dot(
model_mode: str = MODEL_MODE_TRAIN,
previous_chunk: Any = None,
bidirectional_mask: Any = None,
sinks: Array = None,
*,
qk_product_einsum: Callable[..., Array],
wv_product_einsum: Callable[..., Array],
Expand Down Expand Up @@ -1312,7 +1342,7 @@ def apply_attention_dot(
attn_mask = partitioning.with_sharding_constraint(attn_mask, (BATCH, HEAD, None, PREFILL_LENGTH, KV_LENGTH))
if attn_mask is not None:
attn_weights = apply_mask_to_logits(attn_weights, attn_mask)
return self.compute_local_attention(attn_weights, value, q_seq_len, model_mode, wv_product_einsum)
return self.compute_local_attention(attn_weights, value, q_seq_len, model_mode, wv_product_einsum, sinks)

def qk_product(
self, query: Array, key: Array | KVTensor, q_seq_len: int, model_mode: str, einsum: Callable[..., Array]
Expand Down Expand Up @@ -1450,6 +1480,7 @@ def __call__(
cached_values=None,
previous_chunk=None,
bidirectional_mask=None,
sinks=None,
slot: Optional[int] = None,
page_state: Optional[page_manager.PageState] = None,
):
Expand All @@ -1471,6 +1502,7 @@ def __call__(
use_ragged_attention=self.use_ragged_attention,
previous_chunk=previous_chunk,
bidirectional_mask=bidirectional_mask,
sinks=sinks,
qk_product_einsum=self.AqtEinsum_0,
wv_product_einsum=self.AqtEinsum_1,
)
Expand Down
12 changes: 10 additions & 2 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
RotaryEmbedding,
YarnRotaryEmbedding,
)
from MaxText.layers.initializers import nd_dense_init, NdInitializer, variable_to_logically_partitioned
from MaxText.layers.initializers import nd_dense_init, NdInitializer, variable_to_logically_partitioned, default_bias_init
from MaxText.layers.linears import DenseGeneral, canonicalize_tuple, normalize_axes
from MaxText.layers.normalizations import RMSNorm
from MaxText.layers.quantizations import AqtQuantization as Quant
Expand Down Expand Up @@ -477,6 +477,14 @@ def __init__(

self.out = self.init_out_w(output_dim=inputs_q_shape[-1])

if self.config.attention_sink:
self.sinks = nnx.Param(
default_bias_init(self.rngs.params(), (self.config.num_query_heads,), self.weight_dtype),
sharding=(None,),
)
else:
self.sinks = None

is_llama4_decoder_block = self.config.decoder_block == DecoderBlockType.LLAMA4
if self.use_qk_norm and not is_llama4_decoder_block:
self.query_norm = RMSNorm(
Expand Down Expand Up @@ -917,7 +925,7 @@ def __call__(
if model_mode != MODEL_MODE_TRAIN:
cached_values = self.update_kv_caches(key, value, decoder_segment_ids, model_mode, previous_chunk)
out = self.attention_op(
query, key, value, decoder_segment_ids, model_mode, cached_values, previous_chunk, bidirectional_mask
query, key, value, decoder_segment_ids, model_mode, cached_values, previous_chunk, bidirectional_mask, self.sinks
)

if model_mode == MODEL_MODE_PREFILL:
Expand Down
Loading
Loading