Skip to content

Commit

Permalink
change _fitter to lifelines_model
Browse files Browse the repository at this point in the history
  • Loading branch information
CamDavidsonPilon committed Jun 6, 2019
1 parent 9e6d817 commit 84466e5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
14 changes: 7 additions & 7 deletions docs/Compatibility with scikit-learn.rst
Expand Up @@ -25,23 +25,23 @@ New to lifelines in version 0.21.3 is a wrapper that allows you to use lifeline'
CoxRegression = sklearn_adapter(CoxPHFitter, event_col='arrest')
# CoxRegression is a class like the `LinearRegression` class or `SVC` class in scikit-learn
cph = CoxRegression(penalizer=1.0)
cph.fit(X, Y)
print(cph)
sk_cph = CoxRegression(penalizer=1.0)
sk_cph.fit(X, Y)
print(sk_cph)
"""
CoxPHFitter(alpha=0.05, penalizer=1.0, strata=None, tie_method='Efron')
"""
cph.predict(X)
cph.score(X, Y)
sk_cph.predict(X)
sk_cph.score(X, Y)
If needed, the original lifeline's instance is available as the ``_fitter`` attribute.
If needed, the original lifeline's instance is available as the ``lifelines_model`` attribute.

.. code:: python
cph._fitter.print_summary()
sk_cph.lifelines_model.print_summary()
Expand Down
16 changes: 8 additions & 8 deletions lifelines/utils/sklearn_adapter.py
Expand Up @@ -20,7 +20,7 @@ def filter_kwargs(f, kwargs):
class _SklearnModel(BaseEstimator, MetaEstimatorMixin, RegressorMixin):
def __init__(self, **kwargs):
self._params = kwargs
self._fitter = self._fitter(**filter_kwargs(self._fitter.__init__, self._params))
self.lifelines_model = self.lifelines_model(**filter_kwargs(self.lifelines_model.__init__, self._params))

self._params["duration_col"] = "duration_col"
self._params["event_col"] = self._event_col
Expand Down Expand Up @@ -51,19 +51,19 @@ def fit(self, X, y=None, sample_weight=None):
if y is not None:
X.insert(len(X.columns), self._yColumn, y, allow_duplicates=False)

fit = getattr(self._fitter, self._fit_method)
self._fitter = fit(df=X, **filter_kwargs(fit, self._params))
fit = getattr(self.lifelines_model, self._fit_method)
self.lifelines_model = fit(df=X, **filter_kwargs(fit, self._params))
return self

def set_params(self, **params):
for key, value in params.items():
setattr(self._fitter, key, value)
setattr(self.lifelines_model, key, value)
return self

def get_params(self, deep=True):
out = {}
for k in inspect.signature(self._fitter.__init__).parameters:
out[k] = getattr(self._fitter, k)
for k in inspect.signature(self.lifelines_model.__init__).parameters:
out[k] = getattr(self.lifelines_model, k)
return out

def predict(self, X):
Expand All @@ -73,7 +73,7 @@ def predict(self, X):
X: DataFrame or numpy array
"""
return getattr(self._fitter, self._predict_method)(X)[0].values
return getattr(self.lifelines_model, self._predict_method)(X)[0].values

def score(self, X, y, sample_weight=None):
"""
Expand Down Expand Up @@ -117,7 +117,7 @@ class that can be instantiated with parameters (similar to a scikit-learn class)
klass = type(
name,
(_SklearnModel,),
{"_fitter": fitter, "_event_col": event_col, "_predict_method": predict_method, "_fit_method": "fit"},
{"lifelines_model": fitter, "_event_col": event_col, "_predict_method": predict_method, "_fit_method": "fit"},
)
globals()[klass.__name__] = klass
return klass

0 comments on commit 84466e5

Please sign in to comment.