In [2]:
import numpy as np
import time
import tqdm



def load_data(file_path):
    """
    加载Mnist数据，返回数据集和标签
    """
    data = []
    label = []
    with open(file_path, 'r') as f:
        for line in f:
            cur_line = line.strip().split(',')
            if int(cur_line[0]) >= 5:
                label.append(-1)
            else:
                label.append(1)
            data.append([int(num) / 255 for num in cur_line[1:]])
    return data, label

def train(data, label, epoch=50, lr = 0.0001):
    """
    感知器训练过程
    data: 训练集的数据
    label: 训练集的标签
    epoch: 迭代次数，默认50
    """
    print('Start to train:')
    start_time = time.time()

    dimension = len(data[0])
    w = np.random.randn(dimension)
    b = 0
    
    for _ in tqdm.tqdm(range(epoch)):
        for i in range(len(data)):
            x = np.array(data[i])
            y = label[i]
            if y * (np.dot(w, x) + b) <= 0:
                w += lr * y * x
                b += lr * y
    end_time = time.time()
    print('Training cost %.2f seconds' % (end_time - start_time))
    return w, b

def test(data, label, w, b):
    print('Start to test:')
    start_time = time.time()
    error_count = 0
    for i in range(len(data)):
        x = data[i]
        y = label[i]
        if y * (np.dot(w, x) + b) <= 0:
            error_count += 1
    error_rate = error_count / len(data)
    end_time = time.time()
    print('Testing cost %.2f seconds' % (end_time - start_time))
    return 1 - error_rate

if __name__ == '__main__':
    train_data, train_label = load_data('../mnist_train.csv')
    test_data, test_label = load_data('../mnist_test.csv')
    w, b = train(train_data, train_label)
    acc = test(test_data, test_label, w, b)
    print('The accuracy is %.2f' % acc)



Start to train:


100%|██████████| 50/50 [01:23<00:00,  1.66s/it]


Training cost 83.16 seconds
Start to test:
Testing cost 0.26 seconds
The accuracy is 0.80
