# Multi-Class SVMs

In [None]:
from time import time
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np

from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split

In [None]:
from pystruct.models import MultiClassClf
from pystruct.learners import NSlackSSVM

# do a binary digit classification
digits = load_digits()
X, y = digits.data, digits.target
#X = X / 255.
X = X / 16.
#y = y.astype(np.int) - 1
X_train, X_test, y_train, y_test = train_test_split(X, y)

# we add a constant 1 feature for the bias
X_train_bias = np.hstack([X_train, np.ones((X_train.shape[0], 1))])
X_test_bias = np.hstack([X_test, np.ones((X_test.shape[0], 1))])

print(X_train_bias.shape)

In [None]:
model = MultiClassClf(n_features=X_train_bias.shape[1], n_classes=10)

In [None]:
model.joint_feature(X_train_bias[0], 1).shape

In [None]:
n_slack_svm = NSlackSSVM(model, verbose=2, check_constraints=False, C=0.1,
                         batch_size=100, tol=1e-2)

In [None]:
# n-slack cutting plane ssvm
start = time()
n_slack_svm.fit(X_train_bias, y_train)
time_n_slack_svm = time() - start
y_pred = np.hstack(n_slack_svm.predict(X_test_bias))
print("Score with pystruct n-slack ssvm: %f (took %f seconds)"
      % (np.mean(y_pred == y_test), time_n_slack_svm))

In [None]:
n_slack_svm.w.shape

In [None]:
plt.imshow(X_test[0].reshape(8, 8), cmap='gray_r')

In [None]:
n_slack_svm.predict([X_test_bias[0]])

In [None]:
for i in range(10):
    print("{}: {:2f}".format(i, np.dot(n_slack_svm.w, n_slack_svm.model.joint_feature(X_test_bias[0], i))))

In [None]:
fig, ax = plt.subplots(2, 5, figsize=(10, 5),
                       subplot_kw={'xticks': (()), 'yticks':(())})
w_perclass = n_slack_svm.w.reshape(10, -1)
for i, ax in zip(range(10), ax.ravel()):
    ax.imshow(w_perclass[i][:-1].reshape(8, 8), cmap=plt.cm.RdBu)

In [None]:
plt.plot(n_slack_svm.primal_objective_curve_, label="primal objective")
plt.plot(n_slack_svm.objective_curve_, label="dual objective")
plt.legend()

# Exercises
1) Replace the n-slack ssvm by a sugradient ssvm.

2) Replace the MultiClassClf by a GraphCRF with a single node and no edges.