Skip to content

Downscaling model handles static input instead of DataLoader#954

Merged
frodre merged 8 commits intomainfrom
refactor/static-input-handled-by-model
Mar 13, 2026
Merged

Downscaling model handles static input instead of DataLoader#954
frodre merged 8 commits intomainfrom
refactor/static-input-handled-by-model

Conversation

@frodre
Copy link
Collaborator

@frodre frodre commented Mar 10, 2026

Previously, the DataLoader was responsible for subsetting and device-placing static inputs (e.g. fine-resolution topography) to match each batch's spatial extent before passing them to the model. This moves that responsibility into DiffusionModel itself: the model stores the full-domain static inputs on construction and subsets them per-batch using the batch's coordinate metadata.

Changes:

  • DiffusionModel.__init__ now calls .to_device() on static inputs at construction

  • New DiffusionModel._subset_static_inputs encapsulates fine lat/lon subsetting for train_on_batch and generate_on_batch

  • generate_on_batch_no_target derives the fine coordinate interval from coarse batch coordinates via adjust_fine_coord_range, then subsets stored static inputs

  • BatchData gains lat_interval and lon_interval properties for use in subsetting

  • adjust_fine_coord_range now raises a clear ValueError when the coordinate range is too close to the domain boundary for the required number of fine points to exist (documents the implicit ±88° latitude requirement)

  • The static_inputs parameter on all public model methods is retained but ignored — removal is deferred to a follow-on PR

  • Tests added


@property
def lat_interval(self) -> ClosedInterval:
lat = self.latlon_coordinates.lat[0] # all batch members identical; use first
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that we're no longer doing per-item random patch definitions, we could update BatchData to just use LatLonCoordinates

Base automatically changed from refactor/remove-topography-dataloader-pathway to main March 11, 2026 00:12
@frodre frodre force-pushed the refactor/static-input-handled-by-model branch from 555cb92 to f144258 Compare March 11, 2026 00:44
@frodre frodre marked this pull request as ready for review March 11, 2026 04:17
Copy link
Contributor

@AnnaKwa AnnaKwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like these changes, this is a really clean way to remove the static inputs from the data loading! Just to be super sure, can you add a test that the generated batches (when patching is used) have lat lon coords subset as expected? I think the current tests only check data values since we haven't used the coord info before.

self,
coarse_data: TensorMapping,
static_inputs: StaticInputs | None,
static_inputs: StaticInputs | None, # TODO: remove in follow-on PR
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be kept since generate will be passed subsetted static_inputs from generate_on_batch or generate_on_batch_no_target, right?

n_samples: int = 1,
) -> tuple[TensorDict, torch.Tensor, list[torch.Tensor]]:
# static_inputs receives an internally-subsetted value from the calling method;
# external callers should use generate_on_batch / generate_on_batch_no_target.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of scope for this PR, but the only external caller is CascadePredictor.generate, which uses this instead of generate_on_batch_no_target because it was simpler to just pass the first model's output tensor instead of making a new BatchData object out of it to input to the next model. It would be worth adding a helper function to construct a new BatchData object out of generate's output; then we could make this method private.

Copy link
Collaborator Author

@frodre frodre Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mentioned in the #959 , but I ran into the problem that inference needs knowledge of the output sizes that forces some awkward handling within that module. I think this could be solved by having a richer set of output information (e.g., coordinates) passed along in the generation/prediction. That would also allow for some smarter handling in CascadedModels as well.

generated, _, _ = self.generate(batch.data, static_inputs, n_samples)
# Ignore the passed static_inputs; derive the fine lat/lon interval from coarse
# batch coordinates via adjust_fine_coord_range, then subset self.static_inputs.
if self.config.use_fine_topography:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this check necessary? Could this instead just check if self.static_inputs is None?

Since the previous PR removed the old option of loading HGTsfc from the fine dataset, this config option can be deprecated and downstream checks could be removed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree and was thinking this and the DataRequirements.use_fine_topography would be handled in another PR to simplify things.

def lon_interval(self) -> ClosedInterval:
lon = self.latlon_coordinates.lon[0] # all batch members identical; use first
return ClosedInterval(lon.min().item(), lon.max().item())

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a test that the coordinates are as expected when BatchData.generate_from_patches is called? It looks like the existing tests for that usage only check data values.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added tests for the data and coordinates for generate_from_patches under test_datasets.py. I did not adjust the tests in test_static.py since the patch generation will be removed in #956.


def __getitem__(self, k):
return BatchItem(
{key: value[k].squeeze() for key, value in self.data.items()},
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests uncovered an error for patches of length 1 in the x/y dimension!

@frodre frodre requested a review from AnnaKwa March 12, 2026 20:43
Copy link
Contributor

@AnnaKwa AnnaKwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@frodre frodre enabled auto-merge (squash) March 13, 2026 18:04
@frodre frodre merged commit a7d8555 into main Mar 13, 2026
7 checks passed
@frodre frodre deleted the refactor/static-input-handled-by-model branch March 13, 2026 18:18
frodre added a commit that referenced this pull request Mar 17, 2026
This PR finalizes the removal of the `StaticInput` handling by the data
pipeline. The passing of static_input objects are removed from the data
configuration, batch iteration, and model call signatures in favor of
the direct model handling introduced in the previous downscaling PR
(#954).

Changes:
- add `get_fine_coords_for_batch` to facilitate translation of an input
batch domain to output coordinates via the models stored information.
For now, this relies on the model's `static_inputs`, but will be
switched to model's stored coordinates in (#971)
- inference `Downscaler` now takes the batch `input_shape` instead of
`static_inputs` to check the domain size and model type (regular
`DiffusionModel` or `PatchPredictor`
- downscaling `torch.datasets` generators for `BatchData` no longer
include `StaticInputs`
- removed `_apply_patch` and `_generate_from_patches` from
`StaticInputs`
- `config.py` no longer references static inputs as an argument

- [x] Tests added
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants