Downscaling model handles static input instead of DataLoader#954
Downscaling model handles static input instead of DataLoader#954
Conversation
|
|
||
| @property | ||
| def lat_interval(self) -> ClosedInterval: | ||
| lat = self.latlon_coordinates.lat[0] # all batch members identical; use first |
There was a problem hiding this comment.
Given that we're no longer doing per-item random patch definitions, we could update BatchData to just use LatLonCoordinates
555cb92 to
f144258
Compare
AnnaKwa
left a comment
There was a problem hiding this comment.
I really like these changes, this is a really clean way to remove the static inputs from the data loading! Just to be super sure, can you add a test that the generated batches (when patching is used) have lat lon coords subset as expected? I think the current tests only check data values since we haven't used the coord info before.
fme/downscaling/models.py
Outdated
| self, | ||
| coarse_data: TensorMapping, | ||
| static_inputs: StaticInputs | None, | ||
| static_inputs: StaticInputs | None, # TODO: remove in follow-on PR |
There was a problem hiding this comment.
This should be kept since generate will be passed subsetted static_inputs from generate_on_batch or generate_on_batch_no_target, right?
| n_samples: int = 1, | ||
| ) -> tuple[TensorDict, torch.Tensor, list[torch.Tensor]]: | ||
| # static_inputs receives an internally-subsetted value from the calling method; | ||
| # external callers should use generate_on_batch / generate_on_batch_no_target. |
There was a problem hiding this comment.
Out of scope for this PR, but the only external caller is CascadePredictor.generate, which uses this instead of generate_on_batch_no_target because it was simpler to just pass the first model's output tensor instead of making a new BatchData object out of it to input to the next model. It would be worth adding a helper function to construct a new BatchData object out of generate's output; then we could make this method private.
There was a problem hiding this comment.
I mentioned in the #959 , but I ran into the problem that inference needs knowledge of the output sizes that forces some awkward handling within that module. I think this could be solved by having a richer set of output information (e.g., coordinates) passed along in the generation/prediction. That would also allow for some smarter handling in CascadedModels as well.
| generated, _, _ = self.generate(batch.data, static_inputs, n_samples) | ||
| # Ignore the passed static_inputs; derive the fine lat/lon interval from coarse | ||
| # batch coordinates via adjust_fine_coord_range, then subset self.static_inputs. | ||
| if self.config.use_fine_topography: |
There was a problem hiding this comment.
Is this check necessary? Could this instead just check if self.static_inputs is None?
Since the previous PR removed the old option of loading HGTsfc from the fine dataset, this config option can be deprecated and downstream checks could be removed.
There was a problem hiding this comment.
Yes, I agree and was thinking this and the DataRequirements.use_fine_topography would be handled in another PR to simplify things.
| def lon_interval(self) -> ClosedInterval: | ||
| lon = self.latlon_coordinates.lon[0] # all batch members identical; use first | ||
| return ClosedInterval(lon.min().item(), lon.max().item()) | ||
|
|
There was a problem hiding this comment.
Could you add a test that the coordinates are as expected when BatchData.generate_from_patches is called? It looks like the existing tests for that usage only check data values.
There was a problem hiding this comment.
Added tests for the data and coordinates for generate_from_patches under test_datasets.py. I did not adjust the tests in test_static.py since the patch generation will be removed in #956.
|
|
||
| def __getitem__(self, k): | ||
| return BatchItem( | ||
| {key: value[k].squeeze() for key, value in self.data.items()}, |
There was a problem hiding this comment.
Tests uncovered an error for patches of length 1 in the x/y dimension!
This PR finalizes the removal of the `StaticInput` handling by the data pipeline. The passing of static_input objects are removed from the data configuration, batch iteration, and model call signatures in favor of the direct model handling introduced in the previous downscaling PR (#954). Changes: - add `get_fine_coords_for_batch` to facilitate translation of an input batch domain to output coordinates via the models stored information. For now, this relies on the model's `static_inputs`, but will be switched to model's stored coordinates in (#971) - inference `Downscaler` now takes the batch `input_shape` instead of `static_inputs` to check the domain size and model type (regular `DiffusionModel` or `PatchPredictor` - downscaling `torch.datasets` generators for `BatchData` no longer include `StaticInputs` - removed `_apply_patch` and `_generate_from_patches` from `StaticInputs` - `config.py` no longer references static inputs as an argument - [x] Tests added
Previously, the
DataLoaderwas responsible for subsetting and device-placing static inputs (e.g. fine-resolution topography) to match each batch's spatial extent before passing them to the model. This moves that responsibility intoDiffusionModelitself: the model stores the full-domain static inputs on construction and subsets them per-batch using the batch's coordinate metadata.Changes:
DiffusionModel.__init__now calls.to_device()on static inputs at constructionNew
DiffusionModel._subset_static_inputsencapsulates fine lat/lon subsetting for train_on_batch and generate_on_batchgenerate_on_batch_no_targetderives the fine coordinate interval from coarse batch coordinates viaadjust_fine_coord_range, then subsets stored static inputsBatchDatagainslat_intervalandlon_intervalproperties for use in subsettingadjust_fine_coord_rangenow raises a clearValueErrorwhen the coordinate range is too close to the domain boundary for the required number of fine points to exist (documents the implicit ±88° latitude requirement)The
static_inputsparameter on all public model methods is retained but ignored — removal is deferred to a follow-on PRTests added