Skip to content

Commit

Permalink
Fix/ptl version 200 (unit8co#1651)
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisbader authored and alexcolpitts96 committed May 31, 2023
1 parent 7050e6e commit db13a35
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 24 deletions.
2 changes: 1 addition & 1 deletion darts/models/forecasting/pl_forecasting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

# Check whether we are running pytorch-lightning >= 1.6.0 or not:
tokens = pl.__version__.split(".")
pl_160_or_above = int(tokens[0]) >= 1 and int(tokens[1]) >= 6
pl_160_or_above = int(tokens[0]) > 1 or int(tokens[0]) == 1 and int(tokens[1]) >= 6


class PLForecastingModule(pl.LightningModule, ABC):
Expand Down
41 changes: 35 additions & 6 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import datetime
import inspect
import os
import re
import shutil
import sys
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -85,6 +86,10 @@

logger = get_logger(__name__)

# Check whether we are running pytorch-lightning >= 2.0.0 or not:
tokens = pl.__version__.split(".")
pl_200_or_above = int(tokens[0]) >= 2


def _get_checkpoint_folder(work_dir, model_name):
return os.path.join(work_dir, model_name, CHECKPOINTS_FOLDER)
Expand Down Expand Up @@ -427,25 +432,49 @@ def _init_model(self, trainer: Optional[pl.Trainer] = None) -> None:
dtype = self.train_sample[0].dtype
if np.issubdtype(dtype, np.float32):
logger.info("Time series values are 32-bits; casting model to float32.")
precision = 32
precision = "32" if not pl_200_or_above else "32-true"
elif np.issubdtype(dtype, np.float64):
logger.info("Time series values are 64-bits; casting model to float64.")
precision = 64
precision = "64" if not pl_200_or_above else "64-true"
else:
raise_log(
ValueError(
f"Invalid time series data type `{dtype}`. Cast your data to `np.float32` "
f"or `np.float64`, e.g. with `TimeSeries.astype(np.float32)`."
),
logger,
)
precision_int = int(re.findall(r"\d+", str(precision))[0])

precision_user = (
self.trainer_params.get("precision", None)
if trainer is None
else trainer.precision
)
if precision_user is not None:
# currently, we only support float 64 and 32
valid_precisions = (
["64", "32"] if not pl_200_or_above else ["64-true", "32-true"]
)
if str(precision_user) not in valid_precisions:
raise_log(
ValueError(
f"Invalid user-defined trainer_kwarg `precision={precision_user}`. "
f"Use one of ({valid_precisions})"
),
logger,
)
precision_user_int = int(re.findall(r"\d+", str(precision_user))[0])
else:
precision_user_int = None

raise_if(
precision_user is not None and int(precision_user) != precision,
f"User-defined trainer_kwarg `precision={precision_user}` does not match dtype: `{dtype}` of the "
precision_user is not None and precision_user_int != precision_int,
f"User-defined trainer_kwarg `precision='{precision_user}'` does not match dtype: `{dtype}` of the "
f"underlying TimeSeries. Set `precision` to `{precision}` or cast your data to `{precision_user}"
f"` with `TimeSeries.astype(np.float{precision_user})`.",
f"` with `TimeSeries.astype(np.float{precision_user_int})`.",
logger,
)

self.trainer_params["precision"] = precision

# we need to save the initialized TorchForecastingModel as PyTorch-Lightning only saves module checkpoints
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@
{
"input_chunk_length": 10,
"output_chunk_length": 5,
"n_epochs": 5,
"n_epochs": 10,
"random_state": 0,
"likelihood": GaussianLikelihood(),
},
Expand Down
62 changes: 46 additions & 16 deletions darts/tests/models/forecasting/test_ptl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,9 @@ def test_custom_trainer_setup(self):
self.assertEqual(trainer.max_epochs, model.epochs_trained)

def test_builtin_extended_trainer(self):
invalid_trainer_kwarg = {"precisionn": 32}

# error will be raised at training time
# wrong precision parameter name
with self.assertRaises(TypeError):
invalid_trainer_kwarg = {"precisionn": "32-true"}
model = RNNModel(
12,
"RNN",
Expand All @@ -113,20 +112,51 @@ def test_builtin_extended_trainer(self):
)
model.fit(self.series, epochs=1)

valid_trainer_kwargs = {
"precision": 32,
}
# flaot 16 not supported
with self.assertRaises(ValueError):
invalid_trainer_kwarg = {"precision": "16-mixed"}
model = RNNModel(
12,
"RNN",
10,
10,
random_state=42,
pl_trainer_kwargs=invalid_trainer_kwarg,
)
model.fit(self.series.astype(np.float16), epochs=1)

# valid parameters shouldn't raise error
model = RNNModel(
12,
"RNN",
10,
10,
random_state=42,
pl_trainer_kwargs=valid_trainer_kwargs,
)
model.fit(self.series, epochs=1)
# precision value doesn't match `series` dtype
with self.assertRaises(ValueError):
invalid_trainer_kwarg = {"precision": "64-true"}
model = RNNModel(
12,
"RNN",
10,
10,
random_state=42,
pl_trainer_kwargs=invalid_trainer_kwarg,
)
model.fit(self.series.astype(np.float32), epochs=1)

for precision, precision_int in zip(["64-true", "32-true"], [64, 32]):
valid_trainer_kwargs = {
"precision": precision,
}

# valid parameters shouldn't raise error
model = RNNModel(
12,
"RNN",
10,
10,
random_state=42,
pl_trainer_kwargs=valid_trainer_kwargs,
)
ts_dtype = getattr(np, f"float{precision_int}")
model.fit(self.series.astype(ts_dtype), epochs=1)
preds = model.predict(n=3)
assert model.trainer.precision == precision
assert preds.dtype == ts_dtype

def test_custom_callback(self):
class CounterCallback(pl.callbacks.Callback):
Expand Down

0 comments on commit db13a35

Please sign in to comment.