Skip to content

Conversation

NuojCheng
Copy link
Collaborator

@NuojCheng NuojCheng commented Aug 5, 2025

Description

Migrate Mistral model to NNX. Comparing f4f286c (before mistral migration) and 5984d9c (after mistral migration).

Tests

Training tests

Command

python -m MaxText.train MaxText/configs/base.yml \
    run_name=nc_test_mistral_$RANDOM \
    steps=5 \
    base_output_directory=gs://chengnuojin-maxtext-logs \
    dataset_path=gs://chengnuojin-maxtext-dataset \
    model_name=mistral-7b \
    enable_checkpointing=False \
    per_device_batch_size=1

Webdiff

https://diff.googleplex.com/#key=qvnNP8tANEIK

Decode Inference tests

Command

python3 -m MaxText.decode src/MaxText/configs/base.yml \
model_name=mistral-7b \
tokenizer_path=assets/tokenizer.mistral-v1 \
tokenizer_type=sentencepiece \
scan_layers=false \
per_device_batch_size=1 \
ici_fsdp_parallelism=1 \
ici_autoregressive_parallelism=-1 \
max_prefill_predict_length=128 \
max_target_length=256 \
prompt="I love to" \
attention=dot_product \
load_parameters_path=gs://chengnuojin-maxtext-logs/chengnuojin_decode_32458/checkpoints/0/items

Webdiff

https://diff.googleplex.com/#key=dbqiUSGbkqov

Jetstream Inference test

Command

Step 1: https://paste.googleplex.com/6590290169298944

Step 2: https://paste.googleplex.com/6674235103772672

Step 3: https://paste.googleplex.com/6448284356968448

Before Mistral Migration

xprof

Memstats: After load_params:
        Using (GB) 3.37 / 95.74 (3.519950%) on TPU_0(process=0,(0,0,0,0))
        Using (GB) 3.37 / 95.74 (3.519950%) on TPU_1(process=0,(1,0,0,0))
        Using (GB) 3.37 / 95.74 (3.519950%) on TPU_2(process=0,(0,1,0,0))
        Using (GB) 3.37 / 95.74 (3.519950%) on TPU_3(process=0,(1,1,0,0))

RAMstats: After load_params:
        Using (GB) 27.17 / 440.83 (6.163374%) -->  Available:410.84

After Mistral Migration

xprof

Memstats: After load_params:
        Using (GB) 3.37 / 95.74 (3.519950%) on TPU_0(process=0,(0,0,0,0))
        Using (GB) 3.37 / 95.74 (3.519950%) on TPU_1(process=0,(1,0,0,0))
        Using (GB) 3.37 / 95.74 (3.519950%) on TPU_2(process=0,(0,1,0,0))
        Using (GB) 3.37 / 95.74 (3.519950%) on TPU_3(process=0,(1,1,0,0))

RAMstats: After load_params:
        Using (GB) 27.94 / 440.83 (6.338044%) -->  Available:410.06

Web diff

https://diff.googleplex.com/#key=siw4m0hLYUAT

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@NuojCheng NuojCheng force-pushed the chengnuojin-mistral-migration branch 4 times, most recently from b2bfd1d to 65c58f2 Compare August 6, 2025 01:08
@NuojCheng NuojCheng changed the title NNX Migration for Mistral models [Draft] NNX Migration for Mistral models Aug 6, 2025
@NuojCheng NuojCheng changed the title [Draft] NNX Migration for Mistral models [Draft, NO MERGE] NNX Migration for Mistral models Aug 6, 2025
@NuojCheng NuojCheng force-pushed the chengnuojin-mistral-migration branch 13 times, most recently from 0591ca3 to fa13350 Compare August 13, 2025 00:42
@NuojCheng NuojCheng force-pushed the chengnuojin-mistral-migration branch 10 times, most recently from 3675744 to 5973fdb Compare August 18, 2025 20:53
Copy link
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM just a couple small comments. Are the inference decode test results in the description from after the rebase with #2370? I see some very slight memory diff

Also discussed this offline already, but do you mind running some Maxengine/Jetstream test from a checkpoint?

@NuojCheng NuojCheng force-pushed the chengnuojin-mistral-migration branch from c325f3e to 5984d9c Compare September 30, 2025 20:53
Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see both of accuracy is 0 from JetStream. Did you kill the process or others?

Results

{'accuracy': 0.0, 'gen_num': 5000}

@NuojCheng
Copy link
Collaborator Author

I see both of accuracy is 0 from JetStream. Did you kill the process or others?

Results

{'accuracy': 0.0, 'gen_num': 5000}

I did not kill the process. I am not sure what happened here, maybe using the wrong tokenizer? (I used the mistral tokenizer under assets instead of huggingface in llama2)

Copy link
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM after the testing you have been continuing. Just one comment

@NuojCheng NuojCheng force-pushed the chengnuojin-mistral-migration branch 4 times, most recently from 45af6de to bbb8ce7 Compare October 2, 2025 18:23
Copy link
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the great thorough testing so far. Please do follow up here when you have the results of the golden logits test you mentioned you were running

@NuojCheng NuojCheng force-pushed the chengnuojin-mistral-migration branch from bbb8ce7 to 6bcfdc2 Compare October 3, 2025 23:37
@NuojCheng NuojCheng force-pushed the chengnuojin-mistral-migration branch from 6bcfdc2 to ad74a01 Compare October 6, 2025 16:45
@copybara-service copybara-service bot merged commit 13872bd into main Oct 6, 2025
23 of 24 checks passed
@copybara-service copybara-service bot deleted the chengnuojin-mistral-migration branch October 6, 2025 17:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants