Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ice_station_zebra/config/train/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ defaults:
- device_stats
- ema_weight_averaging
- optimizer: default
- scheduler: default
- trainer: default
- _self_
11 changes: 11 additions & 0 deletions ice_station_zebra/config/train/scheduler/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# PyTorch Lightning scheduler settings
_target_: torch.optim.lr_scheduler.CosineAnnealingLR

# These parameters are part of an `lr_scheduler_config`
frequency: 1
interval: epoch

# These parameters are passed to the scheduler instance
scheduler_parameters:
eta_min: 1e-6
T_max: 10
37 changes: 32 additions & 5 deletions ice_station_zebra/models/zebra_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
import hydra
import torch
from lightning import LightningModule
from lightning.pytorch.utilities.types import OptimizerLRScheduler
from lightning.pytorch.utilities.types import (
LRSchedulerConfigType,
OptimizerConfig,
OptimizerLRScheduler,
OptimizerLRSchedulerConfig,
)
from omegaconf import DictConfig

from ice_station_zebra.types import DataSpace, ModelTestOutput, TensorNTCHW
Expand All @@ -13,7 +18,7 @@
class ZebraModel(LightningModule, ABC):
"""A base class for all models used in the Ice Station Zebra project."""

def __init__(
def __init__( # noqa: PLR0913
self,
*,
name: str,
Expand All @@ -22,6 +27,7 @@ def __init__(
n_history_steps: int,
output_space: DictConfig,
optimizer: DictConfig,
scheduler: DictConfig,
) -> None:
"""Initialise a ZebraModel.

Expand Down Expand Up @@ -49,24 +55,45 @@ def __init__(
self.input_spaces = [DataSpace.from_dict(space) for space in input_spaces]
self.output_space = DataSpace.from_dict(output_space)

# Store the optimizer config
# Store the optimizer and scheduler configs
self.optimizer_cfg = optimizer
self.scheduler_cfg = scheduler

# Save all of the arguments to __init__ as hyperparameters
# This will also save the parameters of whichever child class is used
# Note that W&B will log all hyperparameters
self.save_hyperparameters()

def configure_optimizers(self) -> OptimizerLRScheduler:
"""Construct the optimizer from the config."""
return hydra.utils.instantiate(
"""Construct the optimizer and optional scheduler from the config."""
# Optimizer
optimizer = hydra.utils.instantiate(
dict(**self.optimizer_cfg)
| {
"params": itertools.chain(
*[module.parameters() for module in self.children()]
)
}
)
# If no scheduler config is provided, return just the optimizer
if not self.scheduler_cfg:
return OptimizerConfig(optimizer=optimizer)

# Scheduler
scheduler_args = self.scheduler_cfg
scheduler = hydra.utils.instantiate(
{
"_target_": scheduler_args.pop("_target_"),
"optimizer": optimizer,
**scheduler_args.pop("scheduler_parameters", {}),
}
)

# Return the optimizer and scheduler
return OptimizerLRSchedulerConfig(
optimizer=optimizer,
lr_scheduler=LRSchedulerConfigType(scheduler=scheduler, **scheduler_args),
)

@abstractmethod
def forward(self, inputs: dict[str, TensorNTCHW]) -> TensorNTCHW:
Expand Down
1 change: 1 addition & 0 deletions ice_station_zebra/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(self, config: DictConfig) -> None:
"n_history_steps": self.data_module.n_history_steps,
"output_space": self.data_module.output_space.to_dict(),
"optimizer": config["train"]["optimizer"],
"scheduler": config["train"]["scheduler"],
},
**config["model"],
),
Expand Down
13 changes: 13 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,19 @@ def cfg_processor() -> DictConfig:
return DictConfig({"_target_": "ice_station_zebra.models.processors.NullProcessor"})


@pytest.fixture
def cfg_scheduler() -> DictConfig:
"""Test configuration for a scheduler."""
return DictConfig(
{
"_target_": "torch.optim.lr_scheduler.LinearLR",
"frequency": 1,
"interval": "epoch",
"scheduler_parameters": {"start_factor": 0.2, "end_factor": 0.8},
}
)


@pytest.fixture(scope="session")
def mock_data() -> dict[str, dict[str, Any]]:
"""Fixture to create a mock dataset for testing."""
Expand Down
2 changes: 2 additions & 0 deletions tests/models/test_encode_process_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def test_init(
n_history_steps=test_n_history_steps,
output_space=cfg_output_space,
optimizer=DictConfig({}),
scheduler=DictConfig({}),
)

assert model.name == "encode-null-decode"
Expand Down Expand Up @@ -62,6 +63,7 @@ def test_forward(
n_history_steps=test_n_history_steps,
output_space=cfg_output_space,
optimizer=DictConfig({}),
scheduler=DictConfig({}),
)
result: torch.Tensor = model(
{
Expand Down
2 changes: 2 additions & 0 deletions tests/models/test_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def test_forward_shape(
n_history_steps=test_n_history_steps,
output_space=output_space,
optimizer={},
scheduler={},
)
batch = {
"input": torch.randn(
Expand Down Expand Up @@ -73,6 +74,7 @@ def test_optimizer(self) -> None:
"shape": (1, 1),
},
optimizer={},
scheduler={},
)
assert model.configure_optimizers() is None, (
"No optimizer should be initialized"
Expand Down
41 changes: 40 additions & 1 deletion tests/models/test_zebra_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def test_init(
n_history_steps=test_n_history_steps,
output_space=output_space,
optimizer=DictConfig({}),
scheduler=DictConfig({}),
)
return

Expand All @@ -78,6 +79,7 @@ def test_init(
n_history_steps=test_n_history_steps,
output_space=output_space,
optimizer=DictConfig({}),
scheduler=DictConfig({}),
)
return

Expand All @@ -88,6 +90,7 @@ def test_init(
n_history_steps=test_n_history_steps,
output_space=output_space,
optimizer=DictConfig({}),
scheduler=DictConfig({}),
)

assert model.name == "dummy"
Expand All @@ -110,6 +113,7 @@ def test_loss(
n_history_steps=1,
output_space=cfg_output_space,
optimizer=DictConfig({}),
scheduler=DictConfig({}),
)
# Test loss
prediction = torch.zeros(1, 1, 1, 1)
Expand All @@ -129,17 +133,47 @@ def test_optimizer(
n_history_steps=1,
output_space=cfg_output_space,
optimizer=cfg_optimizer,
scheduler=DictConfig({}),
)
optimizer = model.configure_optimizers()
opt_sched_cfg = model.configure_optimizers()
assert isinstance(opt_sched_cfg, dict)
optimizer = opt_sched_cfg.get("optimizer", None)
assert isinstance(optimizer, torch.optim.Optimizer)
assert isinstance(optimizer, torch.optim.AdamW)
assert optimizer.defaults["lr"] == 5e-4

def test_scheduler(
self,
cfg_input_space: DictConfig,
cfg_optimizer: DictConfig,
cfg_output_space: DictConfig,
cfg_scheduler: DictConfig,
) -> None:
model = DummyModel(
name="dummy",
input_spaces=[cfg_input_space],
n_forecast_steps=1,
n_history_steps=1,
output_space=cfg_output_space,
optimizer=cfg_optimizer,
scheduler=cfg_scheduler,
)
opt_sched_cfg = model.configure_optimizers()
assert isinstance(opt_sched_cfg, dict)
lr_scheduler_cfg = opt_sched_cfg.get("lr_scheduler", None)
assert isinstance(lr_scheduler_cfg, dict)
scheduler = lr_scheduler_cfg.get("scheduler", None)
assert isinstance(scheduler, torch.optim.lr_scheduler.LRScheduler)
assert isinstance(scheduler, torch.optim.lr_scheduler.LinearLR)
assert scheduler.start_factor == 0.2
assert scheduler.end_factor == 0.8

def test_test_step(
self,
cfg_input_space: DictConfig,
cfg_output_space: DictConfig,
cfg_optimizer: DictConfig,
cfg_scheduler: DictConfig,
) -> None:
batch_size = n_history_steps = n_forecast_steps = 1
batch = {
Expand All @@ -165,6 +199,7 @@ def test_test_step(
n_history_steps=n_history_steps,
output_space=cfg_output_space,
optimizer=cfg_optimizer,
scheduler=cfg_scheduler,
)
output_shape = batch["target"].shape
output = model.test_step(batch, 0)
Expand All @@ -178,6 +213,7 @@ def test_training_step(
cfg_input_space: DictConfig,
cfg_output_space: DictConfig,
cfg_optimizer: DictConfig,
cfg_scheduler: DictConfig,
) -> None:
batch_size = n_history_steps = n_forecast_steps = 1
batch = {
Expand All @@ -203,6 +239,7 @@ def test_training_step(
n_history_steps=n_history_steps,
output_space=cfg_output_space,
optimizer=cfg_optimizer,
scheduler=cfg_scheduler,
)
output = model.training_step(batch, 0)
assert isinstance(output, torch.Tensor)
Expand All @@ -213,6 +250,7 @@ def test_validation_step(
cfg_input_space: DictConfig,
cfg_output_space: DictConfig,
cfg_optimizer: DictConfig,
cfg_scheduler: DictConfig,
) -> None:
batch_size = n_history_steps = n_forecast_steps = 1
batch = {
Expand All @@ -238,6 +276,7 @@ def test_validation_step(
n_history_steps=n_history_steps,
output_space=cfg_output_space,
optimizer=cfg_optimizer,
scheduler=cfg_scheduler,
)
output = model.validation_step(batch, 0)
assert isinstance(output, torch.Tensor)
Expand Down
Loading