# Example for using Kernel CMSM  Model with sklearn-like API

In [1]:
import sys, os
sys.path.insert(0, os.pardir)

### Sample face dataset
 - Put Face_data.mat under sample_dataset

In [2]:
import h5py
import numpy as np

def load_mat_file(file_name, *var_name):
    """Load .mat file as ndarray
    example     array = Utils.load_mat_file("hoge.mat")
                array = Util.load_mat_file("hoge.mat", ["var1", "result"])

    Args:
        file_name (TYPE): The file name for reading mat file
        var_name (TYPE): The variable name for a matrix

    Returns:
        array_dict (TYPE): Read arrays as the dictionary
    """
    f = h5py.File(file_name, "r")
    keys_name = list(f.keys())
    array_dict = {}

    if len(var_name) == 0:

        for i in keys_name:
            array_dict[i] = f[i]

        if "#refs#" in array_dict:
            del (array_dict["#refs#"])
        if "#subsystem#" in array_dict:
            del (array_dict["#subsystem#"])
    else:
        for i in var_name:
            array_dict[i] = f[i]
        if "#refs#" in array_dict:
            del (array_dict["#refs#"])
        if "#subsystem#" in array_dict:
            del (array_dict["#subsystem#"])
    return array_dict

face = load_mat_file("../sample_dataset/Face_data.mat")
train_X = np.array(face["X1"])
train_y = np.arange(len(train_X))
test_X = np.array(face["X2"])
test_X = test_X.reshape(-1, test_X.shape[-2], test_X.shape[-1])
test_y = np.array([[i] * 36 for i in range(10)]).flatten()

### Training

In [3]:
eps = 5e-1
rs = np.random.RandomState(seed=100)
train_X = [_X + eps * rs.randn(*_X[0].shape) for _X in train_X]
test_X = [_X + eps * rs.randn(*_X[0].shape) for _X in test_X]

In [8]:
from cvt.models import KernelMSM
from sklearn.model_selection import RandomizedSearchCV
from scipy.stats import randint as sp_randint

model = KernelMSM(n_subdims=5, sigma=100, faster_mode=True)
model.fit(train_X, train_y)

### Evaluation

In [7]:
pred = model.predict(test_X)
print(f"pred: {pred}\ntrue: {test_y}\naccuracy: {(pred == test_y).mean()}")

pred: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1
 7 1 4 0 4 1 4 7 1 1 0 1 4 1 1 4 1 1 1 0 0 1 4 1 4 1 1 1 1 1 1 1 1 1 1 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 3 2 6 2 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 4 4 4 4
 4 4 4 4 4 4 0 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 5 8 5 5 5
 5 5 5 8 5 5 5 8 5 8 5 5 5 5 5 5 8 5 8 5 5 5 9 5 5 5 5 5 8 5 5 6 6 6 6 6 6
 6 6 6 6 6 6 6 9 6 6 6 6 6 6 6 6 6 6 6 6 9 6 6 6 6 6 6 6 6 6 7 7 7 7 7 7 7
 7 7 7 7 7 4 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 0 7 7 7 7 7 8 8 8 8 8 8 8 8
 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 6 6 6 9 9 9 9 9 9
 9 9 6 9 9 9 9 6 9 9 9 9 6 6 6 9 5 6 6 9 9 9 9 9 9 9 9]
true: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 3 3 3
 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3