-
Notifications
You must be signed in to change notification settings - Fork 398
Add total memory usage logging in GB and unit tests for activation offload #1813
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
MaxText/train.py
Outdated
@@ -963,11 +963,16 @@ def train_loop(config, recorder, state=None): | |||
compiled = p_train_step.lower(state, example_batch, nextrng).compile() | |||
compiled_stats = compiled.memory_analysis() | |||
if compiled_stats is not None: | |||
total = (compiled_stats.output_size_in_bytes + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realize the existing code doesn't follow this pattern, but can we move this formula + print statement into max_utils? Other files (e.g. sft_trainer or grpo_trainer) may also want to call this function, and it cleans up train.py which we prefer to be lean
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, moved this formula + print statement into max_utils. Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also as requested in G Chat: "printing the stats at the end of the first step instead of at the end of all of the steps". This is done. It is moved to before the first step.
The current logs are in this order:
Memstats: After params initialized:
Using (GB) XX / YY (ZZ%) on cuda:0
Total memory size: AA GB, Output size: BB GB, Temp size: CC GB, Argument size: DD GB, Host temp size: EE GB.
completed step: 0, seconds: xxx, TFLOP/s/device: yyy, Tokens/s/device: zzz, total_weights: www, loss: ooo
To see full metrics 'tensorboard --logdir=/tmp/tmp.Exyqpj9oUF/logdir/tensorboard/'
completed step: 1, seconds: xxx, TFLOP/s/device: yyy, Tokens/s/device: zzz, total_weights: www, loss: ooo
|
||
@pytest.mark.integration_test | ||
@pytest.mark.gpu_only | ||
def test_gpu_activation_offload_without_scan(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious why have scan=False, is this something we really care about?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We also want to keep the explicit loop (non-scan version) working with good performance, as it may be needed for some non-LLM models. Also because they use different HLOs (copy-start/done vs dynamic-update-start/done, dynamic-slice-start/done), we want to track both.
MaxText/train.py
Outdated
@@ -841,6 +841,10 @@ def train_loop(config, recorder, state=None): | |||
if step == first_profiling_step or prof.should_activate_periodic_profile(step): | |||
optional_postfix = f"step_{step}" if config.profile_periodically_period > 0 else "" | |||
prof.activate(blocking_object=state, optional_postfix=optional_postfix) | |||
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fails our internal linter since nextrng may not be defined yet. Is it possible to move this before the main train for loop (beforefor step in np.arange(start_step, config.steps):
)? You may have to plug in a random rngkey
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, moved this before the main train for loop and defined in max_utils.py to make train.py lean.
The data_iterator is copied to get the "example_batch" without disrupting the training loop. Please take a look.
MaxText/max_utils.py
Outdated
|
||
import orbax.checkpoint as ocp | ||
|
||
from tensorboardX import writer | ||
|
||
from MaxText import max_logging | ||
from MaxText.train import load_next_batch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't want max_utils to depend on train, can we instead get a shaped version of batch similar to how its done in train_compile here
maxtext/MaxText/train_compile.py
Line 102 in 3296117
shaped_batch = maxtext_utils.get_shaped_batch(config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It makes sense. Took this approach of maxtext_utils.get_shaped_batch(config). Done.
These changes are included in #1904 and it has been merged. Closing this PR. |
Add total memory usage logging in GB which is more readable and useful to be tracked especially for offloading.
Also add two unit tests for activation offloading with and without scan.