Skip to content

Commit

Permalink
linear regression: allow 1D targets (#221)
Browse files Browse the repository at this point in the history
  • Loading branch information
mathause committed Nov 16, 2022
1 parent fe14ca6 commit 4b68d26
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 32 deletions.
25 changes: 15 additions & 10 deletions CHANGELOG.rst
Expand Up @@ -7,16 +7,21 @@ v0.9.0 - unreleased
New Features
^^^^^^^^^^^^

- Create :py:class:`mesmer.stats.linear_regression.LinearRegression` which encapsulates
``fit``, ``predict``, etc. methods around linear regression
(`#134 <https://github.com/MESMER-group/mesmer/pull/134>`_).
By `Mathias Hauser <https://github.com/mathause>`_.
- Add ``mesmer.stats._fit_linear_regression_xr``: xarray wrapper for ``mesmer.stats._fit_linear_regression_np``.
(`#123 <https://github.com/MESMER-group/mesmer/pull/123>`_ and `#142 <https://github.com/MESMER-group/mesmer/pull/142>`_).
By `Mathias Hauser <https://github.com/mathause>`_.
- Add add ``fit_intercept`` argument to the ``linear_regression`` fitting methods and
functions (`#144 <https://github.com/MESMER-group/mesmer/pull/144>`_).
By `Mathias Hauser <https://github.com/mathause>`_.
- Create statistical functionality for linear regression:
- Create :py:class:`mesmer.stats.linear_regression.LinearRegression` which encapsulates
``fit``, ``predict``, etc. methods around linear regression
(`#134 <https://github.com/MESMER-group/mesmer/pull/134>`_).
By `Mathias Hauser <https://github.com/mathause>`_.
- Add ``mesmer.stats._fit_linear_regression_xr``: xarray wrapper for ``mesmer.stats._fit_linear_regression_np``.
(`#123 <https://github.com/MESMER-group/mesmer/pull/123>`_ and `#142 <https://github.com/MESMER-group/mesmer/pull/142>`_).
By `Mathias Hauser <https://github.com/mathause>`_.
- Add add ``fit_intercept`` argument to the ``linear_regression`` fitting methods and
functions (`#144 <https://github.com/MESMER-group/mesmer/pull/144>`_).
By `Mathias Hauser <https://github.com/mathause>`_.
- Allow to pass 1-dimensional targets to :py:meth:`mesmer.stats.linear_regression.LinearRegression.fit`
(`#221 <https://github.com/MESMER-group/mesmer/pull/221>`_).
By `Mathias Hauser <https://github.com/mathause>`_.

- Add ``mesmer.stats.auto_regression._fit_auto_regression_xr``: xarray wrapper to fit an
auto regression model (`#139 <https://github.com/MESMER-group/mesmer/pull/139>`_).
By `Mathias Hauser <https://github.com/mathause>`_.
Expand Down
13 changes: 10 additions & 3 deletions mesmer/stats/linear_regression.py
Expand Up @@ -230,7 +230,14 @@ def _fit_linear_regression_xr(
coords="minimal",
)

_check_dataarray_form(target, ndim=2, required_dims=dim, name="target")
_check_dataarray_form(target, required_dims=dim, name="target")

if target.ndim == 1:
# a 2D target array is required, extra dim is squeezed at the end
extra_dim = f"__{dim}__"
target = target.expand_dims(extra_dim)
elif target.ndim != 2:
raise ValueError(f"target should be 1D or 2D, but has {target.ndim}D")

# ensure `dim` is equal
xr.align(predictors_concat, target, join="exact")
Expand All @@ -239,7 +246,7 @@ def _fit_linear_regression_xr(
_check_dataarray_form(weights, ndim=1, required_dims=dim, name="weights")
xr.align(weights, target, join="exact")

target_dim = list(set(target.dims) - {dim})[0]
(target_dim,) = list(set(target.dims) - {dim})

out = _fit_linear_regression_np(
predictors_concat.transpose(dim, "predictor"),
Expand All @@ -261,7 +268,7 @@ def _fit_linear_regression_xr(
if weights is not None:
out["weights"] = weights

return out
return out.squeeze()


def _fit_linear_regression_np(predictors, target, weights=None, fit_intercept=True):
Expand Down
68 changes: 49 additions & 19 deletions tests/unit/test_linear_regression.py
Expand Up @@ -10,6 +10,13 @@
from mesmer.testing import trend_data_1D, trend_data_2D


def trend_data_1D_or_2D(as_2D, slope, scale, intercept):
if as_2D:
return trend_data_2D(slope=slope, scale=scale, intercept=intercept)

return trend_data_1D(slope=slope, scale=scale, intercept=intercept)


def LinearRegression_fit_wrapper(*args, **kwargs):
# wrapper for LinearRegression().fit() because it has no return value - should it?
# -> no: a class method should either change state or have a return value, it's a
Expand Down Expand Up @@ -54,27 +61,32 @@ def test_LR_params():
data_vars={
"intercept": ("x", [5]),
"fit_intercept": True,
"weights": ("y", [5]),
"weights": ("x", [5]),
}
)
with pytest.raises(ValueError, match="Expected additional variables"):
lr.params = ds

ds = xr.Dataset(
data_vars={"intercept": ("x", [5]), "fit_intercept": True, "tas": ("y", [5])}
data_vars={"intercept": ("x", [5]), "fit_intercept": True, "tas": ("x", [5])}
)
lr.params = ds

xr.testing.assert_equal(ds, lr.params)

ds = xr.Dataset(data_vars={"intercept": 5, "fit_intercept": True, "tas": 5})
lr.params = ds
xr.testing.assert_equal(ds, lr.params)


def test_LR_predict():
@pytest.mark.parametrize("as_2D", [True, False])
def test_LR_predict(as_2D):
lr = mesmer.stats.linear_regression.LinearRegression()

params = xr.Dataset(
data_vars={"intercept": ("x", [5]), "fit_intercept": True, "tas": ("x", [3])}
)
lr.params = params
lr.params = params if as_2D else params.squeeze()

with pytest.raises(ValueError, match="Missing or superflous predictors"):
lr.predict({})
Expand All @@ -86,25 +98,28 @@ def test_LR_predict():

result = lr.predict({"tas": tas})
expected = xr.DataArray([[5, 8, 11]], dims=("x", "time"))
expected = expected if as_2D else expected.squeeze()

xr.testing.assert_equal(result, expected)


def test_LR_residuals():
@pytest.mark.parametrize("as_2D", [True, False])
def test_LR_residuals(as_2D):

lr = mesmer.stats.linear_regression.LinearRegression()

params = xr.Dataset(
data_vars={"intercept": ("x", [5]), "fit_intercept": True, "tas": ("x", [0])}
)
lr.params = params
lr.params = params if as_2D else params.squeeze()

tas = xr.DataArray([0, 1, 2], dims="time")
target = xr.DataArray([[5, 8, 0]], dims=("x", "time"))

expected = xr.DataArray([[0, 3, -5]], dims=("x", "time"))
target = target if as_2D else target.squeeze()

result = lr.residuals({"tas": tas}, target)
expected = xr.DataArray([[0, 3, -5]], dims=("x", "time"))
expected = expected if as_2D else expected.squeeze()

xr.testing.assert_equal(expected, result)

Expand Down Expand Up @@ -166,13 +181,19 @@ def test_wrong_shape(pred0, pred1, tgt, weights, name, ndim):
test_wrong_shape(
pred0, pred1.expand_dims("new"), tgt, weights, name="predictor: pred1", ndim=1
)
test_wrong_shape(
pred0, pred1, tgt.expand_dims("new"), weights, name="target", ndim=2
)
test_wrong_shape(
pred0, pred1, tgt, weights.expand_dims("new"), name="weights", ndim=1
)

# target ndim test has a different error message
with pytest.raises(ValueError, match="target should be 1D or 2D"):
lr_method_or_function(
{"pred0": pred0, "pred1": pred1},
tgt.expand_dims("new"),
dim="time",
weights=weights,
)

def test_missing_dim(pred0, pred1, tgt, weights, name):
with pytest.raises(ValueError, match=f"{name} is missing the required dims"):
lr_method_or_function(
Expand All @@ -195,10 +216,14 @@ def test_missing_dim(pred0, pred1, tgt, weights, name):
@pytest.mark.parametrize("lr_method_or_function", LR_METHOD_OR_FUNCTION)
@pytest.mark.parametrize("intercept", (0, 3.14))
@pytest.mark.parametrize("slope", (0, 3.14))
def test_linear_regression_one_predictor(lr_method_or_function, intercept, slope):
@pytest.mark.parametrize("as_2D", [True, False])
def test_linear_regression_one_predictor(
lr_method_or_function, intercept, slope, as_2D
):

pred0 = trend_data_1D(slope=1, scale=0)
tgt = trend_data_2D(slope=slope, scale=0, intercept=intercept)
trend_data_1D_or_2D
tgt = trend_data_1D_or_2D(as_2D=as_2D, slope=slope, scale=0, intercept=intercept)

result = lr_method_or_function({"pred0": pred0}, tgt, "time")

Expand All @@ -218,10 +243,11 @@ def test_linear_regression_one_predictor(lr_method_or_function, intercept, slope


@pytest.mark.parametrize("lr_method_or_function", LR_METHOD_OR_FUNCTION)
def test_linear_regression_fit_intercept(lr_method_or_function):
@pytest.mark.parametrize("as_2D", [True, False])
def test_linear_regression_fit_intercept(lr_method_or_function, as_2D):

pred0 = trend_data_1D(slope=1, scale=0)
tgt = trend_data_2D(slope=1, scale=0, intercept=1)
tgt = trend_data_1D_or_2D(as_2D=as_2D, slope=1, scale=0, intercept=1)

result = lr_method_or_function({"pred0": pred0}, tgt, "time", fit_intercept=False)

Expand All @@ -241,12 +267,13 @@ def test_linear_regression_fit_intercept(lr_method_or_function):


@pytest.mark.parametrize("lr_method_or_function", LR_METHOD_OR_FUNCTION)
def test_linear_regression_no_coords(lr_method_or_function):
@pytest.mark.parametrize("as_2D", [True, False])
def test_linear_regression_no_coords(lr_method_or_function, as_2D):

slope, intercept = 3.14, 3.14

pred0 = trend_data_1D(slope=1, scale=0)
tgt = trend_data_2D(slope=slope, scale=0, intercept=intercept)
tgt = trend_data_1D_or_2D(as_2D=as_2D, slope=slope, scale=0, intercept=intercept)

# remove the coords
pred0 = pred0.drop_vars(pred0.coords.keys())
Expand All @@ -272,11 +299,14 @@ def test_linear_regression_no_coords(lr_method_or_function):
@pytest.mark.parametrize("lr_method_or_function", LR_METHOD_OR_FUNCTION)
@pytest.mark.parametrize("intercept", (0, 3.14))
@pytest.mark.parametrize("slope", (0, 3.14))
def test_linear_regression_two_predictors(lr_method_or_function, intercept, slope):
@pytest.mark.parametrize("as_2D", [True, False])
def test_linear_regression_two_predictors(
lr_method_or_function, intercept, slope, as_2D
):

pred0 = trend_data_1D(slope=1, scale=0)
pred1 = trend_data_1D(slope=1, scale=0)
tgt = trend_data_2D(slope=slope, scale=0, intercept=intercept)
tgt = trend_data_1D_or_2D(as_2D=as_2D, slope=slope, scale=0, intercept=intercept)

result = lr_method_or_function({"pred0": pred0, "pred1": pred1}, tgt, "time")

Expand Down

0 comments on commit 4b68d26

Please sign in to comment.