From cd6aacdab95b3ff61a111727e821da49aff99b0b Mon Sep 17 00:00:00 2001 From: James Robinson Date: Mon, 6 Oct 2025 17:03:25 +0100 Subject: [PATCH 1/3] :wrench: Add scheduler config --- ice_station_zebra/config/train/default.yaml | 1 + ice_station_zebra/config/train/scheduler/default.yaml | 11 +++++++++++ 2 files changed, 12 insertions(+) create mode 100644 ice_station_zebra/config/train/scheduler/default.yaml diff --git a/ice_station_zebra/config/train/default.yaml b/ice_station_zebra/config/train/default.yaml index 02414daf..d3983207 100644 --- a/ice_station_zebra/config/train/default.yaml +++ b/ice_station_zebra/config/train/default.yaml @@ -3,5 +3,6 @@ defaults: - device_stats - ema_weight_averaging - optimizer: default + - scheduler: default - trainer: default - _self_ diff --git a/ice_station_zebra/config/train/scheduler/default.yaml b/ice_station_zebra/config/train/scheduler/default.yaml new file mode 100644 index 00000000..8de0dcfd --- /dev/null +++ b/ice_station_zebra/config/train/scheduler/default.yaml @@ -0,0 +1,11 @@ +# PyTorch Lightning scheduler settings +_target_: torch.optim.lr_scheduler.CosineAnnealingLR + +# These parameters are part of an `lr_scheduler_config` +frequency: 1 +interval: epoch + +# These parameters are passed to the scheduler instance +scheduler_parameters: + eta_min: 1e-6 + T_max: 10 From 7deab419a24a0ddbf6f32eb8bef8b18dda0bc90e Mon Sep 17 00:00:00 2001 From: James Robinson Date: Mon, 6 Oct 2025 17:04:18 +0100 Subject: [PATCH 2/3] :sparkles: Add support for scheduler initialised from config file --- ice_station_zebra/models/zebra_model.py | 37 +++++++++++++++++++++---- ice_station_zebra/training/trainer.py | 1 + 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/ice_station_zebra/models/zebra_model.py b/ice_station_zebra/models/zebra_model.py index 37cddba9..813c4e1f 100644 --- a/ice_station_zebra/models/zebra_model.py +++ b/ice_station_zebra/models/zebra_model.py @@ -4,7 +4,12 @@ import hydra import torch from lightning import LightningModule -from lightning.pytorch.utilities.types import OptimizerLRScheduler +from lightning.pytorch.utilities.types import ( + LRSchedulerConfigType, + OptimizerConfig, + OptimizerLRScheduler, + OptimizerLRSchedulerConfig, +) from omegaconf import DictConfig from ice_station_zebra.types import DataSpace, ModelTestOutput, TensorNTCHW @@ -13,7 +18,7 @@ class ZebraModel(LightningModule, ABC): """A base class for all models used in the Ice Station Zebra project.""" - def __init__( + def __init__( # noqa: PLR0913 self, *, name: str, @@ -22,6 +27,7 @@ def __init__( n_history_steps: int, output_space: DictConfig, optimizer: DictConfig, + scheduler: DictConfig, ) -> None: """Initialise a ZebraModel. @@ -49,8 +55,9 @@ def __init__( self.input_spaces = [DataSpace.from_dict(space) for space in input_spaces] self.output_space = DataSpace.from_dict(output_space) - # Store the optimizer config + # Store the optimizer and scheduler configs self.optimizer_cfg = optimizer + self.scheduler_cfg = scheduler # Save all of the arguments to __init__ as hyperparameters # This will also save the parameters of whichever child class is used @@ -58,8 +65,9 @@ def __init__( self.save_hyperparameters() def configure_optimizers(self) -> OptimizerLRScheduler: - """Construct the optimizer from the config.""" - return hydra.utils.instantiate( + """Construct the optimizer and optional scheduler from the config.""" + # Optimizer + optimizer = hydra.utils.instantiate( dict(**self.optimizer_cfg) | { "params": itertools.chain( @@ -67,6 +75,25 @@ def configure_optimizers(self) -> OptimizerLRScheduler: ) } ) + # If no scheduler config is provided, return just the optimizer + if not self.scheduler_cfg: + return OptimizerConfig(optimizer=optimizer) + + # Scheduler + scheduler_args = self.scheduler_cfg + scheduler = hydra.utils.instantiate( + { + "_target_": scheduler_args.pop("_target_"), + "optimizer": optimizer, + **scheduler_args.pop("scheduler_parameters", {}), + } + ) + + # Return the optimizer and scheduler + return OptimizerLRSchedulerConfig( + optimizer=optimizer, + lr_scheduler=LRSchedulerConfigType(scheduler=scheduler, **scheduler_args), + ) @abstractmethod def forward(self, inputs: dict[str, TensorNTCHW]) -> TensorNTCHW: diff --git a/ice_station_zebra/training/trainer.py b/ice_station_zebra/training/trainer.py index f7f5ed4c..ad7a3d7a 100644 --- a/ice_station_zebra/training/trainer.py +++ b/ice_station_zebra/training/trainer.py @@ -39,6 +39,7 @@ def __init__(self, config: DictConfig) -> None: "n_history_steps": self.data_module.n_history_steps, "output_space": self.data_module.output_space.to_dict(), "optimizer": config["train"]["optimizer"], + "scheduler": config["train"]["scheduler"], }, **config["model"], ), From c5a1f3c58a2398a881e91f9b86c2b8cbef8f51dc Mon Sep 17 00:00:00 2001 From: James Robinson Date: Mon, 6 Oct 2025 17:04:38 +0100 Subject: [PATCH 3/3] :white_check_mark: Update tests to support scheduler usage --- tests/conftest.py | 13 +++++++ tests/models/test_encode_process_decode.py | 2 ++ tests/models/test_persistence.py | 2 ++ tests/models/test_zebra_model.py | 41 +++++++++++++++++++++- 4 files changed, 57 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index fa9c7d5f..24998843 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -65,6 +65,19 @@ def cfg_processor() -> DictConfig: return DictConfig({"_target_": "ice_station_zebra.models.processors.NullProcessor"}) +@pytest.fixture +def cfg_scheduler() -> DictConfig: + """Test configuration for a scheduler.""" + return DictConfig( + { + "_target_": "torch.optim.lr_scheduler.LinearLR", + "frequency": 1, + "interval": "epoch", + "scheduler_parameters": {"start_factor": 0.2, "end_factor": 0.8}, + } + ) + + @pytest.fixture(scope="session") def mock_data() -> dict[str, dict[str, Any]]: """Fixture to create a mock dataset for testing.""" diff --git a/tests/models/test_encode_process_decode.py b/tests/models/test_encode_process_decode.py index f3dcabe6..d0a565ae 100644 --- a/tests/models/test_encode_process_decode.py +++ b/tests/models/test_encode_process_decode.py @@ -28,6 +28,7 @@ def test_init( n_history_steps=test_n_history_steps, output_space=cfg_output_space, optimizer=DictConfig({}), + scheduler=DictConfig({}), ) assert model.name == "encode-null-decode" @@ -62,6 +63,7 @@ def test_forward( n_history_steps=test_n_history_steps, output_space=cfg_output_space, optimizer=DictConfig({}), + scheduler=DictConfig({}), ) result: torch.Tensor = model( { diff --git a/tests/models/test_persistence.py b/tests/models/test_persistence.py index 2bfe60ca..8f68fe44 100644 --- a/tests/models/test_persistence.py +++ b/tests/models/test_persistence.py @@ -35,6 +35,7 @@ def test_forward_shape( n_history_steps=test_n_history_steps, output_space=output_space, optimizer={}, + scheduler={}, ) batch = { "input": torch.randn( @@ -73,6 +74,7 @@ def test_optimizer(self) -> None: "shape": (1, 1), }, optimizer={}, + scheduler={}, ) assert model.configure_optimizers() is None, ( "No optimizer should be initialized" diff --git a/tests/models/test_zebra_model.py b/tests/models/test_zebra_model.py index 1b77b451..959629b5 100644 --- a/tests/models/test_zebra_model.py +++ b/tests/models/test_zebra_model.py @@ -63,6 +63,7 @@ def test_init( n_history_steps=test_n_history_steps, output_space=output_space, optimizer=DictConfig({}), + scheduler=DictConfig({}), ) return @@ -78,6 +79,7 @@ def test_init( n_history_steps=test_n_history_steps, output_space=output_space, optimizer=DictConfig({}), + scheduler=DictConfig({}), ) return @@ -88,6 +90,7 @@ def test_init( n_history_steps=test_n_history_steps, output_space=output_space, optimizer=DictConfig({}), + scheduler=DictConfig({}), ) assert model.name == "dummy" @@ -110,6 +113,7 @@ def test_loss( n_history_steps=1, output_space=cfg_output_space, optimizer=DictConfig({}), + scheduler=DictConfig({}), ) # Test loss prediction = torch.zeros(1, 1, 1, 1) @@ -129,17 +133,47 @@ def test_optimizer( n_history_steps=1, output_space=cfg_output_space, optimizer=cfg_optimizer, + scheduler=DictConfig({}), ) - optimizer = model.configure_optimizers() + opt_sched_cfg = model.configure_optimizers() + assert isinstance(opt_sched_cfg, dict) + optimizer = opt_sched_cfg.get("optimizer", None) assert isinstance(optimizer, torch.optim.Optimizer) assert isinstance(optimizer, torch.optim.AdamW) assert optimizer.defaults["lr"] == 5e-4 + def test_scheduler( + self, + cfg_input_space: DictConfig, + cfg_optimizer: DictConfig, + cfg_output_space: DictConfig, + cfg_scheduler: DictConfig, + ) -> None: + model = DummyModel( + name="dummy", + input_spaces=[cfg_input_space], + n_forecast_steps=1, + n_history_steps=1, + output_space=cfg_output_space, + optimizer=cfg_optimizer, + scheduler=cfg_scheduler, + ) + opt_sched_cfg = model.configure_optimizers() + assert isinstance(opt_sched_cfg, dict) + lr_scheduler_cfg = opt_sched_cfg.get("lr_scheduler", None) + assert isinstance(lr_scheduler_cfg, dict) + scheduler = lr_scheduler_cfg.get("scheduler", None) + assert isinstance(scheduler, torch.optim.lr_scheduler.LRScheduler) + assert isinstance(scheduler, torch.optim.lr_scheduler.LinearLR) + assert scheduler.start_factor == 0.2 + assert scheduler.end_factor == 0.8 + def test_test_step( self, cfg_input_space: DictConfig, cfg_output_space: DictConfig, cfg_optimizer: DictConfig, + cfg_scheduler: DictConfig, ) -> None: batch_size = n_history_steps = n_forecast_steps = 1 batch = { @@ -165,6 +199,7 @@ def test_test_step( n_history_steps=n_history_steps, output_space=cfg_output_space, optimizer=cfg_optimizer, + scheduler=cfg_scheduler, ) output_shape = batch["target"].shape output = model.test_step(batch, 0) @@ -178,6 +213,7 @@ def test_training_step( cfg_input_space: DictConfig, cfg_output_space: DictConfig, cfg_optimizer: DictConfig, + cfg_scheduler: DictConfig, ) -> None: batch_size = n_history_steps = n_forecast_steps = 1 batch = { @@ -203,6 +239,7 @@ def test_training_step( n_history_steps=n_history_steps, output_space=cfg_output_space, optimizer=cfg_optimizer, + scheduler=cfg_scheduler, ) output = model.training_step(batch, 0) assert isinstance(output, torch.Tensor) @@ -213,6 +250,7 @@ def test_validation_step( cfg_input_space: DictConfig, cfg_output_space: DictConfig, cfg_optimizer: DictConfig, + cfg_scheduler: DictConfig, ) -> None: batch_size = n_history_steps = n_forecast_steps = 1 batch = { @@ -238,6 +276,7 @@ def test_validation_step( n_history_steps=n_history_steps, output_space=cfg_output_space, optimizer=cfg_optimizer, + scheduler=cfg_scheduler, ) output = model.validation_step(batch, 0) assert isinstance(output, torch.Tensor)