Skip to content

Commit

Permalink
linear regression: remove all dim coords on target (#334)
Browse files Browse the repository at this point in the history
* linear regression: remove all dim coords on target

* changelog
  • Loading branch information
mathause committed Nov 13, 2023
1 parent 60c15d2 commit 71e41f3
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 6 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.rst
Expand Up @@ -21,6 +21,11 @@ New Features
- Allow to pass 1-dimensional targets to :py:meth:`mesmer.stats.linear_regression.LinearRegression.fit`
(`#221 <https://github.com/MESMER-group/mesmer/pull/221>`_).
By `Mathias Hauser`_.
- Fixed two bugs related to (non-dimension) coordinates (
`#332 <https://github.com/MESMER-group/mesmer/issues/332>`_,
`#333 <https://github.com/MESMER-group/mesmer/issues/333>`_ and
`#334 <https://github.com/MESMER-group/mesmer/pull/313>`_).
By `Mathias Hauser`_.

- Extracted statistical functionality for auto regression:
- Add ``mesmer.stats.auto_regression._fit_auto_regression_xr``: xarray wrapper to fit an
Expand Down
10 changes: 5 additions & 5 deletions mesmer/stats/linear_regression.py
Expand Up @@ -246,16 +246,16 @@ def _fit_linear_regression_xr(
fit_intercept,
)

# remove (non-dimension) coords from target (#332, #333)
target = target.drop_vars(target[dim].coords)

# split `out` into individual DataArrays
keys = ["intercept"] + list(predictors)
dataarrays = {key: (target_dim, out[:, i]) for i, key in enumerate(keys)}
out = xr.Dataset(dataarrays, coords=target.coords)
data_vars = {key: (target_dim, out[:, i]) for i, key in enumerate(keys)}
out = xr.Dataset(data_vars, coords=target.coords)

out["fit_intercept"] = fit_intercept

if dim in out.coords:
out = out.drop_vars(dim)

if weights is not None:
out["weights"] = weights

Expand Down
49 changes: 48 additions & 1 deletion tests/unit/test_linear_regression.py
Expand Up @@ -222,7 +222,7 @@ def test_linear_regression_one_predictor(
):

pred0 = trend_data_1D(slope=1, scale=0)
trend_data_1D_or_2D

tgt = trend_data_1D_or_2D(as_2D=as_2D, slope=slope, scale=0, intercept=intercept)

result = lr_method_or_function({"pred0": pred0}, tgt, "time")
Expand All @@ -242,6 +242,53 @@ def test_linear_regression_one_predictor(
xr.testing.assert_allclose(result, expected)


@pytest.mark.parametrize("lr_method_or_function", LR_METHOD_OR_FUNCTION)
@pytest.mark.parametrize("as_2D", [True, False])
def test_linear_regression_predictor_named_like_dim(lr_method_or_function, as_2D):

slope, intercept = 0.3, 0.2
tgt = trend_data_1D_or_2D(as_2D=as_2D, slope=slope, scale=0, intercept=intercept)

result = lr_method_or_function({"time": tgt.time}, tgt, "time")
template = tgt.isel(time=0, drop=True)

expected_intercept = xr.full_like(template, intercept)
expected_time = xr.full_like(template, slope)

expected = xr.Dataset(
{
"intercept": expected_intercept,
"time": expected_time,
"fit_intercept": True,
}
)
xr.testing.assert_allclose(result, expected)


@pytest.mark.parametrize("lr_method_or_function", LR_METHOD_OR_FUNCTION)
@pytest.mark.parametrize("as_2D", [True, False])
def test_linear_regression_predictor_has_non_dim_coors(lr_method_or_function, as_2D):

slope, intercept = 0.3, 0.2
tgt = trend_data_1D_or_2D(as_2D=as_2D, slope=slope, scale=0, intercept=intercept)
tgt = tgt.assign_coords(year=("time", tgt.time.values + 1850))

result = lr_method_or_function({"pred0": tgt.time}, tgt, "time")
template = tgt.isel(time=0, drop=True)

expected_intercept = xr.full_like(template, intercept)
expected_pred0 = xr.full_like(template, slope)

expected = xr.Dataset(
{
"intercept": expected_intercept,
"pred0": expected_pred0,
"fit_intercept": True,
}
)
xr.testing.assert_allclose(result, expected)


@pytest.mark.parametrize("lr_method_or_function", LR_METHOD_OR_FUNCTION)
@pytest.mark.parametrize("as_2D", [True, False])
def test_linear_regression_fit_intercept(lr_method_or_function, as_2D):
Expand Down

0 comments on commit 71e41f3

Please sign in to comment.