Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add LinearRegression class #134

Merged
merged 4 commits into from Mar 20, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.rst
Expand Up @@ -6,7 +6,10 @@ v0.9.0 - unreleased

New Features
^^^^^^^^^^^^

- Create :py:class:`mesmer.core.linear_regression.LinearRegression` which encapsulates
``fit``, ``predict``, etc. methods aroung linear regression
mathause marked this conversation as resolved.
Show resolved Hide resolved
(`#134 <https://github.com/MESMER-group/mesmer/pull/134>`_).
By `Mathias Hauser <https://github.com/mathause>`_.
- Add ``mesmer.core.linear_regression``: xarray wrapper for ``mesmer.core._linear_regression``.
(`#123 <https://github.com/MESMER-group/mesmer/pull/123>`_).
By `Mathias Hauser <https://github.com/mathause>`_.
Expand Down
8 changes: 6 additions & 2 deletions docs/source/api.rst
Expand Up @@ -16,8 +16,12 @@ Statistical core functions
.. autosummary::
:toctree: generated/

~core.linear_regression.linear_regression

~core.linear_regression.LinearRegression
~core.linear_regression.LinearRegression.fit
~core.linear_regression.LinearRegression.predict
~core.linear_regression.LinearRegression.residuals
~core.linear_regression.LinearRegression.to_netcdf
~core.linear_regression.LinearRegression.from_netcdf

Train mesmer
------------
Expand Down
6 changes: 3 additions & 3 deletions docs/source/development.rst
Expand Up @@ -79,7 +79,7 @@ We include links with each of these tools to starting points that we think are u

- `Jupyter Notebooks <https://medium.com/codingthesmartway-com-blog/getting-started-with-jupyter-notebook-for-python-4e7082bd5d46>`_

- Jupyter is automatically included in your virtual environment if you follow our `Getting setup`_ instructions
- Jupyter is automatically included in your virtual environment if you follow our `Development setup`_ instructions

- Sphinx_

Expand Down Expand Up @@ -124,15 +124,15 @@ We use the following tools:
- `flake8 <https://flake8.pycqa.org/en/latest/>`_ to check the format and small errors

These automatically format the code for us and tell use where the errors are.
To use them, after setting yourself up (see `Getting setup`_), simply run ``make format``.
To use them, after setting yourself up (see `Development setup`_), simply run ``make format``.
Note that ``make format`` can only be run if you have committed all your work i.e. your working directory is 'clean'.
This restriction is made to ensure that you don't format code without being able to undo it, just in case something goes wrong.


Buiding the docs
----------------

After setting yourself up (see `Getting setup`_), building the docs is as simple as running ``make docs`` (note, run ``make -B docs`` to force the docs to rebuild and ignore make when it says '... index.html is up to date').
After setting yourself up (see `Development setup`_), building the docs is as simple as running ``make docs`` (note, run ``make -B docs`` to force the docs to rebuild and ignore make when it says '... index.html is up to date').
This will build the docs for you.
You can preview them by opening ``docs/build/html/index.html`` in a browser.

Expand Down
171 changes: 169 additions & 2 deletions mesmer/core/linear_regression.py
Expand Up @@ -2,9 +2,168 @@

import numpy as np
import xarray as xr
from sklearn.linear_model import LinearRegression

from .utils import _check_dataarray_form
from .utils import _check_dataarray_form, _check_dataset_form


class LinearRegression:
"""Ordinary least squares Linear Regression for xarray.DataArray objects."""
mathause marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self):
self._params = None

def fit(
self,
predictors: Mapping[str, xr.DataArray],
target: xr.DataArray,
dim: str,
weights: Optional[xr.DataArray] = None,
):
"""
Fit a linear model

Parameters
----------
predictors : dict of xr.DataArray
A dict of DataArray objects used as predictors. Must be 1D and contain
`dim`.

target : xr.DataArray
Target DataArray. Must be 2D and contain `dim`.

dim : str
Dimension along which to fit the polynomials.

weights : xr.DataArray, default: None.
Individual weights for each sample. Must be 1D and contain `dim`.
"""

params = linear_regression(
predictors=predictors,
target=target,
dim=dim,
weights=weights,
)

self._params = params

def predict(
self,
predictors: Mapping[str, xr.DataArray],
):
"""
Predict using the linear model.

Parameters
----------
predictors : dict of xr.DataArray
A dict of DataArray objects used as predictors. Must be 1D and contain `dim`.

Returns
-------
prediction : xr.DataArray
Returns predicted values.
"""

params = self.params

required_predictors = set(params.data_vars) - set(["intercept", "weights"])
available_predictors = set(predictors.keys())

if required_predictors != available_predictors:
raise ValueError("Missing or superflous predictors.")

prediction = params.intercept
for key in required_predictors:
prediction = prediction + predictors[key] * params[key]

return prediction

def residuals(
self,
predictors: Mapping[str, xr.DataArray],
target: xr.DataArray,
):
"""
Calculate the residuals of the fitted linear model

Parameters
----------
predictors : dict of xr.DataArray
A dict of DataArray objects used as predictors. Must be 1D and contain `dim`.

target : xr.DataArray
Target DataArray. Must be 2D and contain `dim`.

Returns
-------
residuals : xr.DataArray
Returns residuals - the difference between the predicted values and target.

"""

prediction = self.predict(predictors)

residuals = target - prediction

return residuals

@property
def params(self):
"""The parameters of this estimator."""

if self._params is None:
raise ValueError(
"'params' not set - call `fit` or assign them to "
"`LinearRegression().params`."
)

return self._params

@params.setter
def params(self, params):

_check_dataset_form(
params,
"params",
required_vars="intercept",
optional_vars="weights",
requires_other_vars=True,
)

self._params = params

@classmethod
def from_netcdf(cls, filename, **kwargs):
"""read params from a netCDF file

Parameters
----------
filename : str
Name of the netCDF file to open.
kwargs : Any
Additional keyword arguments passed to ``xr.open_dataset``
"""
ds = xr.open_dataset(filename, **kwargs)

obj = cls()
obj.params = ds

return obj

def to_netcdf(self, filename, **kwargs):
"""save params to a netCDF file

Parameters
----------
filename : str
Name of the netCDF file to save.
kwargs : Any
Additional keyword arguments passed to ``xr.Dataset.to_netcf``
"""

params = self.params()
params.to_netcdf(filename, **kwargs)


def linear_regression(
Expand Down Expand Up @@ -40,6 +199,11 @@ def linear_regression(
if not isinstance(predictors, Mapping):
raise TypeError(f"predictors should be a dict, got {type(predictors)}.")

if ("weights" in predictors) or ("intercept" in predictors):
raise ValueError(
"A predictor with the name 'weights' or 'intercept' is not allowed"
)

for key, pred in predictors.items():
_check_dataarray_form(pred, ndim=1, required_dims=dim, name=f"predictor: {key}")

Expand Down Expand Up @@ -100,6 +264,9 @@ def _linear_regression(predictors, target, weights=None):
followed by the intercept for each predictor (in the same order as the
columns of ``predictors``).
"""

from sklearn.linear_model import LinearRegression

reg = LinearRegression()
reg.fit(X=predictors, y=target, sample_weight=weights)

Expand Down
103 changes: 87 additions & 16 deletions mesmer/core/utils.py
Expand Up @@ -3,30 +3,101 @@
import xarray as xr


def _to_set(arg):

if arg is None:
arg = {}

if isinstance(arg, str):
arg = {arg}

arg = set(arg)

return arg


def _check_dataset_form(
obj,
name: str = "obj",
*,
required_vars: Union[str, Set[str]] = set(),
optional_vars: Union[str, Set[str]] = set(),
requires_other_vars: bool = False,
):
"""check if a dataset conforms to some conditions

obj: Any
object to check.
name : str, default: 'obj'
Name to use in error messages.
required_vars, str, set of str, optional
Variables that obj is required to contain.
optional_vars: str, set of str, optional
Variables that the obj may contain, only
relevant if `requires_other_vars` is True
requires_other_vars: bool, default: False
obj is required to contain other variables than
required_vars or optional_vars

Raises
------
TypeError: if obj is not a xr.Dataset
ValueError: if any of the conditions is violated

"""

required_vars = _to_set(required_vars)
optional_vars = _to_set(optional_vars)

if not isinstance(obj, xr.Dataset):
raise TypeError(f"Expected {name} to be an xr.Dataset, got {type(obj)}")

data_vars = set(obj.data_vars)

missing_vars = required_vars - data_vars
if missing_vars:
missing_vars = ",".join(missing_vars)
raise ValueError(f"{name} is missing the required data_vars: {missing_vars}")

n_vars_except = len(data_vars - (required_vars | optional_vars))
if requires_other_vars and n_vars_except == 0:

raise ValueError(f"Expected additional variables on {name}")


def _check_dataarray_form(
da: xr.DataArray,
name: str = None,
obj,
name: str = "obj",
*,
ndim: int = None,
required_dims: Union[str, Set[str]] = {},
required_dims: Union[str, Set[str]] = set(),
):
"""check if a dataset conforms to some conditions

if name is None:
name = "da"
obj: Any
object to check.
name : str, default: 'obj'
Name to use in error messages.
ndim, int, optional
Number of required dimensions
required_dims: str, set of str, optional
Names of dims that are required for obj

if isinstance(required_dims, str):
required_dims = {required_dims}
Raises
------
TypeError: if obj is not a xr.DataArray
ValueError: if any of the conditions is violated

required_dims = set(required_dims)
"""

if required_dims is None:
required_dims = {}
required_dims = _to_set(required_dims)

if not isinstance(da, xr.DataArray):
raise TypeError(f"Expected {name} to be an xr.DataArray, got {type(da)}")
if not isinstance(obj, xr.DataArray):
raise TypeError(f"Expected {name} to be an xr.DataArray, got {type(obj)}")

if ndim is not None and ndim != da.ndim:
raise ValueError(f"{name} should be {ndim}-dimensional, but is {da.ndim}D")
if ndim is not None and ndim != obj.ndim:
raise ValueError(f"{name} should be {ndim}-dimensional, but is {obj.ndim}D")

if required_dims - set(da.dims):
missing_dims = " ,".join(required_dims - set(da.dims))
if required_dims - set(obj.dims):
missing_dims = " ,".join(required_dims - set(obj.dims))
raise ValueError(f"{name} is missing the required dims: {missing_dims}")