Skip to content

Remove static_input from data pipeline and model call signatures#956

Merged
frodre merged 12 commits intomainfrom
refactor/remove-static-input-from-data-and-call-sigs
Mar 17, 2026
Merged

Remove static_input from data pipeline and model call signatures#956
frodre merged 12 commits intomainfrom
refactor/remove-static-input-from-data-and-call-sigs

Conversation

@frodre
Copy link
Collaborator

@frodre frodre commented Mar 11, 2026

This PR finalizes the removal of the StaticInput handling by the data pipeline. The passing of static_input objects are removed from the data configuration, batch iteration, and model call signatures in favor of the direct model handling introduced in the previous downscaling PR (#954).

Changes:

  • add get_fine_coords_for_batch to facilitate translation of an input batch domain to output coordinates via the models stored information. For now, this relies on the model's static_inputs, but will be switched to model's stored coordinates in (Add fine coordinates to the model for easier inference handling #971)

  • inference Downscaler now takes the batch input_shape instead of static_inputs to check the domain size and model type (regular DiffusionModel or PatchPredictor

  • downscaling torch.datasets generators for BatchData no longer include StaticInputs

  • removed _apply_patch and _generate_from_patches from StaticInputs

  • config.py no longer references static inputs as an argument

  • Tests added

Base automatically changed from refactor/static-input-handled-by-model to main March 13, 2026 18:18
@frodre frodre force-pushed the refactor/remove-static-input-from-data-and-call-sigs branch from 83abe41 to af4e22d Compare March 13, 2026 22:19
for i, (batch, static_inputs) in enumerate(self.batch_generator):
aggregator: NoTargetAggregator | None = None
for i, batch in enumerate(self.batch_generator):
if aggregator is None:
Copy link
Collaborator Author

@frodre frodre Mar 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lazy initialization, like done with inference.py to wait for availability on the output domain information

@frodre frodre marked this pull request as ready for review March 16, 2026 22:15
@frodre frodre requested a review from AnnaKwa March 16, 2026 22:16
Copy link
Contributor

@AnnaKwa AnnaKwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just a couple small suggestions

)
static_inputs_patches = static_inputs.generate_from_patches(fine_patches)
else:
static_inputs_patches = null_generator(len(coarse_patches))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can remove null_generator from utils now

)
return self.static_inputs.subset_latlon(lat_interval, lon_interval)

def get_fine_coords_for_batch(self, batch: BatchData) -> LatLonCoordinates:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: take the ClosedInterval and LatLonCoordinates as args rather than the BatchData to remove the dependence on how BatchData stores this info (e.g. won't have to change this code if switching BatchData to store the coords as LatLonCoordinates instead of BatchedLatLonCoordinates)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to leave the single argument method rather than expand into 3 arguments since I think the consolidated passing works in our favor. Expanding would force the repeat of the access pattern that wherever the method is used (and then we'd have to make those updates anyways in multiple locations when we to LatLonCoordinates at the batch level).

@frodre frodre enabled auto-merge (squash) March 17, 2026 20:47
@frodre frodre merged commit d4828e4 into main Mar 17, 2026
7 checks passed
@frodre frodre deleted the refactor/remove-static-input-from-data-and-call-sigs branch March 17, 2026 21:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants