In [1]:
from absl import app
from absl import flags

import jax
import jax.numpy as jnp
from jax import vmap, jit
from jaxopt import projection
from jaxopt import ProjectedGradient
from jaxopt import BoxOSQP
from typing import Mapping
import numpy as onp
from sklearn import datasets
from sklearn import preprocessing
from sklearn import svm

from functools import partial

In [None]:
def rbf_kernel(x: jnp.array, x_prime: jnp.array):
    return jnp.exp(-0.5 * jnp.linalg.norm(x - x_prime) ** 2)


def poly_kernel(x: jnp.array, x_prime: jnp.array, a: float):
    return jnp.dot(x, x_prime) ** a


def get_poly_kernel(a):
    return partial(poly_kernel, a=a)


def solve_svm(K: jnp.array, y: jnp.array, C: float):

    def objective_fun(beta, K, y):
        # print(K.shape, beta.shape, y.shape)
        return 0.5 * jnp.dot(jnp.dot(beta, K), beta) - jnp.dot(beta, y)

    # TODO add it if intercept needed
    # w = jnp.zeros(y.shape[0])

    def proj(beta, C):
        box_lower = jnp.where(y == 1, 0, -C)
        box_upper = jnp.where(y == 1, C, 0)
        proj_params = (box_lower, box_upper)
        return projection.projection_box(beta, proj_params)

    # Run solver.
    beta_init = jnp.ones(y.shape[0])
    solver = ProjectedGradient(
        fun=objective_fun, projection=proj, tol=tol, maxiter=500, verbose=verbose
    )
    beta_fit = solver.run(beta_init, hyperparams_proj=C, K=K, y=y).params

    return beta_fit


vect_solve_svm = jax.vmap(solve_svm, in_axes=(None, 0, None), out_axes=0)


class BinarySVM:
    def __init__(self, kernel_func: Mapping, c: float, threshold=0):
        self.c = c
        self.threshold = threshold
        self.kernel_vec = vmap(kernel_func, (None, 0), 0)
        self.kernel_mat = vmap(self.kernel_vec, in_axes=(0, None), out_axes=0)

    def fit(self, X, y):
        alpha = solve_svm(self.kernel_mat(X, X), y, self.c)

        if self.threshold > 0:
            threshold_mask = jnp.abs(alpha) > self.threshold
            x_pred = X[threshold_mask]
            self._alpha_pred = alpha[threshold_mask]
            self._kernel_pred = partial(self.kernel_mat, x_pred)
        else:
            self._alpha_pred = alpha
            self._kernel_pred = partial(self.kernel_mat, X)

    def predict(self, X):
        if X.ndim < 2:
            X = X[None, :]
        preds = self._alpha_pred @ self._kernel_pred(X)
        X = jnp.squeeze(X)
        return preds


class MultiClassSVM:
    def __init__(self, num_classes: int, kernel_func: Mapping, c: float, comp_num=None):
        self.num_classes = num_classes
        self.c = c
        self.comp_num = comp_num
        self.kernel_vec = vmap(kernel_func, (None, 0), 0)
        self.kernel_mat = vmap(self.kernel_vec, in_axes=(0, None), out_axes=0)

    def fit(self, X, y):
        y_onehot = jax.nn.one_hot(y, num_classes=self.num_classes, axis=0)
        y_onehot = 2 * y_onehot - 1
        K = self.kernel_mat(X, X)
        alphas = vect_solve_svm(K, y_onehot, self.c)
        if self.comp_num == None:
            ref_points = X
            self._alpha_pred = alphas
            self.kern_fun = partial(partial(self.kernel_mat, ref_points))
        else:
            # print(alphas.shape)
            sort_index = jnp.argsort(-jnp.abs(alphas), axis=1)[:,:self.comp_num]
            # print(sort_index.shape)
            self._alpha_pred = []
            self._chosen_points = []
            self.kern_fun = []
            for i in range(self.num_classes):
                self._alpha_pred.append(alphas[i, sort_index[i]])
                self._chosen_points.append(X[sort_index[i], :])
                self.kern_fun.append(partial(self.kernel_mat, self._chosen_points[-1]))
            self._alpha_pred = jnp.array(self._alpha_pred)
            self._chosen_points = jnp.array(self._chosen_points)
            # self.kernel_comp_vec = 
    def predict(self, X:jnp.array):
        if X.ndim<2:
            X = X[None, :]
        if self.comp_num==None:
            kern_dist = self.kern_fun(X)
            prob = self._alpha_pred@kern_dist
            return jnp.argsort(-prob, axis=0)[0,:]
        else:
            raise NotImplementedError("custom compoenents number is not yet implemented.")

In [None]:
## Testing the kernel 


lam = 0.5
tol = 1e-6
num_samples = 200
num_features = 5
verbose = False
c = 1/(lam*num_samples)
n_classes = 3


X, y = datasets.make_classification(
    n_samples=num_samples, n_features=num_features,n_clusters_per_class=1,n_classes=n_classes, random_state=0
)


k_fun = rbf_kernel
X = preprocessing.Normalizer().fit_transform(X)
multi_svm = MultiClassSVM(3, k_fun, c=0.1, comp_num=None)
multi_svm.fit(X[:180], y[:180])
print("My results", jnp.mean(multi_svm.predict(X[180:]) == y[180:]))



sk_svm = svm.SVC(kernel="rbf", C=0.1 ).fit(X[:180], y[:180])
print("Sklearn", jnp.mean(sk_svm.predict(X[180:]) == y[180:]))

# multi_svm.

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


My results 0.85
Sklearn 0.90000004
