Skip to content

Don't add unoptimized steps to computational graph in coupled training#1013

Merged
jpdunc23 merged 3 commits intomainfrom
refactor/coupled-no-grad-steps
Mar 26, 2026
Merged

Don't add unoptimized steps to computational graph in coupled training#1013
jpdunc23 merged 3 commits intomainfrom
refactor/coupled-no-grad-steps

Conversation

@jpdunc23
Copy link
Copy Markdown
Member

@jpdunc23 jpdunc23 commented Mar 26, 2026

Avoid adding unoptimized steps (i.e., those where LossContributionsConfig settings result in 0 loss weight) to the computational graph by computing those steps with torch.no_grad(). In a production job, this change resulted in ~13% decrease in GPU memory utilization.

Changes:

  • Adds step_is_optimized() helper method to CoupledStepperTrainLoss which can be passed to CoupledStepper.get_prediction_generator() via its new argument of the same name.

  • Tests added

@jpdunc23 jpdunc23 changed the title Don't add unoptimized steps to computational graph Don't add unoptimized steps to computational graph in coupled training Mar 26, 2026
@jpdunc23
Copy link
Copy Markdown
Member Author

jpdunc23 commented Mar 26, 2026

Slide here showing the reduction in GPU memory util.

@jpdunc23 jpdunc23 marked this pull request as ready for review March 26, 2026 16:02
Copy link
Copy Markdown
Contributor

@mcgibbon mcgibbon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment: Not something you need to change, but noting the separation of responsibilities is different between the coupled code and ace. In ace, the TrainStepper is responsible for deciding/knowing which steps should be optimized, keeping the loss object a simpler "gets the loss on a particular step" object. Here the loss defines the loss on a series of steps in a window, though the way it's called to compute the loss is still by passing particular steps.

I think this leads to more coupling between the train stepper and the loss, but also, I can see the feeling that because the window of losses is more complicated in the coupled case, it's nice to pull it out into a level other than the stepper.

initial_condition: CoupledPrognosticState,
forcing_data: CoupledBatchData,
optimizer: OptimizationABC,
step_is_optimized: Callable[[str, int], bool] | None = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: It took me a while to understand what this was doing, at first I mis-read below and thought that this argument overrides a default implementation that calls self.step_is_optimized, but then I noticed below there's no self..

Suggestion: I think the behavior would be clear and the logic below simpler if you made the default lambda n, c: True or something similar.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
step_is_optimized: Callable[[str, int], bool] | None = None,
step_is_optimized: Callable[[str, int], bool] = lambda n, c: None,

Copy link
Copy Markdown
Contributor

@mcgibbon mcgibbon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving pending the line suggestion or something similar.

@jpdunc23
Copy link
Copy Markdown
Member Author

Comment: Not something you need to change, but noting the separation of responsibilities is different between the coupled code and ace. In ace, the TrainStepper is responsible for deciding/knowing which steps should be optimized, keeping the loss object a simpler "gets the loss on a particular step" object. Here the loss defines the loss on a series of steps in a window, though the way it's called to compute the loss is still by passing particular steps.

I think this leads to more coupling between the train stepper and the loss, but also, I can see the feeling that because the window of losses is more complicated in the coupled case, it's nice to pull it out into a level other than the stepper.

Agreed. Will refactor as in the implementation in #868 when I get back to that PR.

@jpdunc23 jpdunc23 merged commit c905556 into main Mar 26, 2026
7 checks passed
@jpdunc23 jpdunc23 deleted the refactor/coupled-no-grad-steps branch March 26, 2026 18:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants