Skip to content

Commit 59e0f17

Browse files
Merge pull request #3746 from AI-Hypercomputer:aireen/fix_gemma4_dtype
PiperOrigin-RevId: 905389377
2 parents dff1d0c + 9184714 commit 59e0f17

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxtext/models/gemma4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def __call__(
370370

371371
next_layer_addition = mlp_lnx + residual
372372
layer_output = next_layer_addition
373-
layer_output = layer_output * self.layer_scalar.value
373+
layer_output = layer_output * jnp.asarray(self.layer_scalar.value, cfg.dtype)
374374

375375
layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names)
376376

0 commit comments

Comments
 (0)