Add configurable global mean removal transform#1193
Conversation
Extracts global mean removal from the normalizer into a separate transform that wraps SingleModuleStep, supporting shared-reference and per-channel modes with optional extra input channels. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ed field names Remove dead from_state method on GlobalMeanRemovalConfig. Update docstrings to explain that output-only fields are intentionally un-shifted by inverse_transform (the network learns to compensate via end-to-end training). Log a warning when field_names entries appear in neither in_names nor out_names. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Move result tensors to CPU before comparing with CPU-created expected tensors, fixing failures on GPU CI. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…onfig ABC Corrector, ocean, and prescribed prognostics now run in physical space (after inverse_transform) when global_mean_removal is active, so they see un-shifted values. Remove the unused GlobalMeanRemovalConfig ABC since only the union type is used in type hints. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Moves forward_transform/inverse_transform into step_with_adjustments so corrector/ocean/prescribed adjustments stay in one place instead of being duplicated outside. Adds NoGlobalMeanRemoval null class to eliminate the forked code path in SingleModuleStep.step(). Fixes device mismatch in test_per_channel_masked_uses_zero by moving data_mask to the test device. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The per-channel forward_transform was subtracting each field's per-sample spatial mean in physical space, leaving a post-normalization bias of -clim_mean/clim_std on every input pixel. For fields with significant climatology means (e.g. absolute temperatures) this fed the network large constant offsets and produced NaNs during training. Shift each field by clim_mean - sample_mean instead, mirroring the shared variant, so the post-normalization spatial mean is approximately zero. For masked samples the shift is zero (no forward or inverse shift), and the extra-channel formula becomes -shift/std (the anomaly for unmasked samples, zero for masked). Add regression tests that assert the post-normalization spatial mean is ~0 with realistic climatologies, including the masked case. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
The way this is implemented raises some questions about interaction with masked-input training. The way masking is implemented right now, the step requires surface temperature always be provided when this feature is on, so that it can compute the global-mean and use it to normalize the other features. I think it's also fairly clear we would want to be able to train while independently masking e.g. surface temperature and this new global-mean surface temperature input. On the question of information leakage, we could later add a feature to noise this removed global-mean so it doesn't contain such reliable information about the global mean surface temperature. We have two types of masking for inputs - missing data, and "we want to train this batch without this input". For missing data, the current behavior is basically correct - we can't train on a sample without surface temperature, unless we modify the scheme for the temperature removal to be less dependent on that field. For the "we want to train this batch without this input", we could maybe think about this as a dropout instead? |
yyexela
left a comment
There was a problem hiding this comment.
Overall looks great! Just some minor comments.
Links the unit-level value coverage in test_global_mean_removal.py to the full step by asserting that enabling global_mean_removal changes the output relative to the baseline with the same seed/weights/inputs. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
yyexela
left a comment
There was a problem hiding this comment.
Issue: The new commit adds a test called test_step_global_mean_removal_affects_output that tests if the outputs of a full step with/without PerChannelGlobalMeanRemovalConfig differ. This test does not verify that SharedGlobalMeanRemovalConfig also produces different outputs.
Suggestion: Copy this test into two:
test_step_per_channel_global_mean_removal_affects_output
and
test_step_shared_global_mean_removal_affects_output
This then verifies that the transform is invoked during the step for both config types.
|
Question: Should |
Yes, I would say this is already covered by the backwards-compatibility inference tests. |
Adds symmetric coverage for SharedGlobalMeanRemovalConfig alongside the existing PerChannelGlobalMeanRemovalConfig case via a shared helper. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
Changes look good to me! ✅ I approve |
Arcomano1234
left a comment
There was a problem hiding this comment.
Left a question that should be addressed / answered and a nit that I think should also be fixed
| "which is not supported for shared global mean removal." | ||
| ) | ||
| ref = input[ref_name] | ||
| sample_mean = ref.mean(dim=tuple(range(1, ref.ndim))) |
There was a problem hiding this comment.
Question: When we do the stats calculation for variables we use the area averaged mean correct? I was curious if you think this makes a difference / should we use aree-weighted just to match everything else we do
There was a problem hiding this comment.
I did think about doing it that way, but I don't think it will make a difference or be worth the added complexity.
There was a problem hiding this comment.
Yeah I don't think it should have profound effect but I think we should document it somewhere as I think most if not all of our "global means" in ACE are area-weighted. So if in the doc strings we say something like "subtract the raw global mean" that should be good enough
There was a problem hiding this comment.
How about "cellwise" global mean?
There was a problem hiding this comment.
I think our convention is generally to call the area-weighted global mean the area-weighted or weighted global mean, and I think cellwise global mean would be more verbose than needed in the variable/configuration names, but I think we could use it in the class names and docstrings.
There was a problem hiding this comment.
Yeah we don't need to be that verbose for the variable names / function names just want to make sure its explicitly documented in the docstrings so we have a reference and so its obvious to external users.
There was a problem hiding this comment.
Claude: Added cellwise-vs-area-weighted note in docstrings on GlobalMeanRemoval (the primary explanation, with rationale: simpler, network compensates during end-to-end training), and propagated short references to the two concrete impls (SharedGlobalMeanRemoval, PerChannelGlobalMeanRemoval) and both Config dataclasses so the distinction is obvious to external users at the user-facing entry points. Pushed in 786b074.
The method was only used by tests; production code reads ``n_extra_input_channels`` from the built transform object. Drop it from both configs and refactor the one remaining test to use the build path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Arcomano1234
left a comment
There was a problem hiding this comment.
I still think we should document it some where that we are using "cellwise" means for the subtraction but other than that this looks good to go. That documentation should be added but I don't need to re-review this, so I am approving the PR.
ACE conventionally uses area-weighted global means for stats and metrics; this transform deliberately uses the simpler cellwise (unweighted) mean and relies on end-to-end training to absorb the difference. Note this in the ABC and propagate to the concrete impl and config docstrings so the distinction is obvious to external users. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add an optional
global_mean_removalconfig toSingleModuleStepConfig(andSingleModuleStepperConfig) that removes per-sample global means from fields before normalization and restores them after denormalization. This lets the network operate on anomalies relative to the current global mean, which can improve generalization for temperature-like fields under climate drift.Changes:
fme.core.step.global_mean_removal: new module withSharedGlobalMeanRemoval(single reference field offset applied to a set of fields) andPerChannelGlobalMeanRemoval(each field's own mean removed independently). Both support optionally appending the removed mean as extra normalized input channels.fme.core.step.SingleModuleStepConfig: new optionalglobal_mean_removalfield; forward transform applied before normalization, inverse transform applied after denormalizationfme.ace.stepper.SingleModuleStepperConfig: passesglobal_mean_removalthrough toSingleModuleStepConfigTests added
If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated