From 35e96b9cadbbf17c9c49f8bf5fac35ca4ed10dfd Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Wed, 23 Feb 2022 22:11:50 +0100 Subject: [PATCH 1/3] add LinearRegression class --- mesmer/core/linear_regression.py | 151 +++++++++++++++++++- mesmer/core/utils.py | 103 ++++++++++--- tests/integration/test_linear_regression.py | 120 +++++++++++++--- tests/integration/test_utils.py | 75 +++++++++- 4 files changed, 404 insertions(+), 45 deletions(-) diff --git a/mesmer/core/linear_regression.py b/mesmer/core/linear_regression.py index 57f40da6..2547c456 100644 --- a/mesmer/core/linear_regression.py +++ b/mesmer/core/linear_regression.py @@ -2,9 +2,148 @@ import numpy as np import xarray as xr -from sklearn.linear_model import LinearRegression -from .utils import _check_dataarray_form +from .utils import _check_dataarray_form, _check_dataset_form + + +class LinearRegression: + def __init__(self): + self._params = None + + def fit( + self, + predictors: Mapping[str, xr.DataArray], + target: xr.DataArray, + dim: str, + weights: Optional[xr.DataArray] = None, + ): + """ + Fit a linear model + + Parameters + ---------- + predictors : dict of xr.DataArray + A dict of DataArray objects used as predictors. Must be 1D and contain `dim`. + + target : xr.DataArray + Target DataArray. Must be 2D and contain `dim`. + + dim : str + Dimension along which to fit the polynomials. + + weights : xr.DataArray, default: None. + Individual weights for each sample. Must be 1D and contain `dim`. + """ + + params = linear_regression( + predictors=predictors, + target=target, + dim=dim, + weights=weights, + ) + + self._params = params + + def predict( + self, + predictors: Mapping[str, xr.DataArray], + ): + """ + Predict using the linear model. + + Parameters + ---------- + predictors : dict of xr.DataArray + A dict of DataArray objects used as predictors. Must be 1D and contain `dim`. + + Returns + ------- + prediction : xr.DataArray + Returns predicted values. + """ + + params = self.params + + required_predictors = set(params.data_vars) - set(["intercept", "weights"]) + available_predictors = set(predictors.keys()) + + if required_predictors != available_predictors: + raise ValueError("Missing or superflous predictors.") + + prediction = params.intercept + for key in required_predictors: + prediction = prediction + predictors[key] * params[key] + + return prediction + + def residuals( + self, + predictors: Mapping[str, xr.DataArray], + target: xr.DataArray, + ): + """ + Calculate the residuals of the fitted linear model + + Parameters + ---------- + predictors : dict of xr.DataArray + A dict of DataArray objects used as predictors. Must be 1D and contain `dim`. + + target : xr.DataArray + Target DataArray. Must be 2D and contain `dim`. + + Returns + ------- + residuals : xr.DataArray + Returns residuals - the difference between the predicted values and target. + + """ + + prediction = self.predict(predictors) + + residuals = target - prediction + + return residuals + + @property + def params(self): + """ + The parameters of this estimator. + """ + + if self._params is None: + raise ValueError( + "'params' not set - call `fit` or assign them to `LinearRegression().params`." + ) + + return self._params + + @params.setter + def params(self, params): + + _check_dataset_form( + params, + "params", + required_vars="intercept", + optional_vars="weights", + requires_other_vars=True, + ) + + self._params = params + + @classmethod + def from_netcdf(cls, filename, **kwargs): + ds = xr.open_dataset(filename, **kwargs) + + obj = cls() + obj.params = ds + + return obj + + def to_netcdf(self, filename, **kwargs): + params = self.params() + + params.to_netcdf(filename, **kwargs) def linear_regression( @@ -40,6 +179,11 @@ def linear_regression( if not isinstance(predictors, Mapping): raise TypeError(f"predictors should be a dict, got {type(predictors)}.") + if ("weights" in predictors) or ("intercept" in predictors): + raise ValueError( + "A predictor with the name 'weights' or 'intercept' is not allowed" + ) + for key, pred in predictors.items(): _check_dataarray_form(pred, ndim=1, required_dims=dim, name=f"predictor: {key}") @@ -100,6 +244,9 @@ def _linear_regression(predictors, target, weights=None): followed by the intercept for each predictor (in the same order as the columns of ``predictors``). """ + + from sklearn.linear_model import LinearRegression + reg = LinearRegression() reg.fit(X=predictors, y=target, sample_weight=weights) diff --git a/mesmer/core/utils.py b/mesmer/core/utils.py index 7b805541..5fe1dc55 100644 --- a/mesmer/core/utils.py +++ b/mesmer/core/utils.py @@ -3,30 +3,101 @@ import xarray as xr +def _to_set(arg): + + if arg is None: + arg = {} + + if isinstance(arg, str): + arg = {arg} + + arg = set(arg) + + return arg + + +def _check_dataset_form( + obj, + name: str = "obj", + *, + required_vars: Union[str, Set[str]] = set(), + optional_vars: Union[str, Set[str]] = set(), + requires_other_vars: bool = False, +): + """check if a dataset conforms to some conditions + + obj: Any + object to check. + name : str, default: 'obj' + Name to use in error messages. + required_vars, str, set of str, optional + Variables that obj is required to contain. + optional_vars: str, set of str, optional + Variables that the obj may contain, only + relevant if `requires_other_vars` is True + requires_other_vars: bool, default: False + obj is required to contain other variables than + required_vars or optional_vars + + Raises + ------ + TypeError: if obj is not a xr.Dataset + ValueError: if any of the conditions is violated + + """ + + required_vars = _to_set(required_vars) + optional_vars = _to_set(optional_vars) + + if not isinstance(obj, xr.Dataset): + raise TypeError(f"Expected {name} to be an xr.Dataset, got {type(obj)}") + + data_vars = set(obj.data_vars) + + missing_vars = required_vars - data_vars + if missing_vars: + missing_vars = ",".join(missing_vars) + raise ValueError(f"{name} is missing the required data_vars: {missing_vars}") + + n_vars_except = len(data_vars - (required_vars | optional_vars)) + if requires_other_vars and n_vars_except == 0: + + raise ValueError(f"Expected additional variables on {name}") + + def _check_dataarray_form( - da: xr.DataArray, - name: str = None, + obj, + name: str = "obj", + *, ndim: int = None, - required_dims: Union[str, Set[str]] = {}, + required_dims: Union[str, Set[str]] = set(), ): + """check if a dataset conforms to some conditions - if name is None: - name = "da" + obj: Any + object to check. + name : str, default: 'obj' + Name to use in error messages. + ndim, int, optional + Number of required dimensions + required_dims: str, set of str, optional + Names of dims that are required for obj - if isinstance(required_dims, str): - required_dims = {required_dims} + Raises + ------ + TypeError: if obj is not a xr.DataArray + ValueError: if any of the conditions is violated - required_dims = set(required_dims) + """ - if required_dims is None: - required_dims = {} + required_dims = _to_set(required_dims) - if not isinstance(da, xr.DataArray): - raise TypeError(f"Expected {name} to be an xr.DataArray, got {type(da)}") + if not isinstance(obj, xr.DataArray): + raise TypeError(f"Expected {name} to be an xr.DataArray, got {type(obj)}") - if ndim is not None and ndim != da.ndim: - raise ValueError(f"{name} should be {ndim}-dimensional, but is {da.ndim}D") + if ndim is not None and ndim != obj.ndim: + raise ValueError(f"{name} should be {ndim}-dimensional, but is {obj.ndim}D") - if required_dims - set(da.dims): - missing_dims = " ,".join(required_dims - set(da.dims)) + if required_dims - set(obj.dims): + missing_dims = " ,".join(required_dims - set(obj.dims)) raise ValueError(f"{name} is missing the required dims: {missing_dims}") diff --git a/tests/integration/test_linear_regression.py b/tests/integration/test_linear_regression.py index 5ffa9079..adccf975 100644 --- a/tests/integration/test_linear_regression.py +++ b/tests/integration/test_linear_regression.py @@ -9,10 +9,93 @@ from .utils import trend_data_1D, trend_data_2D -# TEST XARRAY WRAPPER +def LinearRegression_fit_wrapper(*args, **kwargs): + # wrapper for LinearRegression().fit() because it has no return value - should it? -def test_linear_regression_errors(): + lr = mesmer.core.linear_regression.LinearRegression() + + lr.fit(*args, **kwargs) + return lr.params + + +LR_METHOD_OR_FUNCTION = [ + mesmer.core.linear_regression.linear_regression, + LinearRegression_fit_wrapper, +] + +# TEST LinearRegression class + + +def test_LR_params(): + + lr = mesmer.core.linear_regression.LinearRegression() + + with pytest.raises(ValueError, match="'params' not set"): + lr.params + + with pytest.raises(TypeError, match="Expected params to be an xr.Dataset"): + lr.params = None + + with pytest.raises(ValueError, match="missing the required data_vars"): + lr.params = xr.Dataset() + + with pytest.raises(ValueError, match="missing the required data_vars"): + lr.params = xr.Dataset(data_vars={"weights": ("x", [5])}) + + with pytest.raises(ValueError, match="Expected additional variables"): + lr.params = xr.Dataset(data_vars={"intercept": ("x", [5])}) + + ds = xr.Dataset(data_vars={"intercept": ("x", [5]), "weights": ("y", [5])}) + with pytest.raises(ValueError, match="Expected additional variables"): + lr.params = ds + + ds = xr.Dataset(data_vars={"intercept": ("x", [5]), "tas": ("y", [5])}) + lr.params = ds + + xr.testing.assert_equal(ds, lr.params) + + +def test_LR_predict(): + lr = mesmer.core.linear_regression.LinearRegression() + + params = xr.Dataset(data_vars={"intercept": ("x", [5]), "tas": ("x", [3])}) + lr.params = params + + with pytest.raises(ValueError, match="Missing or superflous predictors"): + lr.predict({}) + + with pytest.raises(ValueError, match="Missing or superflous predictors"): + lr.predict({"tas": None, "something else": None}) + + tas = xr.DataArray([0, 1, 2], dims="time") + + result = lr.predict({"tas": tas}) + expected = xr.DataArray([[5, 8, 11]], dims=("x", "time")) + + xr.testing.assert_equal(result, expected) + + +def test_LR_residuals(): + + lr = mesmer.core.linear_regression.LinearRegression() + + params = xr.Dataset(data_vars={"intercept": ("x", [5]), "tas": ("x", [0])}) + lr.params = params + + tas = xr.DataArray([0, 1, 2], dims="time") + target = xr.DataArray([[5, 8, 0]], dims=("x", "time")) + + expected = xr.DataArray([[0, 3, -5]], dims=("x", "time")) + + result = lr.residuals({"tas": tas}, target) + + xr.testing.assert_equal(expected, result) + + +# TEST XARRAY WRAPPER & LinearRegression().fit +@pytest.mark.parametrize("lr_method_or_function", LR_METHOD_OR_FUNCTION) +def test_linear_regression_errors(lr_method_or_function): pred0 = trend_data_1D() pred1 = trend_data_1D() @@ -24,14 +107,14 @@ def test_linear_regression_errors(): weights = trend_data_1D(intercept=1, slope=0, scale=0) with pytest.raises(TypeError, match="predictors should be a dict"): - mesmer.core.linear_regression.linear_regression(pred0, tgt, dim="time") + lr_method_or_function(pred0, tgt, dim="time") def test_unequal_coords(pred0, pred1, tgt, weights): with pytest.raises( ValueError, match="indexes along dimension 'time' are not equal" ): - mesmer.core.linear_regression.linear_regression( + lr_method_or_function( {"pred0": pred0, "pred1": pred1}, tgt, dim="time", weights=weights ) @@ -42,7 +125,7 @@ def test_unequal_coords(pred0, pred1, tgt, weights): def test_wrong_type(pred0, pred1, tgt, weights, name): with pytest.raises(TypeError, match=f"Expected {name} to be an xr.DataArray"): - mesmer.core.linear_regression.linear_regression( + lr_method_or_function( {"pred0": pred0, "pred1": pred1}, tgt, dim="time", weights=weights ) @@ -53,7 +136,7 @@ def test_wrong_type(pred0, pred1, tgt, weights, name): def test_wrong_shape(pred0, pred1, tgt, weights, name, ndim): with pytest.raises(ValueError, match=f"{name} should be {ndim}-dimensional"): - mesmer.core.linear_regression.linear_regression( + lr_method_or_function( {"pred0": pred0, "pred1": pred1}, tgt, dim="time", weights=weights ) @@ -72,7 +155,7 @@ def test_wrong_shape(pred0, pred1, tgt, weights, name, ndim): def test_missing_dim(pred0, pred1, tgt, weights, name): with pytest.raises(ValueError, match=f"{name} is missing the required dims"): - mesmer.core.linear_regression.linear_regression( + lr_method_or_function( {"pred0": pred0, "pred1": pred1}, tgt, dim="time", weights=weights ) @@ -86,16 +169,15 @@ def test_missing_dim(pred0, pred1, tgt, weights, name): test_missing_dim(pred0, pred1, tgt, weights.rename(time="t"), name="weights") +@pytest.mark.parametrize("lr_method_or_function", LR_METHOD_OR_FUNCTION) @pytest.mark.parametrize("intercept", (0, 3.14)) @pytest.mark.parametrize("slope", (0, 3.14)) -def test_linear_regression_one_predictor(intercept, slope): +def test_linear_regression_one_predictor(lr_method_or_function, intercept, slope): pred0 = trend_data_1D(slope=1, scale=0) tgt = trend_data_2D(slope=slope, scale=0, intercept=intercept) - result = mesmer.core.linear_regression.linear_regression( - {"pred0": pred0}, tgt, "time" - ) + result = lr_method_or_function({"pred0": pred0}, tgt, "time") template = tgt.isel(time=0, drop=True) @@ -107,17 +189,16 @@ def test_linear_regression_one_predictor(intercept, slope): xr.testing.assert_allclose(result, expected) +@pytest.mark.parametrize("lr_method_or_function", LR_METHOD_OR_FUNCTION) @pytest.mark.parametrize("intercept", (0, 3.14)) @pytest.mark.parametrize("slope", (0, 3.14)) -def test_linear_regression_two_predictors(intercept, slope): +def test_linear_regression_two_predictors(lr_method_or_function, intercept, slope): pred0 = trend_data_1D(slope=1, scale=0) pred1 = trend_data_1D(slope=1, scale=0) tgt = trend_data_2D(slope=slope, scale=0, intercept=intercept) - result = mesmer.core.linear_regression.linear_regression( - {"pred0": pred0, "pred1": pred1}, tgt, "time" - ) + result = lr_method_or_function({"pred0": pred0, "pred1": pred1}, tgt, "time") template = tgt.isel(time=0, drop=True) @@ -136,8 +217,9 @@ def test_linear_regression_two_predictors(intercept, slope): xr.testing.assert_allclose(result, expected) +@pytest.mark.parametrize("lr_method_or_function", LR_METHOD_OR_FUNCTION) @pytest.mark.parametrize("intercept", (0, 3.14)) -def test_linear_regression_weights(intercept): +def test_linear_regression_weights(lr_method_or_function, intercept): pred0 = trend_data_1D(slope=1, scale=0) tgt = trend_data_2D(slope=1, scale=0, intercept=intercept) @@ -145,9 +227,7 @@ def test_linear_regression_weights(intercept): weights = trend_data_1D(intercept=0, slope=0, scale=0) weights[0] = 1 - result = mesmer.core.linear_regression.linear_regression( - {"pred0": pred0}, tgt, "time", weights=weights - ) + result = lr_method_or_function({"pred0": pred0}, tgt, "time", weights=weights) template = tgt.isel(time=0, drop=True) @@ -344,7 +424,7 @@ def test_linear_regression_np(predictors, target, weight): mock_regressor.coef_ = [123, -38] with mock.patch( - "mesmer.core.linear_regression.LinearRegression" + "sklearn.linear_model.LinearRegression" ) as mocked_linear_regression: mocked_linear_regression.return_value = mock_regressor diff --git a/tests/integration/test_utils.py b/tests/integration/test_utils.py index 1401062b..3bc7031f 100644 --- a/tests/integration/test_utils.py +++ b/tests/integration/test_utils.py @@ -5,14 +5,75 @@ import mesmer.core.utils -@pytest.mark.parametrize("da", (None, xr.Dataset())) -def test_check_dataarray_form_wrong_type(da): +@pytest.mark.parametrize("obj", (None, xr.DataArray())) +def test_check_dataset_form_wrong_type(obj): - with pytest.raises(TypeError, match="Expected da to be an xr.DataArray"): - mesmer.core.utils._check_dataarray_form(da) + with pytest.raises(TypeError, match="Expected obj to be an xr.Dataset"): + mesmer.core.utils._check_dataset_form(obj) + + with pytest.raises(TypeError, match="Expected test to be an xr.Dataset"): + mesmer.core.utils._check_dataset_form(obj, name="test") + + +def test_check_dataset_form_required_vars(): + + ds = xr.Dataset() + + with pytest.raises(ValueError, match="obj is missing the required data_vars"): + mesmer.core.utils._check_dataset_form(ds, required_vars="missing") + + with pytest.raises(ValueError, match="test is missing the required data_vars"): + mesmer.core.utils._check_dataset_form(ds, "test", required_vars="missing") + + # no error + mesmer.core.utils._check_dataset_form(ds) + mesmer.core.utils._check_dataset_form(ds, required_vars=set()) + mesmer.core.utils._check_dataset_form(ds, required_vars=None) + + ds = xr.Dataset(data_vars={"var": ("x", [0])}) + + # no error + mesmer.core.utils._check_dataset_form(ds) + mesmer.core.utils._check_dataset_form(ds, required_vars="var") + mesmer.core.utils._check_dataset_form(ds, required_vars={"var"}) + + +def test_check_dataset_form_requires_other_vars(): + + ds = xr.Dataset() + + with pytest.raises(ValueError, match="Expected additional variables on obj"): + mesmer.core.utils._check_dataset_form(ds, requires_other_vars=True) + + with pytest.raises(ValueError, match="Expected additional variables on test"): + mesmer.core.utils._check_dataset_form(ds, "test", requires_other_vars=True) + + with pytest.raises(ValueError, match="Expected additional variables on obj"): + mesmer.core.utils._check_dataset_form( + ds, optional_vars="var", requires_other_vars=True + ) + + ds = xr.Dataset(data_vars={"var": ("x", [0])}) + + with pytest.raises(ValueError, match="Expected additional variables on obj"): + mesmer.core.utils._check_dataset_form( + ds, required_vars="var", requires_other_vars=True + ) + + with pytest.raises(ValueError, match="Expected additional variables on obj"): + mesmer.core.utils._check_dataset_form( + ds, optional_vars="var", requires_other_vars=True + ) + + +@pytest.mark.parametrize("obj", (None, xr.Dataset())) +def test_check_dataarray_form_wrong_type(obj): + + with pytest.raises(TypeError, match="Expected obj to be an xr.DataArray"): + mesmer.core.utils._check_dataarray_form(obj) with pytest.raises(TypeError, match="Expected test to be an xr.DataArray"): - mesmer.core.utils._check_dataarray_form(da, name="test") + mesmer.core.utils._check_dataarray_form(obj, name="test") @pytest.mark.parametrize("ndim", (0, 1, 3)) @@ -20,7 +81,7 @@ def test_check_dataarray_form_ndim(ndim): da = xr.DataArray(np.ones((2, 2))) - with pytest.raises(ValueError, match=f"da should be {ndim}-dimensional"): + with pytest.raises(ValueError, match=f"obj should be {ndim}-dimensional"): mesmer.core.utils._check_dataarray_form(da, ndim=ndim) with pytest.raises(ValueError, match=f"test should be {ndim}-dimensional"): @@ -35,7 +96,7 @@ def test_check_dataarray_form_required_dims(required_dims): da = xr.DataArray(np.ones((2, 2)), dims=("x", "y")) - with pytest.raises(ValueError, match="da is missing the required dims"): + with pytest.raises(ValueError, match="obj is missing the required dims"): mesmer.core.utils._check_dataarray_form(da, required_dims=required_dims) with pytest.raises(ValueError, match="test is missing the required dims"): From 2c39343e0a3eb7f5f5f0ae29f3b27fd87021cc54 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Sun, 20 Mar 2022 18:57:48 +0100 Subject: [PATCH 2/3] changelog and docs --- CHANGELOG.rst | 5 +++- docs/source/api.rst | 8 ++++-- docs/source/development.rst | 6 ++-- mesmer/core/linear_regression.py | 32 +++++++++++++++++---- tests/integration/test_linear_regression.py | 2 ++ 5 files changed, 41 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 98d9a963..530464a6 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -6,7 +6,10 @@ v0.9.0 - unreleased New Features ^^^^^^^^^^^^ - +- Create :py:class:`mesmer.core.linear_regression.LinearRegression` which encapsulates + ``fit``, ``predict``, etc. methods aroung linear regression + (`#134 `_). + By `Mathias Hauser `_. - Add ``mesmer.core.linear_regression``: xarray wrapper for ``mesmer.core._linear_regression``. (`#123 `_). By `Mathias Hauser `_. diff --git a/docs/source/api.rst b/docs/source/api.rst index 0ee1d7dc..22f679ba 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -16,8 +16,12 @@ Statistical core functions .. autosummary:: :toctree: generated/ - ~core.linear_regression.linear_regression - + ~core.linear_regression.LinearRegression + ~core.linear_regression.LinearRegression.fit + ~core.linear_regression.LinearRegression.predict + ~core.linear_regression.LinearRegression.residuals + ~core.linear_regression.LinearRegression.to_netcdf + ~core.linear_regression.LinearRegression.from_netcdf Train mesmer ------------ diff --git a/docs/source/development.rst b/docs/source/development.rst index 54278f4c..30312316 100644 --- a/docs/source/development.rst +++ b/docs/source/development.rst @@ -79,7 +79,7 @@ We include links with each of these tools to starting points that we think are u - `Jupyter Notebooks `_ - - Jupyter is automatically included in your virtual environment if you follow our `Getting setup`_ instructions + - Jupyter is automatically included in your virtual environment if you follow our `Development setup`_ instructions - Sphinx_ @@ -124,7 +124,7 @@ We use the following tools: - `flake8 `_ to check the format and small errors These automatically format the code for us and tell use where the errors are. -To use them, after setting yourself up (see `Getting setup`_), simply run ``make format``. +To use them, after setting yourself up (see `Development setup`_), simply run ``make format``. Note that ``make format`` can only be run if you have committed all your work i.e. your working directory is 'clean'. This restriction is made to ensure that you don't format code without being able to undo it, just in case something goes wrong. @@ -132,7 +132,7 @@ This restriction is made to ensure that you don't format code without being able Buiding the docs ---------------- -After setting yourself up (see `Getting setup`_), building the docs is as simple as running ``make docs`` (note, run ``make -B docs`` to force the docs to rebuild and ignore make when it says '... index.html is up to date'). +After setting yourself up (see `Development setup`_), building the docs is as simple as running ``make docs`` (note, run ``make -B docs`` to force the docs to rebuild and ignore make when it says '... index.html is up to date'). This will build the docs for you. You can preview them by opening ``docs/build/html/index.html`` in a browser. diff --git a/mesmer/core/linear_regression.py b/mesmer/core/linear_regression.py index 2547c456..2a57e9dc 100644 --- a/mesmer/core/linear_regression.py +++ b/mesmer/core/linear_regression.py @@ -7,6 +7,8 @@ class LinearRegression: + """Ordinary least squares Linear Regression for xarray.DataArray objects.""" + def __init__(self): self._params = None @@ -23,7 +25,8 @@ def fit( Parameters ---------- predictors : dict of xr.DataArray - A dict of DataArray objects used as predictors. Must be 1D and contain `dim`. + A dict of DataArray objects used as predictors. Must be 1D and contain + `dim`. target : xr.DataArray Target DataArray. Must be 2D and contain `dim`. @@ -107,13 +110,12 @@ def residuals( @property def params(self): - """ - The parameters of this estimator. - """ + """The parameters of this estimator.""" if self._params is None: raise ValueError( - "'params' not set - call `fit` or assign them to `LinearRegression().params`." + "'params' not set - call `fit` or assign them to " + "`LinearRegression().params`." ) return self._params @@ -133,6 +135,15 @@ def params(self, params): @classmethod def from_netcdf(cls, filename, **kwargs): + """read params from a netCDF file + + Parameters + ---------- + filename : str + Name of the netCDF file to open. + kwargs : Any + Additional keyword arguments passed to ``xr.open_dataset`` + """ ds = xr.open_dataset(filename, **kwargs) obj = cls() @@ -141,8 +152,17 @@ def from_netcdf(cls, filename, **kwargs): return obj def to_netcdf(self, filename, **kwargs): - params = self.params() + """save params to a netCDF file + Parameters + ---------- + filename : str + Name of the netCDF file to save. + kwargs : Any + Additional keyword arguments passed to ``xr.Dataset.to_netcf`` + """ + + params = self.params() params.to_netcdf(filename, **kwargs) diff --git a/tests/integration/test_linear_regression.py b/tests/integration/test_linear_regression.py index adccf975..58a0e10f 100644 --- a/tests/integration/test_linear_regression.py +++ b/tests/integration/test_linear_regression.py @@ -12,6 +12,8 @@ def LinearRegression_fit_wrapper(*args, **kwargs): # wrapper for LinearRegression().fit() because it has no return value - should it? + # -> no: a class method should either change state or have a return value, it's a + # bit awkward for testing but better overall lr = mesmer.core.linear_regression.LinearRegression() From d772b60bbfcea62438f67b2e010a5d0a6b989ae0 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Sun, 20 Mar 2022 19:36:14 +0100 Subject: [PATCH 3/3] Apply suggestions from code review --- CHANGELOG.rst | 2 +- mesmer/core/linear_regression.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 530464a6..15f2a4e8 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -7,7 +7,7 @@ v0.9.0 - unreleased New Features ^^^^^^^^^^^^ - Create :py:class:`mesmer.core.linear_regression.LinearRegression` which encapsulates - ``fit``, ``predict``, etc. methods aroung linear regression + ``fit``, ``predict``, etc. methods around linear regression (`#134 `_). By `Mathias Hauser `_. - Add ``mesmer.core.linear_regression``: xarray wrapper for ``mesmer.core._linear_regression``. diff --git a/mesmer/core/linear_regression.py b/mesmer/core/linear_regression.py index 2a57e9dc..25178554 100644 --- a/mesmer/core/linear_regression.py +++ b/mesmer/core/linear_regression.py @@ -7,7 +7,7 @@ class LinearRegression: - """Ordinary least squares Linear Regression for xarray.DataArray objects.""" + """Ordinary least squares Linear Regression for xr.DataArray objects.""" def __init__(self): self._params = None