In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import Lasso
from sklearn.tree import DecisionTreeRegressor
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler

In [None]:
# Generate some data.
X = np.linspace(-10, 10, 21) + np.random.normal(size=21)
X = X.reshape(-1,1)
y = np.linspace(-10, 10, 21)

plt.figure()
plt.scatter(X, y)
plt.show()

## Using class decorator

#### Decorater definition

In [None]:
def non_negative_decorate(Cls):
    class NonNegativeModel(Cls):
        def __init__(self, **kwargs):
            super().__init__(**kwargs)
            
        def predict(self, X, y=None):
            preds = super().predict(X)
            preds[preds < 0] = 0
            return preds
            
    return NonNegativeModel

#### Standalone usage

In [None]:
DecoratedLasso = non_negative_decorate(Lasso)
lasso_preds = DecoratedLasso(alpha=0.1, random_state=0).fit(X, y).predict(X)

DecoratedTree = non_negative_decorate(DecisionTreeRegressor)
tree_preds = DecoratedTree(max_depth=4).fit(X, y).predict(X)

plt.figure()
plt.scatter(X, y, label='True')
plt.scatter(X, lasso_preds, label='Non-negative Lasso predictions')
plt.scatter(X, tree_preds, label='Non-negative DT predictions')
plt.legend()
plt.show()

#### Usage in pipeline

In [None]:
DecoratedLasso = non_negative_decorate(Lasso)

ppl = make_pipeline(StandardScaler(), DecoratedLasso(alpha=0.1, random_state=0))
ppl.fit(X, y)
preds = ppl.predict(X)
preds

## Inherit and modify .predict()

#### Class definiton

In [None]:
class NonNegativeLasso(Lasso):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def predict(self, X, y=None):
        preds = super().predict(X)
        preds[preds < 0] = 0
        return preds

#### Standalone usage

In [None]:
reg = NonNegativeLasso(alpha=0.1, random_state=0)
reg.fit(X, y)
preds = reg.predict(X)
preds

In [None]:
plt.figure()
plt.scatter(X, y, label='True')
plt.scatter(X, preds, label='Non-negative Lasso predictions')
plt.legend()
plt.show()

#### Usage in pipeline

In [None]:
ppl = make_pipeline(StandardScaler(), NonNegativeLasso(alpha=0.1, random_state=0))
ppl.fit(X, y)
ppl.predict(X)