Skip to content

fix gemma4 dtype mismatch#3746

Merged
copybara-service[bot] merged 1 commit intomainfrom
aireen/fix_gemma4_dtype
Apr 25, 2026
Merged

fix gemma4 dtype mismatch#3746
copybara-service[bot] merged 1 commit intomainfrom
aireen/fix_gemma4_dtype

Conversation

@aireenmei
Copy link
Copy Markdown
Collaborator

@aireenmei aireenmei commented Apr 25, 2026

Description

When fixing another bug in #3727 , I changed to use config.weight_dtype from config.dtype to initialize layer_scalar because it's a weight, this requires casting layer_scalar back to dtype during use time, or it causes error The input carry component c[1] has type bfloat16[2048,4096,2816] but the corresponding output carry component has type float32[2048,4096,2816], so the dtypes do not match. when weight_dtype=float32, dtype=bfloat16

Tests

log with error
working log

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 25, 2026

Codecov Report

❌ Patch coverage is 0% with 1 line in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/models/gemma4.py 0.00% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

Copy link
Copy Markdown
Collaborator

@shralex shralex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo in title dtye

@aireenmei aireenmei changed the title fix gemma4 dtye mismatch fix gemma4 dtype mismatch Apr 25, 2026
next_layer_addition = mlp_lnx + residual
layer_output = next_layer_addition
layer_output = layer_output * self.layer_scalar.value
layer_output = layer_output * jnp.asarray(self.layer_scalar.value, cfg.dtype)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: self.layer_scalar.value.astype(cfg.dtype) probably matches better maxtext style.

@copybara-service copybara-service Bot merged commit 59e0f17 into main Apr 25, 2026
63 of 68 checks passed
@copybara-service copybara-service Bot deleted the aireen/fix_gemma4_dtype branch April 25, 2026 06:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants