# Custom Estimators

In [None]:
from sklearn.base import BaseEstimator, TransformerMixin

class MyTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, first_paramter=1, second_parameter=2):
        # all parameters must be specified in the __init__ function
        self.first_paramter = 1
        self.second_parameter = 2
        
    def fit(self, X, y=None):
        # fit should only take X and y as parameters
        # even if your model is unsupervised, you need to accept a y argument!
        
        # Model fitting code goes here
        print("fitting the model right here")
        # fit returns self
        return self
    
    def transform(self, X):
        # transform takes as parameter only X
        
        # apply some transformation to X:
        X_transformed = X + 1
        return X_transformed


In [None]:
from sklearn.utils.validation import check_X_y

class MyEstimator(object):
    def __init__(self, my_parameter="stuff"):
        self.my_parameter = my_parameter
    def fit(self, X, y):
        X, y = check_X_y(X, y)
        return self
    def set_params(self, **kwargs):
        for key, value in kwargs:
            if key == "parameter":
                self.my_parameter = my_parameter
            else:
                raise ValueError("Unknown parameter %s" % key)
        return self
    def get_params(self, deep=None):
        return {'my_parameter': self.my_parameter}

In [None]:
est = MyEstimator(my_parameter="bla")
print(est)             

In [None]:
from sklearn.utils.estimator_checks import check_estimator
check_estimator(MyEstimator)

In [None]:
from sklearn.utils.validation import check_X_y, check_array

class MyBrokenEstimator(object):
    def __init__(self, my_parameter="stuff"):
        self.my_parameter = my_parameter + " more stuff"
    def fit(self, X, y):
        X, y = check_X_y(X, y)
        return self
    def set_params(self, **kwargs):
        for key, value in kwargs:
            if key == "parameter":
                self.my_parameter = my_parameter
            else:
                raise ValueError("Unknown parameter %s" % key)
        return self
    def get_params(self, deep=None):
        return {'my_parameter': self.my_parameter}

In [None]:
check_estimator(MyBrokenEstimator)

In [None]:
from sklearn.base import BaseEstimator

class MyInheritingEstimator(BaseEstimator):
    def __init__(self, my_parameter="stuff"):
        self.my_parameter = my_parameter
    def fit(self, X, y):
        X, y = check_X_y(X, y)
        return self

In [None]:
est = MyInheritingEstimator(my_parameter="bla")
print(est)         

In [None]:
check_estimator(MyInheritingEstimator)

In [None]:
from sklearn.base import TransformerMixin
class MyTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, my_parameter="stuff"):
        self.my_parameter = my_parameter
    def fit(self, X, y):
        X, y = check_X_y(X, y)
        self.n_features_ = X.shape[1]
        return self
    def transform(self, X):
        X = check_array(X)
        return X - 2

In [None]:
check_estimator(MyTransformer)

In [None]:
import numpy as np
from sklearn.base import ClassifierMixin

class MyBrokenClassifier(BaseEstimator, ClassifierMixin):
    def __init__(self, my_parameter="stuff"):
        self.my_parameter = my_parameter
    def fit(self, X, y):
        X, y = check_X_y(X, y)
        return self
    def predict(self, X):
        X = check_array(X)
        return np.array([1, 2])

In [None]:
check_estimator(MyBrokenClassifier)

# Exercise
- Reimplement a simple version of the standard scaler (that removes mean and scales to unit variance) with scikit-learn interface. Can you make it pass the tests? Does it give the same result as sklearn.preprocessing.StandardScaler?
- Reimplement a one nearest neighbor classifier with scikit-learn interface (that memorizes the training set and assignes a new test point to the class of the closest training point). Again, try making it pass the tests.

hint: use sklearn.utils.validation.check_is_fitted and sklearn.utils.validation.unique_labels (though you don't have to).

In [None]:
# %load solutions/custom_estimators.py