In [1]:
import numpy as np

###  Custom estimator

In [2]:
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted

from sklearn.mixture import GaussianMixture
from scipy.stats import multivariate_normal

In [3]:
class GaussianMixtureDiscriminantAnalysis(BaseEstimator, ClassifierMixin):
    def __init__(self, n_cmp = (2,2) , tol=1e-4 , max_iter=100 ):
        super().__init__()
        self.n_cmp = n_cmp
        self.tol = tol
        self.max_iter = max_iter

    def fit(self, X, y):
        self.X_, self.y_ = check_X_y(X, y)
        self.cls_ = [GaussianMixture(n_components=n_cmp, tol=self.tol, max_iter=self.max_iter) for n_cmp in self.n_cmp]
    
        for l, cmp in enumerate(self.cls_):
                cmp.fit(self.X_[self.y_==l])
        
        self.means_ = [cmp.means_ for cmp in self.cls_]
        self.covs_  = [cmp.covariances_ for cmp in self.cls_]
        self.pdfs_ = [[multivariate_normal(cmp.means_[i], cmp.covariances_[i]).pdf for i in range(cmp.n_components)] for cmp in self.cls_]
        
        return self  
    
    def predict_proba(self, X):
        check_is_fitted(self)
        p = np.zeros((len(X),2) )
        for k,cmp in enumerate(self.cls_):
            for i in range(cmp.n_components):
                p[:,k]+= cmp.weights_[i]*self.pdfs_[k][i](X)
           
        return p[:,1]/(p.sum(1))
        
    def predict(self, X):
        X = check_array(X)
        return (self.predict_proba(X)>0.5).astype('int64')
    

In [4]:
from sklearn.model_selection import train_test_split

In [5]:
half_circles = np.loadtxt('half_circles.txt')
hc_train, hc_test = train_test_split(half_circles, test_size=0.25, stratify=half_circles[:,2])

In [6]:
hc_train_labels = hc_train[:,2].astype('int32')
hc_train_data   = hc_train[:,:2]

In [7]:
hc_test_labels = hc_test[:,2].astype('int32')
hc_test_data   = hc_test[:,:2]

In [8]:
gmd = GaussianMixtureDiscriminantAnalysis(n_cmp=(2,2))

In [9]:
gmd.fit(hc_train_data, hc_train_labels)

GaussianMixtureDiscriminantAnalysis()

In [10]:
gmd.score(hc_train_data, hc_train_labels)

0.996

In [11]:
gmd.score(hc_test_data, hc_test_labels)

0.984

In [12]:
from sklearn.model_selection import GridSearchCV

In [13]:
param_grid = {'n_cmp': [[1,1],(2,2),(3,3), (4,4), (5,5)]}

In [14]:
grid_search = GridSearchCV(GaussianMixtureDiscriminantAnalysis(), 
                           param_grid=param_grid, n_jobs=6, cv=5, scoring='accuracy')

In [15]:
%%time
grid_search.fit(hc_train_data, hc_train_labels)

CPU times: user 246 ms, sys: 62.2 ms, total: 308 ms
Wall time: 1.06 s


GridSearchCV(cv=5, estimator=GaussianMixtureDiscriminantAnalysis(), n_jobs=6,
             param_grid={'n_cmp': [[1, 1], (2, 2), (3, 3), (4, 4), (5, 5)]},
             scoring='accuracy')

In [16]:
import pandas as pd

In [17]:
pd.DataFrame(grid_search.cv_results_).sort_values(by='rank_test_score')[['param_n_cmp','mean_test_score']]

Unnamed: 0,param_n_cmp,mean_test_score
2,"(3, 3)",0.997333
3,"(4, 4)",0.996
1,"(2, 2)",0.996
4,"(5, 5)",0.992
0,"[1, 1]",0.818667


In [18]:
grid_search.best_estimator_.score(hc_test_data, hc_test_labels)

0.996

In [19]:
from sklearn.model_selection import RandomizedSearchCV
from scipy.stats import randint

In [20]:
class random_pairs:
    def __init__(self,high1=2, high2=2):
        self.high1_ = high1
        self.high2_ = high2
    
    def rvs(self, size=1, random_state=None):
        r1 = randint(1,self.high1_).rvs(size=size, random_state=random_state)
        r2 = randint(1,self.high2_).rvs(size=size, random_state=None)
        if size>1:
            return list(zip(r1,r2))
        else:
            return r1[0], r2[0]
        

In [21]:
randomized_search = RandomizedSearchCV(GaussianMixtureDiscriminantAnalysis(max_iter=500), {'n_cmp': random_pairs(7,7)},n_iter=10, cv=5)

In [22]:
randomized_search.fit(hc_train_data, hc_train_labels);

In [23]:
pd.DataFrame(randomized_search.cv_results_).sort_values(['mean_test_score'], ascending=False)[['param_n_cmp','mean_test_score']]

Unnamed: 0,param_n_cmp,mean_test_score
0,"(2, 3)",0.993333
4,"(6, 6)",0.993333
8,"(6, 2)",0.993333
5,"(2, 6)",0.990667
9,"(2, 6)",0.990667
2,"(2, 5)",0.989333
3,"(4, 1)",0.966667
1,"(3, 1)",0.96
7,"(1, 2)",0.946667
6,"(1, 1)",0.818667
