In [1]:
import os
import struct
import random
import numpy as np
import matplotlib.pyplot as plt

def load_mnist(path, kind='train'):
    """Load MNIST data from `path`"""
    labels_path = os.path.join(path,
                               '%s-labels.idx1-ubyte'
                                % kind)
    images_path = os.path.join(path,
                               '%s-images.idx3-ubyte'
                               % kind)

    with open(labels_path, 'rb') as lbpath:
        magic, n = struct.unpack('>II',
                                 lbpath.read(8))
        labels = np.fromfile(lbpath,
                             dtype=np.uint8)

    with open(images_path, 'rb') as imgpath:
        magic, num, rows, cols = struct.unpack(">IIII",
                                               imgpath.read(16))
        images = np.fromfile(imgpath,
                             dtype=np.uint8).reshape(len(labels), 784)

    return images, labels

#### Loading the data

X_train, y_train = load_mnist('./data', kind='train')
X_test, y_test = load_mnist('./data', kind='t10k')

In [12]:
class MySVM(object):
    """
    1 vs 1 SVM (binary classification)
    """
    def __init__(self, C=1000, eta=0.01, batch_size=1, max_iter=25, epsilon=1e-8):
        """
        constructor of MyBinarySVM class.
        """
        
        self.max_iter = max_iter
        self.batch_size = batch_size
        self.eta = eta
        self.C = C
        self.epsilon = epsilon
        self.num_classes = 0
    
    def fit(self, X, y=None, params=None):
        """
        fit method for training svm
        
        Arguments:
        --------------------------
        X: image data. (60000, 784)
        y: label data. (60000, 1)
        
        Returns:
        --------------------------
        Z: class score
        """

        m = np.shape(X)[0] #행 개수 60000
        n = np.shape(X)[1] #열 개수 784
        self.num_classes = len(np.unique(y)) #클래스 수 = 10
        
        y_encoded = self.encode_y(y)
        
        # create weights.
        if params is None:
            self.params = {
                'W': np.random.randn(n, self.num_classes), #(784,10) 정규분포난수
                'b': np.random.randn(1, self.num_classes)
            }

        cnt = 1
        
        # main loop: how much iterate on entire dataset.
        for epoch in range(self.max_iter):
            # before dive into SGD, shuffle dataset
            X_shuffled, y_shuffled = self.shuffle(X, y_encoded)
            
            # cost variable for printing/logging
            avg_loss = 0
            
            # batch_count = dataset_size / batch_size
            batch_count = int(np.ceil(np.shape(X)[0] / self.batch_size))
            
            # mini-batch loop
            for t in range(batch_count):
                # draw the {batch_size} number of samples from X and y
                X_batch, y_batch, bs = self.next_batch(X_shuffled, y_shuffled, t)
                
                # just in case, reshape batch of X and y into proper shape.
                X_batch = np.reshape(X_batch, (bs, n))
                y_batch = np.reshape(y_batch, (bs, self.num_classes))
                
                # prediction phase
                Z = self.forward_prop(X_batch)
                Z = np.reshape(Z, (bs, self.num_classes))
                
                # compute cost phase
                loss = self.compute_cost(y_batch, Z)
                # update weights phase
                self.backward_prop(X_batch, y_batch, Z, bs, cnt)
                
                # accumulate loss
                avg_loss += loss
                cnt += 1
        
            # logging
            avg_loss /= batch_count
            if epoch % (self.max_iter / 10) == 0:
                print('Cost at epoch {0}: {1}'.format(epoch, avg_loss))

        return self
    
    def encode_y(self, y):
        y_encoded = np.ones((np.shape(y)[0], self.num_classes)) #1로 이루어진 배열(60000,10)
        
        for i in range(self.num_classes):
            y_encoded[:, i][y != i] = -1
            
        return y_encoded
    
    def shuffle(self, X, y):
        """
        Random selection is required for SGD.
        But, my approach is to shuffle entire data before every iteration.
        This has same effect as random selection.
        
        Arguments:
        ---------------------------
        X: images (BATCH_SIZE, 784)
        y: labels (BATCH_SIZE, 1)
        
        Returns:
        ---------------------------
        shuffled data
        """
        
        # the number of dataset samples
        m = np.shape(X)[0]
        
        # variable for shuffle
        r = np.arange(0, m)
        
        np.random.shuffle(r)
        
        return X[r], y[r]
    
    def next_batch(self, X, y, t):
        """
        Get next batch.
        If it is SGD, next_batch function just pick one sample from dataset.
        
        Arguments:
        ---------------------------------
        X: images (60000, 784)
        y: labels (60000, 1)
        
        Returns:
        ---------------------------------
        X_batch: small subset of X (BATCH_SIZE, 784)
        y_batch: small subset of y (BATCH_SIZE, 1)
        """
        
        # the number of dataset samples
        m = np.shape(X)[0] #60000
        
        # draw the {batch_size} number of samples from X and y
        X_batch = X[t * self.batch_size : min(m, (t + 1) * self.batch_size)]
        y_batch = y[t * self.batch_size : min(m, (t + 1) * self.batch_size)]
        bs = min(m, (t + 1) * self.batch_size) - t * self.batch_size
        
        return X_batch, y_batch, bs
    
    def forward_prop(self, X):
        """
        Process of inference (prediction).
        
        Arguments:
        -----------------------
        X: images e.g (BATCH_SIZE, 784)
        params: weights dictionary(map in other programming language)
        
        Returns:
        -----------------------
        A: 
        """
        
        # prediction
        Z = np.matmul(X, self.params['W']) + self.params['b']
        print(np.matmul(X, self.params['W'])
        return Z
    
#     def sigmoid(self, Z):
#         """
#         sigmoid activation for binary classification
        
#         Arguments:
#         ----------------------
#         Z: class score (W.T * X)
        
#         Returns:
#         ----------------------
#         sigmoid activation
#         """
        
#         return 1 / (1 + np.exp(-Z))
    
    def compute_cost(self, y, Z):
        """
        compute cost function (loss function)
        
        Arguments:
        ------------------------------
        y: true label
        Z: class score (W.T * X)
        
        Returns:
        ------------------------------
        loss: total cost (loss)
        """
        
        # compute loss function
        temp = 1 - np.multiply(y, Z)
        temp[temp < 0] = 0
        loss = np.mean(temp)
        return loss
    
    def backward_prop(self, X, y, Z, bs, cnt):
        """
        update weights
        
        Arguments:
        ----------------------------
        X: images e.g (BATCH_SIZE, 784)
        y: labels e.g (BATCH_SIZE, 1)
        Z: class score after forward propagation
        params: weights dictionary(map in other programming language)
        eta: learning rate
        
        Returns:
        ----------------------------
        params: weights dictionary
        """
        
        # number of features
        n = np.shape(X)[1]
        
        # differential vector of loss function to update weights
        dw = np.zeros(self.params['W'].shape)
        db = np.zeros(self.params['b'].shape)
        
        Z = np.reshape(Z, (bs, self.num_classes))
        
        temp = np.multiply(y, Z)
        temp = 1 - temp
        
        temp[temp <= 0] = 0
        temp[temp > 0] = 1
        
        y_temp = np.multiply(y, temp.reshape(bs, self.num_classes))
        
        dw = -(1 / bs) * np.matmul(X.T, y_temp) + (1 / self.C) * self.params['W']
        db = -(1 / bs) * np.sum(y_temp, axis=0)
       
        self.params['W'] = self.params['W'] - (self.eta / (1 + self.epsilon * cnt)) * dw
        self.params['b'] = self.params['b'] - (self.eta / (1 + self.epsilon * cnt)) * db
       
        return self.params
    
    def predict(self, X, y=None):
        m = np.shape(X)[0]
        
        class_score = self.forward_prop(X)
        pred = np.argmax(class_score, axis=1)
        
        return pred
    
    def score(self, X, y=None):
        pred = self.predict(X)
        score = np.mean(pred == y)
        
        return score
    
    def get_parameters(self):
        return self.params

SyntaxError: invalid syntax (<ipython-input-12-6fd3fc078cdd>, line 163)

In [13]:
mine=MySVM()
mine.fit(X_test, y_test)

[[ 1.8532655  -1.52634006 -1.10966435 ... -0.68417371  0.62452735
   0.41758795]
 [-0.39519526  1.27535824  1.43665833 ...  0.53373509  1.041939
   1.53538226]
 [-0.74729575  1.86883594 -0.3664837  ...  0.86412377  0.11118531
  -0.64306802]
 ...
 [-0.02913711  0.95023071 -0.57683022 ...  0.18845177  1.19566812
   0.12395642]
 [ 0.67353486  0.34896877  0.01243046 ...  0.14851567  0.17343001
   0.16055064]
 [ 0.6310023   1.57455514  1.35109785 ...  0.21084373  0.76638996
   0.24331147]]
[[ 1.85324696 -1.5263248  -1.10965326 ... -0.68416687  0.62452111
   0.41758377]
 [-0.39519131  1.27534548  1.43664396 ...  0.53372975  1.04192858
   1.53536691]
 [-0.74728827  1.86881725 -0.36648004 ...  0.86411513  0.1111842
  -0.64306159]
 ...
 [-0.02913682  0.95022121 -0.57682445 ...  0.18844989  1.19565616
   0.12395518]
 [ 0.67352812  0.34896528  0.01243033 ...  0.14851419  0.17342828
   0.16054904]
 [ 0.63099599  1.57453939  1.35108434 ...  0.21084162  0.76638229
   0.24330904]]
[[ 1.85322843 -1.52

[[ 1.84891542 -1.52275736 -1.10705969 ... -0.68256778  0.62306143
   0.41660776]
 [-0.39426764  1.27236465  1.43328613 ...  0.53248228  1.03949331
   1.53177834]
 [-0.74554166  1.86444932 -0.36562347 ...  0.86209546  0.11092433
  -0.64155858]
 ...
 [-0.02906872  0.94800028 -0.57547625 ...  0.18800943  1.19286158
   0.12366546]
 [ 0.6719539   0.34814965  0.01240128 ...  0.14816707  0.17302293
   0.16017379]
 [ 0.62952118  1.57085927  1.34792649 ...  0.21034883  0.76459105
   0.24274036]]
[[ 1.84889693 -1.52274213 -1.10704862 ... -0.68256095  0.6230552
   0.4166036 ]
 [-0.3942637   1.27235193  1.4332718  ...  0.53247695  1.03948292
   1.53176302]
 [-0.7455342   1.86443067 -0.36561982 ...  0.86208684  0.11092322
  -0.64155216]
 ...
 [-0.02906843  0.9479908  -0.5754705  ...  0.18800755  1.19284966
   0.12366423]
 [ 0.67194718  0.34814617  0.01240115 ...  0.14816559  0.1730212
   0.16017219]
 [ 0.62951488  1.57084356  1.34791301 ...  0.21034672  0.7645834
   0.24273793]]
[[ 1.84887844 -1.52

[[ 1.84455712 -1.51916788 -1.10445011 ... -0.68095882  0.62159274
   0.41562573]
 [-0.39333826  1.26936541  1.42990756 ...  0.5312271   1.037043
   1.5281676 ]
 [-0.74378425  1.8600544  -0.36476162 ...  0.86006331  0.11066286
  -0.64004628]
 ...
 [-0.0290002   0.94576564 -0.57411973 ...  0.18756625  1.19004975
   0.12337396]
 [ 0.67036996  0.34732899  0.01237205 ...  0.14781781  0.17261507
   0.15979622]
 [ 0.62803725  1.5671564   1.34474913 ...  0.20985299  0.76278874
   0.24216817]]
[[ 1.84453867 -1.51915269 -1.10443906 ... -0.68095201  0.62158652
   0.41562157]
 [-0.39333433  1.26935271  1.42989326 ...  0.53122179  1.03703263
   1.52815232]
 [-0.74377681  1.8600358  -0.36475797 ...  0.86005471  0.11066175
  -0.64003988]
 ...
 [-0.02899991  0.94575618 -0.57411399 ...  0.18756437  1.19003785
   0.12337272]
 [ 0.67036326  0.34732552  0.01237192 ...  0.14781633  0.17261335
   0.15979463]
 [ 0.62803097  1.56714073  1.34473568 ...  0.20985089  0.76278111
   0.24216575]]
[[ 1.84452023 -1.5

[[ 1.84035632 -1.51570813 -1.10193483 ... -0.679408    0.62017712
   0.41467918]
 [-0.39244248  1.26647455  1.42665108 ...  0.53001728  1.03468123
   1.52468735]
 [-0.74209036  1.85581831 -0.36393091 ...  0.8581046   0.11041083
  -0.63858864]
 ...
 [-0.02893416  0.94361175 -0.57281223 ...  0.18713909  1.18733953
   0.12309299]
 [ 0.66884326  0.34653798  0.01234387 ...  0.14748117  0.17222196
   0.1594323 ]
 [ 0.62660696  1.56358736  1.34168659 ...  0.20937507  0.76105156
   0.24161666]]
[[ 1.84033792 -1.51569297 -1.10192381 ... -0.67940121  0.62017092
   0.41467503]
 [-0.39243855  1.26646189  1.42663682 ...  0.53001198  1.03467089
   1.5246721 ]
 [-0.74208294  1.85579975 -0.36392727 ...  0.85809602  0.11040973
  -0.63858225]
 ...
 [-0.02893387  0.94360231 -0.5728065  ...  0.18713722  1.18732765
   0.12309176]
 [ 0.66883657  0.34653452  0.01234375 ...  0.14747969  0.17222024
   0.15943071]
 [ 0.6266007   1.56357172  1.34167318 ...  0.20937298  0.76104395
   0.24161424]]
[[ 1.84031951 -1

[[ 1.83706504 -1.51299745 -1.09996414 ... -0.67819295  0.619068
   0.41393757]
 [-0.39174063  1.2642096   1.42409967 ...  0.5290694   1.03283082
   1.52196061]
 [-0.74076321  1.85249938 -0.36328006 ...  0.85656997  0.11021338
  -0.63744659]
 ...
 [-0.02888241  0.9419242  -0.57178781 ...  0.18680441  1.1852161
   0.12287285]
 [ 0.6676471   0.34591824  0.01232179 ...  0.14721741  0.17191396
   0.15914718]
 [ 0.62548634  1.56079105  1.33928713 ...  0.20900062  0.7596905
   0.24118455]]
[[ 1.83704667 -1.51298232 -1.09995314 ... -0.67818617  0.61906181
   0.41393343]
 [-0.39173672  1.26419696  1.42408543 ...  0.52906411  1.03282049
   1.52194539]
 [-0.7407558   1.85248085 -0.36327642 ...  0.85656141  0.11021227
  -0.63744022]
 ...
 [-0.02888212  0.94191478 -0.57178209 ...  0.18680254  1.18520424
   0.12287162]
 [ 0.66764043  0.34591478  0.01232167 ...  0.14721594  0.17191224
   0.15914559]
 [ 0.62548009  1.56077544  1.33927374 ...  0.20899853  0.75968291
   0.24118214]]
[[ 1.8370283  -1.512

[[ 1.83341293 -1.50998959 -1.0977774  ... -0.6768447   0.61783729
   0.41311466]
 [-0.39096185  1.26169634  1.42126854 ...  0.52801761  1.03077754
   1.51893493]
 [-0.73929056  1.84881658 -0.36255785 ...  0.8548671   0.10999427
  -0.63617934]
 ...
 [-0.02882499  0.94005164 -0.57065109 ...  0.18643304  1.18285987
   0.12262857]
 [ 0.66631981  0.34523055  0.0122973  ...  0.14692474  0.17157219
   0.15883079]
 [ 0.62424287  1.55768818  1.33662461 ...  0.20858513  0.75818023
   0.24070507]]
[[ 1.8333946  -1.50997449 -1.09776642 ... -0.67683793  0.61783111
   0.41311053]
 [-0.39095794  1.26168372  1.42125433 ...  0.52801233  1.03076723
   1.51891974]
 [-0.73928317  1.8487981  -0.36255423 ...  0.85485855  0.10999317
  -0.63617298]
 ...
 [-0.0288247   0.94004224 -0.57064538 ...  0.18643117  1.18284804
   0.12262735]
 [ 0.66631315  0.34522709  0.01229717 ...  0.14692327  0.17157048
   0.1588292 ]
 [ 0.62423663  1.5576726   1.33661125 ...  0.20858304  0.75817265
   0.24070266]]
[[ 1.83337627 -1

[[ 1.83077474 -1.50781678 -1.09619775 ... -0.67587075  0.61694825
   0.4125202 ]
 [-0.39039927  1.25988081  1.4192234  ...  0.52725781  1.0292943
   1.51674925]
 [-0.73822675  1.84615622 -0.36203615 ...  0.85363699  0.10983599
  -0.63526391]
 ...
 [-0.02878351  0.93869895 -0.56982995 ...  0.18616477  1.18115779
   0.12245212]
 [ 0.66536101  0.34473377  0.0122796  ...  0.14671333  0.17132531
   0.15860224]
 [ 0.62334461  1.55544674  1.33470127 ...  0.20828498  0.75708924
   0.24035871]]
[[ 1.83075643 -1.5078017  -1.09618678 ... -0.67586399  0.61694208
   0.41251608]
 [-0.39039537  1.25986821  1.41920921 ...  0.52725254  1.029284
   1.51673409]
 [-0.73821937  1.84613776 -0.36203253 ...  0.85362845  0.10983489
  -0.63525756]
 ...
 [-0.02878323  0.93868957 -0.56982425 ...  0.18616291  1.18114598
   0.12245089]
 [ 0.66535436  0.34473033  0.01227948 ...  0.14671186  0.17132359
   0.15860065]
 [ 0.62333838  1.55543118  1.33468792 ...  0.2082829   0.75708167
   0.2403563 ]]
[[ 1.83073812 -1.50

[[ 1.82737269 -1.50501487 -1.09416073 ... -0.67461481  0.6158018
   0.41175364]
 [-0.38967381  1.25753963  1.41658612 ...  0.52627803  1.0273816
   1.51393075]
 [-0.73685494  1.84272559 -0.36136339 ...  0.85205071  0.10963189
  -0.63408343]
 ...
 [-0.02873003  0.93695461 -0.56877106 ...  0.18581883  1.1789629
   0.12222457]
 [ 0.6641246   0.34409317  0.01225678 ...  0.14644069  0.17100694
   0.15830752]
 [ 0.62218628  1.55255632  1.33222105 ...  0.20789794  0.75568238
   0.23991206]]
[[ 1.82735441 -1.50499982 -1.09414979 ... -0.67460806  0.61579564
   0.41174952]
 [-0.38966991  1.25752706  1.41657196 ...  0.52627277  1.02737133
   1.51391561]
 [-0.73684757  1.84270716 -0.36135978 ...  0.85204219  0.10963079
  -0.63407708]
 ...
 [-0.02872974  0.93694524 -0.56876537 ...  0.18581697  1.17895111
   0.12222335]
 [ 0.66411796  0.34408973  0.01225666 ...  0.14643923  0.17100523
   0.15830593]
 [ 0.62218006  1.55254079  1.33220773 ...  0.20789586  0.75567482
   0.23990966]]
[[ 1.82733614 -1.50

[[ 1.82333869 -1.5016925  -1.09174533 ... -0.67312557  0.61444239
   0.41084467]
 [-0.38881359  1.25476357  1.41345896 ...  0.52511626  1.02511362
   1.51058869]
 [-0.73522831  1.83865771 -0.36056567 ...  0.85016978  0.10938987
  -0.63268366]
 ...
 [-0.0286666   0.93488625 -0.56751548 ...  0.18540863  1.17636029
   0.12195476]
 [ 0.66265852  0.34333357  0.01222973 ...  0.14611742  0.17062944
   0.15795805]
 [ 0.62081278  1.54912899  1.32928012 ...  0.20743899  0.75401418
   0.23938245]]
[[ 1.82332046 -1.50167748 -1.09173441 ... -0.67311884  0.61443625
   0.41084057]
 [-0.3888097   1.25475102  1.41344482 ...  0.52511101  1.02510337
   1.51057358]
 [-0.73522095  1.83863932 -0.36056206 ...  0.85016127  0.10938878
  -0.63267734]
 ...
 [-0.02866632  0.9348769  -0.5675098  ...  0.18540677  1.17634853
   0.12195354]
 [ 0.66265189  0.34333014  0.0122296  ...  0.14611596  0.17062773
   0.15795647]
 [ 0.62080657  1.5491135   1.32926683 ...  0.20743692  0.75400664
   0.23938005]]
[[ 1.82330223 -1

KeyboardInterrupt: 

In [None]:
mine.score(X_test, y_test)

In [13]:
mine.get_parameters()

{'W': array([[ 4.10051652e-02,  5.58986254e-05, -5.78061388e-02, ...,
         -5.14885801e-02,  7.23898075e-02, -6.85019548e-02],
        [-1.20400609e-02,  7.16937556e-02,  3.40646263e-02, ...,
          4.59066264e-02, -3.70312405e-02,  3.50930063e-02],
        [ 1.24550732e-01, -1.36888130e-01, -2.09572012e-02, ...,
          5.38206552e-02, -1.15575827e-01,  4.71531725e-02],
        ...,
        [ 1.49150833e-02,  1.04881523e-01, -7.19492174e-02, ...,
          7.18512381e-03,  5.35393174e-02, -7.67888475e-02],
        [ 1.07968480e-01,  3.41621839e-02,  8.97009075e-02, ...,
          9.61664080e-02, -7.02416061e-02, -1.03713799e-02],
        [-2.84248512e-02,  2.60784422e-02, -1.63517960e-01, ...,
          9.20050052e-02,  1.49520833e-01,  5.87362281e-02]]),
 'b': array([[ -2.9077112 ,  -2.83231557,  -5.21436629, -10.71951527,
          -2.56542998,   0.45731117,  -4.59340067,  -2.53443188,
         -32.89286865, -12.52483581]])}

In [19]:
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import make_scorer, accuracy_score, f1_score

In [20]:
mine=MySVM()
mine_param={"C":[0.01, 0.03, 0.1, 0.3, 1, 3],
            "eta":[0.001, 0.01, 0.1, 1, 10],
            "batch_size":[1, 10, 100, 200, 300, 400, 500],
            "num_iter":[10, 15, 20, 25, 30],
            "epsilon":[0.001, 0.01, 0.1]
          }

"""scoring = {'f1 macro': make_scorer(f1_score , average='macro'),
           'f1 micro': make_scorer(f1_score, average = 'micro'),
           'Accuracy': make_scorer(accuracy_score)
          }"""

mine_grid = GridSearchCV(mine, mine_param, cv=5)