In [11]:
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn      import datasets
import numpy as np
from sklearn.neighbors import KernelDensity

In [68]:
class KDEClassifier(BaseEstimator, ClassifierMixin):
    
    """Bayesian generative classification based on KDE
    
    Parameters
    ----------
    bandwidth : float
        the kernel bandwidth within each class
    kernel : str
        the kernel name, passed to KernelDensity
    """
    
    def __init__(self, bandwidth=1.0, kernel='gaussian'):
        self.bandwidth = bandwidth
        self.kernel = kernel
        
    def fit(self, X, y, k):
        
        self.classes_ = np.sort(np.unique(y))
        training_sets = [X[y == yi] for yi in self.classes_]
        print(training_sets)
        self.models_ = [KernelDensity(bandwidth=self.bandwidth,
                                      kernel=self.kernel).fit(Xi)
                        for Xi in training_sets]
        self.logpriors_ = [np.log(Xi.shape[0] / X.shape[0])
                           for Xi in training_sets]

        return self
    
    def distance(test,train):
        
        
        map(lamba x:  )
        
        
    def predict_proba(self, X):
        logprobs = np.array([model.score_samples(X)
                             for model in self.models_]).T
        print(logprobs.shape)
        print(logprobs)
        print(self.logpriors_)
        print(logprobs+self.logpriors_)
        result = np.exp(logprobs + self.logpriors_)
       
        return result / result.sum(1, keepdims=True)
        
    def predict(self, X):
        return self.classes_[np.argmax(self.predict_proba(X), 1)]

In [69]:
#y_pred = gnb.fit(iris.data, iris.target).predict(iris.data)
iris = datasets.load_iris()

In [70]:
kde  = KDEClassifier()
kde.fit(iris.data, iris.target)
kde.predict_proba(iris.data)

[array([[5.1, 3.5, 1.4, 0.2],
       [4.9, 3. , 1.4, 0.2],
       [4.7, 3.2, 1.3, 0.2],
       [4.6, 3.1, 1.5, 0.2],
       [5. , 3.6, 1.4, 0.2],
       [5.4, 3.9, 1.7, 0.4],
       [4.6, 3.4, 1.4, 0.3],
       [5. , 3.4, 1.5, 0.2],
       [4.4, 2.9, 1.4, 0.2],
       [4.9, 3.1, 1.5, 0.1],
       [5.4, 3.7, 1.5, 0.2],
       [4.8, 3.4, 1.6, 0.2],
       [4.8, 3. , 1.4, 0.1],
       [4.3, 3. , 1.1, 0.1],
       [5.8, 4. , 1.2, 0.2],
       [5.7, 4.4, 1.5, 0.4],
       [5.4, 3.9, 1.3, 0.4],
       [5.1, 3.5, 1.4, 0.3],
       [5.7, 3.8, 1.7, 0.3],
       [5.1, 3.8, 1.5, 0.3],
       [5.4, 3.4, 1.7, 0.2],
       [5.1, 3.7, 1.5, 0.4],
       [4.6, 3.6, 1. , 0.2],
       [5.1, 3.3, 1.7, 0.5],
       [4.8, 3.4, 1.9, 0.2],
       [5. , 3. , 1.6, 0.2],
       [5. , 3.4, 1.6, 0.4],
       [5.2, 3.5, 1.5, 0.2],
       [5.2, 3.4, 1.4, 0.2],
       [4.7, 3.2, 1.6, 0.2],
       [4.8, 3.1, 1.6, 0.2],
       [5.4, 3.4, 1.5, 0.4],
       [5.2, 4.1, 1.5, 0.1],
       [5.5, 4.2, 1.4, 0.2],
       [4.9, 

array([[9.86677064e-01, 1.32117523e-02, 1.11183873e-04],
       [9.81067224e-01, 1.87946390e-02, 1.38137073e-04],
       [9.87845151e-01, 1.20774116e-02, 7.74371697e-05],
       [9.80841522e-01, 1.90076479e-02, 1.50830493e-04],
       [9.88222880e-01, 1.16800210e-02, 9.70990588e-05],
       [9.75758380e-01, 2.38550092e-02, 3.86610421e-04],
       [9.86757518e-01, 1.31394161e-02, 1.03065412e-04],
       [9.82743644e-01, 1.71007141e-02, 1.55641470e-04],
       [9.82810369e-01, 1.70767822e-02, 1.12849150e-04],
       [9.80381702e-01, 1.94643450e-02, 1.53953380e-04],
       [9.84467205e-01, 1.53662138e-02, 1.66581494e-04],
       [9.80272798e-01, 1.95362219e-02, 1.90980296e-04],
       [9.83250615e-01, 1.66375594e-02, 1.11825648e-04],
       [9.92262665e-01, 7.70442518e-03, 3.29094690e-05],
       [9.92823563e-01, 7.11077225e-03, 6.56650271e-05],
       [9.88745755e-01, 1.10823437e-02, 1.71901367e-04],
       [9.89906796e-01, 9.99293659e-03, 1.00267077e-04],
       [9.85450536e-01, 1.44197