Skip to content

Commit

Permalink
Merge pull request #106 from antoinedemathelin/master
Browse files Browse the repository at this point in the history
feat: RegularTransfer Gaussian Process
  • Loading branch information
antoinedemathelin committed Jun 18, 2023
2 parents 324930c + 62384af commit ff5a6e3
Show file tree
Hide file tree
Showing 9 changed files with 380 additions and 7 deletions.
2 changes: 1 addition & 1 deletion adapt/instance_based/_kliep.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def __init__(self,
max_centers=100,
cv=5,
algo="FW",
lr=np.logspace(-3,1,5),
lr=[0.001, 0.01, 0.1, 1.0, 10.0],
tol=1e-6,
max_iter=2000,
copy=True,
Expand Down
3 changes: 2 additions & 1 deletion adapt/parameter_based/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Parameter-Based Methods Module
"""

from ._regular import RegularTransferLR, RegularTransferLC, RegularTransferNN
from ._regular import RegularTransferLR, RegularTransferLC, RegularTransferNN, RegularTransferGP
from ._finetuning import FineTuning
from ._transfer_tree import TransferTreeClassifier
from ._transfer_tree import TransferForestClassifier
Expand All @@ -13,6 +13,7 @@
__all__ = ["RegularTransferLR",
"RegularTransferLC",
"RegularTransferNN",
"RegularTransferGP",
"FineTuning",
"TransferTreeClassifier",
"TransferForestClassifier",
Expand Down
186 changes: 183 additions & 3 deletions adapt/parameter_based/_regular.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import numpy as np
from sklearn.preprocessing import LabelBinarizer
from scipy.sparse.linalg import lsqr
from sklearn.gaussian_process import GaussianProcessRegressor, GaussianProcessClassifier
from sklearn.linear_model import LinearRegression
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Flatten, Dense
Expand Down Expand Up @@ -186,7 +188,8 @@ def fit(self, Xt=None, yt=None, **fit_params):

if yt_ndim_below_one_:
self.coef_ = self.coef_.reshape(-1)
self.intercept_ = self.intercept_[0]
if self.estimator_.fit_intercept:
self.intercept_ = self.intercept_[0]

self.estimator_.coef_ = self.coef_
if self.estimator_.fit_intercept:
Expand Down Expand Up @@ -267,7 +270,11 @@ def fit(self, Xt=None, yt=None, **fit_params):
Xt, yt = check_arrays(Xt, yt)

_label_binarizer = LabelBinarizer(pos_label=1, neg_label=-1)
yt = _label_binarizer.fit_transform(yt)
_label_binarizer.fit(self.estimator.classes_)
yt = _label_binarizer.transform(yt)

print(yt.shape)

return super().fit(Xt, yt, **fit_params)


Expand Down Expand Up @@ -467,4 +474,177 @@ def predict_disc(self, X):
"""
Not used.
"""
pass
pass


@make_insert_doc(supervised=True)
class RegularTransferGP(BaseAdaptEstimator):
"""
Regular Transfer with Gaussian Process
RegularTransferGP is a parameter-based domain adaptation method.
The method is based on the assumption that a good target estimator
can be obtained by adapting the parameters of a pre-trained source
estimator using a few labeled target data.
The approach consist in fitting the `alpha` coeficients of a
Gaussian Process estimator on target data according to an
objective function regularized by the euclidean distance between
the source and target `alpha`:
.. math::
\\alpha_T = \\underset{\\alpha \in \\mathbb{R}^n}{\\text{argmin}}
\\, ||K_{TS} \\alpha - y_T||^2 + \\lambda ||\\alpha - \\alpha_S||^2
Where:
- :math:`\\alpha_T` are the target model coeficients.
- :math:`\\alpha_S = \\underset{\\alpha \\in \\mathbb{R}^n}{\\text{argmin}}
\\, ||K_{SS} \\alpha - y_S||^2` are the source model coeficients.
- :math:`y_S, y_T` are respectively the source and
the target labels.
- :math:`K_{SS}` is the pariwise kernel distance matrix between source
input data.
- :math:`K_{TS}` is the pariwise kernel distance matrix between target
and source input data.
- :math:`n` is the number of source data in :math:`X_S`
- :math:`\\lambda` is a trade-off parameter. The larger :math:`\\lambda`
the closer the target model will be from the source model.
The ``estimator`` given to ``RegularTransferGP`` should be from classes
``sklearn.gaussian_process.GaussianProcessRegressor`` or
``sklearn.gaussian_process.GaussianProcessClassifier``
Parameters
----------
lambda_ : float (default=1.0)
Trade-Off parameter. For large ``lambda_``, the
target model will be similar to the source model.
Attributes
----------
estimator_ : Same class as estimator
Fitted Estimator.
Examples
--------
>>> from sklearn.gaussian_process import GaussianProcessRegressor
>>> from sklearn.gaussian_process.kernels import Matern, WhiteKernel
>>> from adapt.utils import make_regression_da
>>> from adapt.parameter_based import RegularTransferGP
>>> Xs, ys, Xt, yt = make_regression_da()
>>> kernel = Matern() + WhiteKernel()
>>> src_model = GaussianProcessRegressor(kernel)
>>> src_model.fit(Xs, ys)
>>> print(src_model.score(Xt, yt))
-2.3409379221035382
>>> tgt_model = RegularTransferGP(src_model, lambda_=1.)
>>> tgt_model.fit(Xt[:3], yt[:3])
>>> tgt_model.score(Xt, yt)
-0.21947435769240653
See also
--------
RegularTransferLR, RegularTransferNN
References
----------
.. [1] `[1] <https://www.microsoft.com/en-us/research/wp-\
content/uploads/2004/07/2004-chelba-emnlp.pdf>`_ C. Chelba and \
A. Acero. "Adaptation of maximum entropy classifier: Little data \
can help a lot". In EMNLP, 2004.
"""

def __init__(self,
estimator=None,
Xt=None,
yt=None,
lambda_=1.,
copy=True,
verbose=1,
random_state=None,
**params):

if not hasattr(estimator, "kernel_"):
raise ValueError("`estimator` argument has no ``kernel_`` attribute, "
"please call `fit` on `estimator` or use "
"another estimator as `GaussianProcessRegressor` or "
"`GaussianProcessClassifier`.")

estimator = check_fitted_estimator(estimator)

names = self._get_param_names()
kwargs = {k: v for k, v in locals().items() if k in names}
kwargs.update(params)
super().__init__(**kwargs)


def fit(self, Xt=None, yt=None, **fit_params):
"""
Fit RegularTransferGP.
Parameters
----------
Xt : numpy array (default=None)
Target input data.
yt : numpy array (default=None)
Target output data.
fit_params : key, value arguments
Not used. Here for sklearn compatibility.
Returns
-------
self : returns an instance of self
"""
Xt, yt = self._get_target_data(Xt, yt)
Xt, yt = check_arrays(Xt, yt)
set_random_seed(self.random_state)

self.estimator_ = check_estimator(self.estimator,
copy=self.copy,
force_copy=True)

if isinstance(self.estimator, GaussianProcessRegressor):
src_linear_model = LinearRegression(fit_intercept=False)
src_linear_model.coef_ = self.estimator_.alpha_.transpose()

Kt = self.estimator_.kernel_(Xt, self.estimator_.X_train_)
tgt_linear_model = RegularTransferLR(src_linear_model, lambda_=self.lambda_)

tgt_linear_model.fit(Kt, yt)

self.estimator_.alpha_ = np.copy(tgt_linear_model.coef_).transpose()

elif isinstance(self.estimator, GaussianProcessClassifier):

if hasattr(self.estimator_.base_estimator_, "estimators_"):
for i in range(len(self.estimator_.base_estimator_.estimators_)):
c = self.estimator_.classes_[i]
if sum(yt == c) > 0:
yt_c = np.zeros(yt.shape[0])
yt_c[yt == c] = 1
self.estimator_.base_estimator_.estimators_[i] = self._fit_one_vs_one_classifier(
self.estimator_.base_estimator_.estimators_[i], Xt, yt_c)

else:
self.estimator_.base_estimator_ = self._fit_one_vs_one_classifier(
self.estimator_.base_estimator_, Xt, yt)
return self


def _fit_one_vs_one_classifier(self, estimator, Xt, yt):
src_linear_model = LinearRegression(fit_intercept=False)
src_linear_model.coef_ = (estimator.y_train_ - estimator.pi_)
src_linear_model.classes_ = estimator.classes_
Kt = estimator.kernel_(Xt, estimator.X_train_)

tgt_linear_model = RegularTransferLC(src_linear_model, lambda_=self.lambda_)

tgt_linear_model.fit(Kt, yt)

estimator.pi_ = (estimator.y_train_ - np.copy(tgt_linear_model.coef_).ravel())
return estimator
3 changes: 3 additions & 0 deletions src_docs/_templates/layout.html
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,12 @@
<li class="toctree-l2"><a class="reference internal" href="{{ pathto("generated/adapt.parameter_based.RegularTransferLR") }}">RegularTransferLR</a></li>
<li class="toctree-l2"><a class="reference internal" href="{{ pathto("generated/adapt.parameter_based.RegularTransferLC") }}">RegularTransferLC</a></li>
<li class="toctree-l2"><a class="reference internal" href="{{ pathto("generated/adapt.parameter_based.RegularTransferNN") }}">RegularTransferNN</a></li>
<li class="toctree-l2"><a class="reference internal" href="{{ pathto("generated/adapt.parameter_based.RegularTransferGP") }}">RegularTransferGP</a></li>
<li class="toctree-l2"><a class="reference internal" href="{{ pathto("generated/adapt.parameter_based.FineTuning") }}">FineTuning</a></li>
<li class="toctree-l2"><a class="reference internal" href="{{ pathto("generated/adapt.parameter_based.TransferTreeClassifier") }}">TransferTreeClassifier</a></li>
<li class="toctree-l2"><a class="reference internal" href="{{ pathto("generated/adapt.parameter_based.TransferForestClassifier") }}">TransferForestClassifier</a></li>
<li class="toctree-l2"><a class="reference internal" href="{{ pathto("generated/adapt.parameter_based.TransferTreeSelector") }}">TransferTreeSelector</a></li>
<li class="toctree-l2"><a class="reference internal" href="{{ pathto("generated/adapt.parameter_based.TransferForestSelector") }}">TransferForestSelector</a></li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="{{ pathto("contents") }}{{ contents }}{{ "adapt-metrics" }}">Metrics</a><ul>
Expand Down
3 changes: 3 additions & 0 deletions src_docs/contents.rst
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,12 @@ the **source** data are adapted to build a suited model for the **task** on the
parameter_based.RegularTransferLR
parameter_based.RegularTransferLC
parameter_based.RegularTransferNN
parameter_based.RegularTransferGP
parameter_based.FineTuning
parameter_based.TransferTreeClassifier
parameter_based.TransferForestClassifier
parameter_based.TransferTreeSelector
parameter_based.TransferForestSelector


.. _adapt.metrics:
Expand Down
47 changes: 47 additions & 0 deletions src_docs/generated/adapt.parameter_based.RegularTransferGP.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
:ref:`adapt.parameter_based <adapt.parameter_based>`.RegularTransferGP
===========================================================================

.. currentmodule:: adapt.parameter_based

.. autoclass:: RegularTransferGP
:no-members:
:no-inherited-members:
:no-special-members:




.. rubric:: Methods

.. autosummary::

~RegularTransferGP.__init__
~RegularTransferGP.fit
~RegularTransferGP.fit_estimator
~RegularTransferGP.get_params
~RegularTransferGP.predict
~RegularTransferGP.predict_estimator
~RegularTransferGP.score
~RegularTransferGP.set_params
~RegularTransferGP.unsupervised_score


.. automethod:: __init__
.. automethod:: fit
.. automethod:: fit_estimator
.. automethod:: get_params
.. automethod:: predict
.. automethod:: predict_estimator
.. automethod:: score
.. automethod:: set_params
.. automethod:: unsupervised_score




.. raw:: html

<h2> Examples </h2>

.. include:: ../gallery/RegularTransferGP.rst

Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
:ref:`adapt.parameter_based <adapt.parameter_based>`.TransferForestSelector
================================================================================

.. currentmodule:: adapt.parameter_based

.. autoclass:: TransferForestSelector
:no-members:
:no-inherited-members:
:no-special-members:




.. rubric:: Methods

.. autosummary::

~TransferForestSelector.__init__
~TransferForestSelector.fit
~TransferForestSelector.fit_estimator
~TransferForestSelector.get_params
~TransferForestSelector.model_selection
~TransferForestSelector.predict
~TransferForestSelector.predict_estimator
~TransferForestSelector.score
~TransferForestSelector.set_params
~TransferForestSelector.unsupervised_score


.. automethod:: __init__
.. automethod:: fit
.. automethod:: fit_estimator
.. automethod:: get_params
.. automethod:: model_selection
.. automethod:: predict
.. automethod:: predict_estimator
.. automethod:: score
.. automethod:: set_params
.. automethod:: unsupervised_score




.. raw:: html

<h2> Examples </h2>

.. include:: ../gallery/TransferForestSelector.rst

0 comments on commit ff5a6e3

Please sign in to comment.