Skip to content

Add total memory usage logging in GB and unit tests for activation offload #1813

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from

Conversation

zhenying-liu
Copy link
Contributor

@zhenying-liu zhenying-liu commented Jun 10, 2025

Add total memory usage logging in GB which is more readable and useful to be tracked especially for offloading.
Also add two unit tests for activation offloading with and without scan.

MaxText/train.py Outdated
@@ -963,11 +963,16 @@ def train_loop(config, recorder, state=None):
compiled = p_train_step.lower(state, example_batch, nextrng).compile()
compiled_stats = compiled.memory_analysis()
if compiled_stats is not None:
total = (compiled_stats.output_size_in_bytes +
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize the existing code doesn't follow this pattern, but can we move this formula + print statement into max_utils? Other files (e.g. sft_trainer or grpo_trainer) may also want to call this function, and it cleans up train.py which we prefer to be lean

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, moved this formula + print statement into max_utils. Done.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also as requested in G Chat: "printing the stats at the end of the first step instead of at the end of all of the steps". This is done. It is moved to before the first step.

The current logs are in this order:

Memstats: After params initialized:
        Using (GB) XX / YY (ZZ%) on cuda:0
Total memory size: AA GB, Output size: BB GB, Temp size: CC GB, Argument size: DD GB, Host temp size: EE GB.
completed step: 0, seconds: xxx, TFLOP/s/device: yyy, Tokens/s/device: zzz, total_weights: www, loss: ooo
To see full metrics 'tensorboard --logdir=/tmp/tmp.Exyqpj9oUF/logdir/tensorboard/'
completed step: 1, seconds: xxx, TFLOP/s/device: yyy, Tokens/s/device: zzz, total_weights: www, loss: ooo


@pytest.mark.integration_test
@pytest.mark.gpu_only
def test_gpu_activation_offload_without_scan(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why have scan=False, is this something we really care about?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also want to keep the explicit loop (non-scan version) working with good performance, as it may be needed for some non-LLM models. Also because they use different HLOs (copy-start/done vs dynamic-update-start/done, dynamic-slice-start/done), we want to track both.

@zhenying-liu zhenying-liu requested a review from gobbleturk June 10, 2025 16:59
MaxText/train.py Outdated
@@ -841,6 +841,10 @@ def train_loop(config, recorder, state=None):
if step == first_profiling_step or prof.should_activate_periodic_profile(step):
optional_postfix = f"step_{step}" if config.profile_periodically_period > 0 else ""
prof.activate(blocking_object=state, optional_postfix=optional_postfix)
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails our internal linter since nextrng may not be defined yet. Is it possible to move this before the main train for loop (beforefor step in np.arange(start_step, config.steps):)? You may have to plug in a random rngkey

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, moved this before the main train for loop and defined in max_utils.py to make train.py lean.
The data_iterator is copied to get the "example_batch" without disrupting the training loop. Please take a look.

@zhenying-liu zhenying-liu changed the title Add total memory usage in GB and unit tests for activation offload Add total memory usage logging in GB and unit tests for activation offload Jun 11, 2025

import orbax.checkpoint as ocp

from tensorboardX import writer

from MaxText import max_logging
from MaxText.train import load_next_batch
Copy link
Collaborator

@gobbleturk gobbleturk Jun 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want max_utils to depend on train, can we instead get a shaped version of batch similar to how its done in train_compile here

shaped_batch = maxtext_utils.get_shaped_batch(config)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense. Took this approach of maxtext_utils.get_shaped_batch(config). Done.

@SurbhiJainUSC
Copy link
Collaborator

These changes are included in #1904 and it has been merged. Closing this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants