Skip to content

Downscaling loss weighting#1056

Merged
AnnaKwa merged 6 commits into
mainfrom
feature/downsc-loss-weighting
Apr 20, 2026
Merged

Downscaling loss weighting#1056
AnnaKwa merged 6 commits into
mainfrom
feature/downsc-loss-weighting

Conversation

@AnnaKwa
Copy link
Copy Markdown
Contributor

@AnnaKwa AnnaKwa commented Apr 17, 2026

This PR adds two configurable options to loss weighting in the downscaling training config, contained in the optional field loss_weights. Both were found to improve the generated outputs.

  • loss_weights.output_channels: dict with key, value corresponding to output channel name and weight to multiply that variable's loss by. Output names not in this dict are not adjusted. Downweighting precipitation improves the overall outputs' CRPS and extreme values with no adverse effect on precip skill.
  • loss_weights.noise_weight_exponent: power to raise the default EDM noise-weighted loss by. By default 1.0. Experiments indicate lowering this to ~0.75 improves the generation of pressure and wind extremes; these fields tend to have their noise vs loss dominated by small noise values if using the default EDM weighting.

AnnaKwa added 2 commits April 16, 2026 15:21
Adds a `loss_weights` dict field to `DiffusionModelConfig` that applies
multiplicative per-variable scaling to the denoising loss. Variables not
listed default to a weight of 1.0. Includes a test confirming that
zeroing a variable's weight eliminates its gradient contribution.

Made-with: Cursor
@AnnaKwa AnnaKwa changed the title Feature/downsc loss weighting Downscaling loss weighting Apr 17, 2026
@AnnaKwa AnnaKwa marked this pull request as ready for review April 17, 2026 18:00
Copy link
Copy Markdown
Collaborator

@frodre frodre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the addition. Just a couple minor suggestions.

Comment thread fme/downscaling/models.py
output_channels: Per-variable multiplicative weights applied to the loss.
Keys are variable names from out_names; variables not listed default to 1.0.
noise_weight_exponent: Exponent applied to the EDM noise-level loss weight
``(sigma^2 + sigma_data^2) / (sigma * sigma_data)^2``. The default
Copy link
Copy Markdown
Collaborator

@frodre frodre Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add the value found helpful for variables dominated by the lower-noise losses for easy reference like you did docs for the noise code.

Comment thread fme/downscaling/models.py
def _build_variable_loss_weight_tensor(
weights: dict[str, float], out_names: list[str]
) -> torch.Tensor:
values = [weights.get(name, 1.0) for name in out_names]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would consider putting something in here that errors if a requested loss weight name is not in out_names so that we get a fast turnaround on typos, etc.

@AnnaKwa AnnaKwa enabled auto-merge (squash) April 20, 2026 18:16
@AnnaKwa AnnaKwa merged commit 9986838 into main Apr 20, 2026
7 checks passed
@AnnaKwa AnnaKwa deleted the feature/downsc-loss-weighting branch April 20, 2026 18:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants