In [64]:
import numpy as np
import matplotlib.pyplot as plt
import mnist
from mnist.loader import MNIST

#Data: read MNIST
DATA_PATH = './mnist'
mn = MNIST(DATA_PATH)
mn.gz = True # Enable loading of gzip-ed files
train_img, train_label = mn.load_training()
test_img, test_label = mn.load_testing()
train_X = np.array(train_img)[0:10000]
train_Y = np.array(train_label).reshape(-1, 1)
test_X = np.array(test_img)
test_Y = np.array(test_label).reshape(-1, 1)

In [61]:
print(train_X.shape)
print(len(train_Y))

(10000, 784)
10000


# SVM loss
对于样本i，其损失函数：
$$ \begin{aligned} L_{i} &=\sum_{j \neq y_{i}}\left\{\begin{array}{ll}0 & \text { if } s_{y_{i}} \geq s_{j}+1 \\ s_{j}-s_{y_{i}}+1 & \text { otherwise }\end{array}\right.\\ &=\sum_{j \neq y_{i}} \max \left(0, s_{j}-s_{y_{i}}+1\right) \end{aligned}
 $$

In [65]:
def svm_loss_process(X,Y,W,lamda):
    dW = np.zeros((784, 10))
    loss = 0.0
    for i in range(X.shape[0]):
        scores = X[i].dot(W)
        correct_class_score = scores[Y[i]]
        for j in range(10):
            if j == Y[i]:
                continue
            margin = scores[j] - correct_class_score + 1
            if margin > 0:
                loss += margin
                dW[:,Y[i][0]] += -X[i].T
                dW[:,j] += X[i].T
    loss = (loss + lamda * np.sum(W*W)) / X.shape[0]
    dW = dW / X.shape[0] + 2 * lamda * W
    return dW,loss

In [66]:
#Hyperparameters
lr = 0.00001
epochs = 100

#Model
W = np.random.rand(784, 10)
lamda = 0.95

for step in range(epochs):
    dW, loss = svm_loss_process(train_X,train_Y,W,lamda)
    W = W - lr * dW
    test = test_X.dot(W)
    correct_num = np.sum(np.argmax(test,axis=1) == test_Y.flatten())
    acc = correct_num / train_X.shape[0]
    print("[%d/%d]LOSS:%.3f, ACC:%.3f" %(step+1, epochs, loss, acc))

[1/100]LOSS:2830.986, ACC:0.177
[2/100]LOSS:2776.955, ACC:0.177
[3/100]LOSS:2724.139, ACC:0.177
[4/100]LOSS:2672.603, ACC:0.177
[5/100]LOSS:2622.408, ACC:0.178
[6/100]LOSS:2573.428, ACC:0.178
[7/100]LOSS:2525.665, ACC:0.177
[8/100]LOSS:2479.219, ACC:0.177
[9/100]LOSS:2433.979, ACC:0.177
[10/100]LOSS:2390.075, ACC:0.177
[11/100]LOSS:2347.405, ACC:0.177
[12/100]LOSS:2305.902, ACC:0.178
[13/100]LOSS:2265.710, ACC:0.178
[14/100]LOSS:2226.740, ACC:0.178
[15/100]LOSS:2189.048, ACC:0.178
[16/100]LOSS:2152.522, ACC:0.179
[17/100]LOSS:2117.079, ACC:0.179
[18/100]LOSS:2082.841, ACC:0.178
[19/100]LOSS:2049.673, ACC:0.178
[20/100]LOSS:2017.623, ACC:0.178
[21/100]LOSS:1986.723, ACC:0.177
[22/100]LOSS:1956.950, ACC:0.178
[23/100]LOSS:1928.233, ACC:0.178
[24/100]LOSS:1900.485, ACC:0.178
[25/100]LOSS:1873.688, ACC:0.179
[26/100]LOSS:1847.837, ACC:0.180
[27/100]LOSS:1822.908, ACC:0.180
[28/100]LOSS:1798.857, ACC:0.181
[29/100]LOSS:1775.671, ACC:0.183
[30/100]LOSS:1753.347, ACC:0.183
[31/100]LOSS:1731.8