In [1]:
from abc import ABCMeta, abstractmethod
from sklearn.base import BaseEstimator, ClassifierMixin
import numpy as np
from scipy.stats import norm
from sklearn.utils import shuffle
from math import sqrt
from sklearn.datasets import fetch_mldata
from sklearn.preprocessing import scale, LabelEncoder
from sklearn.model_selection import GridSearchCV, train_test_split

In [2]:
class BaseLinearOnline(BaseEstimator, metaclass=ABCMeta):
    @abstractmethod
    def __init__(self, C, eta, n_iter, random_state, shuffle):
        self.C = C
        self.eta = eta
        self.n_iter = n_iter
        self.random_state = random_state
        self.shuffle = shuffle
        
        self.phi = norm.cdf(self.eta)**(-1)

    def fit(self, X, Y):
        n_samples, n_features = X.shape
        self.mu = np.zeros(n_features).reshape(n_features, 1)
        self.sigma = np.diag([1.0] * n_features)
        
        for epoch in range(self.n_iter):
            if self.shuffle:
                X, Y = shuffle(X, Y, random_state=self.random_state)

            for i in range(n_samples):
                self._update(X[i:i + 1].T, Y[i:i + 1])
                
        return self

    @abstractmethod
    def _update(self, X, y):
        pass
    
    def predict(self, X):
        return np.sign(X @ self.mu)

In [3]:
class SCW1(BaseLinearOnline, ClassifierMixin):
    def __init__(self, C=1.0, eta=0.90, n_iter=3, random_state=0, shuffle=True):
        super(SCW1, self).__init__(C=C, eta=eta, n_iter=n_iter, random_state=random_state, shuffle=shuffle)
        
        self.psi = 1 + self.phi**2 / 2
        self.zeta = 1 + self.phi**2
        
    def _update(self, X, y):
        m = float(y * self.mu.T @ X)
        v = float(X.T @ (self.sigma @ X))
        alpha = min(self.C, max(0, (-m * self.psi + sqrt(m**2 * self.phi**4 / 4 + v * self.phi**2 * self.zeta)) / (v * self.zeta)))
        u = (-alpha * v * self.phi + sqrt(alpha** 2 * v**2 * self.phi**2 + 4 * v))**2 / 4
        beta = alpha * self.phi / (sqrt(u) + v * alpha * self.phi)
        
        self.mu = self.mu + alpha * y * self.sigma @ X
        self.sigma = self.sigma - beta * self.sigma @ X @ X.T @ self.sigma

In [4]:
class SCW2(BaseLinearOnline, ClassifierMixin):
    def __init__(self, C=1.0, eta=0.90, n_iter=3, random_state=0, shuffle=True):
        super(SCW2, self).__init__(C=C, eta=eta, n_iter=n_iter, random_state=random_state, shuffle=shuffle)
        
    def _update(self, X, y):
        m = float(y * self.mu.T @ X)
        v = float(X.T @ (self.sigma @ X))
        n = v + 1 / (2 * self.C)
        gamma = self.phi * sqrt(self.phi**2 * m**2 * v**2 + 4 * n * v * (n + v * self.phi**2))
        alpha = max(0, (-(2 * m * n + self.phi**2 * m * v) + gamma) / (2 * (n**2 + n * v * self.phi**2)))
        u = (-alpha * v * self.phi + sqrt(alpha**2 * v**2 * self.phi**2 + 4 * v))**2 / 4
        beta = alpha * self.phi / (sqrt(u) + v * alpha * self.phi)
        
        self.mu = self.mu + alpha * y * self.sigma @ X
        self.sigma = self.sigma - beta * self.sigma @ X @ X.T @ self.sigma

In [5]:
usps = fetch_mldata('usps')
X = scale(usps.data)
y = usps.target
y_bin = np.array([1 if i >= 5 else -1 for i in y])

X_train, X_test, y_train, y_test = train_test_split(X, y_bin, test_size=0.3, random_state=0)

scoreSCW1 = GridSearchCV(
    estimator = SCW1(n_iter=3, random_state=0, shuffle=True),
    param_grid = {'C': [0.25, 0.5, 1, 2, 4, 8, 16], 'eta': [0.5, 0.6, 0.7, 0.8, 0.9, 1.0]},
    cv = 3,
    verbose = 2
).fit(X_train, y_train).score(X_test, y_test)
print('SCW1: {0:.4f}'.format(scoreSCW1))

scoreSCW2 = GridSearchCV(
    estimator = SCW2(n_iter=3, random_state=0, shuffle=True),
    param_grid = {'C': [0.25, 0.5, 1, 2, 4, 8, 16], 'eta': [0.5, 0.6, 0.7, 0.8, 0.9, 1.0]},
    cv = 3,
    verbose = 2
).fit(X_train, y_train).score(X_test, y_test)
print('SCW2: {0:.4f}'.format(scoreSCW2))

print('result :')
print('SCW1: {0:.4f}'.format(scoreSCW1))
print('SCW2: {0:.4f}'.format(scoreSCW2))

Fitting 3 folds for each of 42 candidates, totalling 126 fits
[CV] eta=0.5, C=0.25 .................................................
[CV] ........................................ eta=0.5, C=0.25 -   0.0s
[CV] eta=0.5, C=0.25 .................................................


[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    9.0s remaining:    0.0s


[CV] ........................................ eta=0.5, C=0.25 -   0.0s
[CV] eta=0.5, C=0.25 .................................................
[CV] ........................................ eta=0.5, C=0.25 -   0.0s
[CV] eta=0.6, C=0.25 .................................................
[CV] ........................................ eta=0.6, C=0.25 -   0.0s
[CV] eta=0.6, C=0.25 .................................................
[CV] ........................................ eta=0.6, C=0.25 -   0.0s
[CV] eta=0.6, C=0.25 .................................................
[CV] ........................................ eta=0.6, C=0.25 -   0.0s
[CV] eta=0.7, C=0.25 .................................................
[CV] ........................................ eta=0.7, C=0.25 -   0.0s
[CV] eta=0.7, C=0.25 .................................................
[CV] ........................................ eta=0.7, C=0.25 -   0.0s
[CV] eta=0.7, C=0.25 .................................................
[CV] .

[Parallel(n_jobs=1)]: Done 126 out of 126 | elapsed: 14.1min finished


SCW1: 0.9125
Fitting 3 folds for each of 42 candidates, totalling 126 fits
[CV] eta=0.5, C=0.25 .................................................
[CV] ........................................ eta=0.5, C=0.25 -   0.0s
[CV] eta=0.5, C=0.25 .................................................


[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    6.0s remaining:    0.0s


[CV] ........................................ eta=0.5, C=0.25 -   0.0s
[CV] eta=0.5, C=0.25 .................................................
[CV] ........................................ eta=0.5, C=0.25 -   0.0s
[CV] eta=0.6, C=0.25 .................................................
[CV] ........................................ eta=0.6, C=0.25 -   0.0s
[CV] eta=0.6, C=0.25 .................................................
[CV] ........................................ eta=0.6, C=0.25 -   0.0s
[CV] eta=0.6, C=0.25 .................................................
[CV] ........................................ eta=0.6, C=0.25 -   0.0s
[CV] eta=0.7, C=0.25 .................................................
[CV] ........................................ eta=0.7, C=0.25 -   0.0s
[CV] eta=0.7, C=0.25 .................................................
[CV] ........................................ eta=0.7, C=0.25 -   0.0s
[CV] eta=0.7, C=0.25 .................................................
[CV] .

[Parallel(n_jobs=1)]: Done 126 out of 126 | elapsed: 22.4min finished


SCW2: 0.9011
result :
SCW1: 0.9125
SCW2: 0.9011
