In [9]:
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn      import datasets
import numpy      as np
from sklearn.neighbors import KernelDensity
from sklearn.grid_search import GridSearchCV

In [15]:
class KDEClassifier(BaseEstimator, ClassifierMixin):
    
    """Bayesian generative classification based on KDE
    
    Parameters
    ----------
    bandwidth : float
        the kernel bandwidth within each class
    kernel : str
        the kernel name, passed to KernelDensity
    """
    
    def __init__(self, bandwidth=1.0, kernel='gaussian'):
        self.bandwidth = bandwidth
        self.kernel    = kernel
        
    def fit(self, X, y):
        
        self.classes_   = np.sort(np.unique(y))
        training_sets   = [X[y == yi] for yi in self.classes_]
        self.models_    = [KernelDensity(bandwidth=self.bandwidth,
                                      kernel=self.kernel).fit(Xi)
                           for Xi in training_sets]
        
        self.logpriors_ = [np.log(Xi.shape[0] / X.shape[0])
                           for Xi in training_sets]

        return self
    
    def each_prior_prob(self,x,X,y,k):
        
        dist          = [np.linalg.norm(x-xi) for xi in  X]
        anchor_k      = sorted(dist)[k]
        dist_adj_k    = dist/anchor_k
        weight_k      = list(map(lambda x: 0 if x >= 1 else 1- x,dist_adj_k))
        indicator     = [y==i for i in set(y)]

        n             = len(set(y))
        indic_weight  = [indicator[i]*weight_k for i in range(n)]
        new           = np.hstack([np.expand_dims(indic_weight[0], axis=1), np.expand_dims(indic_weight[1], axis=1)])

        if n >2:
            for i in range(2,n):
                new       = np.hstack([new, np.expand_dims(indic_weight[i], axis=1)])
                
        prob_weight       = [sum(new[:,i]) for i in range(n)]
        weight_sum        = sum(weight_k)
        log_prob_prior    = np.log(list(map(lambda x: 0+0.000001 if x==0 else x/weight_sum,prob_weight)))
        
        return log_prob_prior
    
    
    def prior_prob_fit(X_test,X_train,y,k):
    
        n = len(set(y))
        prior_prob        = np.zeros((X_test.shape[0],n),dtype=np.float32)
    
        for i in range(X_test.shape[0]):
            prior_prob[i] = each_prior_prob(X_test[i],X_train,y_train,k)
            
        return prior_prob
    
        
        
    def predict_proba(self, X):
        
        logprobs = np.array([model.score_samples(X)
                             for model in self.models_]).T
        
        result = np.exp(logprobs + self.logpriors_)
       
        return result / result.sum(1, keepdims=True)
        
    def predict(self, X):
        return self.classes_[np.argmax(self.predict_proba(X), 1)]

In [16]:
from sklearn.model_selection import  train_test_split
iris = datasets.load_iris()

In [17]:
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.33, random_state=42)

In [18]:
bandwidths = 10 ** np.linspace(0, 2, 100)
grid       = GridSearchCV(KDEClassifier(), {'bandwidth': bandwidths})

In [19]:
grid.fit(X_train, y_train)

GridSearchCV(cv=None, error_score='raise',
       estimator=KDEClassifier(bandwidth=1.0, kernel='gaussian'),
       fit_params={}, iid=True, n_jobs=1,
       param_grid={'bandwidth': array([  1.     ,   1.04762, ...,  95.45485, 100.     ])},
       pre_dispatch='2*n_jobs', refit=True, scoring=None, verbose=0)

In [20]:
scores     = [val.mean_validation_score for val in grid.grid_scores_]

In [21]:
scores

[0.9,
 0.89,
 0.89,
 0.89,
 0.89,
 0.89,
 0.89,
 0.89,
 0.88,
 0.88,
 0.87,
 0.87,
 0.87,
 0.87,
 0.87,
 0.87,
 0.87,
 0.87,
 0.87,
 0.87,
 0.87,
 0.88,
 0.87,
 0.88,
 0.87,
 0.87,
 0.87,
 0.87,
 0.86,
 0.86,
 0.86,
 0.86,
 0.86,
 0.86,
 0.84,
 0.83,
 0.83,
 0.8,
 0.78,
 0.73,
 0.65,
 0.57,
 0.56,
 0.5,
 0.47,
 0.46,
 0.43,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42,
 0.42]