In [2]:
from absl import app
from absl import flags

import jax
import jax.numpy as jnp
from jax import vmap, jit
from jaxopt import projection
from jaxopt import ProjectedGradient
from jaxopt import BoxOSQP
from typing import Mapping
import numpy as onp
from sklearn import datasets
from sklearn import preprocessing
from sklearn import svm

from functools import partial

In [43]:
lam = 0.5
tol = 1e-6
num_samples = 20
num_features = 5
verbose = False

In [44]:
def rbf_kernel(x: jnp.array, x_prime: jnp.array):
    return jnp.exp(-0.5 * jnp.linalg.norm(x - x_prime) ** 2)


def poly_kernel(x: jnp.array, x_prime: jnp.array, a: float):
    return jnp.dot(x, x_prime) ** a


def get_poly_kernel(a):
    return partial(poly_kernel, a=a)


def solve_svm(K: jnp.array, y: jnp.array, C: float):

    def objective_fun(beta, K, y):
        # print(K.shape, beta.shape, y.shape)
        return 0.5 * jnp.dot(jnp.dot(beta, K), beta) - jnp.dot(beta, y)
    #TODO add it if intercept needed
    # w = jnp.zeros(y.shape[0])

    def proj(beta, C):
        box_lower = jnp.where(y == 1, 0, -C)
        box_upper = jnp.where(y == 1, C, 0)
        proj_params = (box_lower, box_upper)
        return projection.projection_box(beta, proj_params)

    # Run solver.
    beta_init = jnp.ones(y.shape[0])
    solver = ProjectedGradient(
        fun=objective_fun, projection=proj, tol=tol, maxiter=500, verbose=verbose
    )
    beta_fit = solver.run(beta_init, hyperparams_proj=C, K=K, y=y).params

    return beta_fit

vect_solve_svm = jax.vmap(solve_svm, in_axes=(None,0,None),out_axes=0)


class BinarySVM:
    def __init__(self, kernel_func: Mapping, c: float, threshold=0):
        self.c = c
        self.threshold = threshold
        self.kernel_vec = vmap(kernel_func, (None, 0), 0)
        self.kernel_mat = vmap(self.kernel_vec, in_axes=(0, None), out_axes=0)

    def fit(self, X, y):
        alpha = solve_svm(self.kernel_mat(X,X), y, self.c)
        
        if self.threshold>0:
            threshold_mask = jnp.abs(alpha) > self.threshold
            x_pred = X[threshold_mask]
            self._alpha_pred = alpha[threshold_mask]
            self._kernel_pred=partial(self.kernel_mat, x_pred)
        else:
            self._alpha_pred = alpha
            self._kernel_pred = partial(self.kernel_mat, X)
            
    def predict(self, X):  
        if X.ndim<2:
            X = X[None,:]
        preds = self._alpha_pred@self._kernel_pred(X)
        X = jnp.squeeze(X)
        return preds
    
class MultiClassSVM():
    def __init__(self, num_classes:int, kernel_func: Mapping, c: float, comp_num=None):
        self.num_classes = num_classes
        self.c = c
        self.comp_num = comp_num
        self.kernel_vec = vmap(kernel_func, (None, 0), 0)
        self.kernel_mat = vmap(self.kernel_vec, in_axes=(0, None), out_axes=0)

    def fit(self, X, y):
        
        y_onehot = jax.nn.one_hot(y, num_classes=self.num_classes, axis=0)
        y_onehot = 2*y_onehot - 1
        K = self.kernel_mat(X,X)
        alphas = vect_solve_svm(K, y_onehot,self.c)
        if self.comp_num==None:
            self.ref_point= X
            self._alpha_pred = alphas
        else:
            raise NotImplemented
        

In [45]:
def binary_kernel_svm_skl(K, y, C):
    print(f"Solve SVM with sklearn.svm.SVC: ")
    svc = svm.SVC(kernel="precomputed", C=C, tol=tol).fit(K, y)
    dual_coef = onp.zeros(K.shape[0])
    dual_coef[svc.support_] = svc.dual_coef_[0]
    return dual_coef, svc


def print_svm_result(beta, threshold=1e-4):
    # Here the vector `beta` of coefficients is signed:
    # its sign depends of the true label of the corresponding example.
    # Hence we use jnp.abs() to detect support vectors.
    is_support_vectors = jnp.abs(beta) > threshold
    print(f"Beta: {beta}")
    print(f"Support vector indices: {onp.where(is_support_vectors)[0]}")
    print("")

In [41]:
# Prepare data.
X, y = datasets.make_classification(
    n_samples=100, n_features=num_features,n_clusters_per_class=1, n_classes=3, random_state=0
)


X = preprocessing.Normalizer().fit_transform(X)
multi_svm = MultiClassSVM(3, k_fun, 1)
multi_svm.fit(X, y)
multi_svm.

(3, 100)


In [36]:



y = jnp.array(y * 2.0 - 1)  # Transform labels from {0, 1} to {-1., 1.}.

# lam =  lam
lam  = 20
k_fun = rbf_kernel
C = 1.0 / lam
bsvm = BinarySVM(k_fun, c=C)
bsvm.fit(X, y)
print("acc, ", jnp.mean(jnp.sign(bsvm.predict(X)*y)>0))


acc,  0.84999996


In [37]:
import numpy as np

In [38]:
C = 1.0 / lam
K = np.array(bsvm.kernel_mat(X,X))
print(K.shape)
beta_fit_pg, svc= binary_kernel_svm_skl(K, y, C)

print("acc, ", jnp.mean(jnp.sign(svc.predict(K)*y)>0))

(100, 100)
Solve SVM with sklearn.svm.SVC: 
acc,  0.91999996
