# 最近邻分类器

In [1]:
import numpy as np

class NearstNeighbor:
    def __init__(self):
        pass
    
    def train(self, X, y):
        """ X is N x D where each row is an example. Y is 1-dimension of size N """
        # the nearest neighbor classfier simply remembers all the training data
        self.Xtr = X
        self.ytr = y
        
    def predict(self, X):
        """ X is N x D where each row is an example we wish to predict label for """
        num_test = X.shape[0]
        # lets make sure that the output type matches the input type
        Ypred = np.zeros(num_test, dtype = self.ytr.dtype)
        
        # loop over all test rows
        for i in xrange(num_test):
            # find the nearest training image to the i'th test image
            # using the L1 diistance (sum of absolute value differences)
            distances = np.sum(np.abs(self.Xtr - X[i,:]), axis=1)
            min_index = np.argmin(distances)  # get the index with smallest distance
            Ypred[i] = self.ytr[min_index]  # predict the label of the nearest example
            
        return Ypred

# K-最近邻分类器：超参数的设置

In [None]:
# assume we have Xtr_rows, Ytr, Xte_rows, Yte as before
# recall Xtr_rows is 50,000 x 3072 matrix
# 假定已经有 Xtr_rows, Ytr, Xte_rows, Yte了，其中 Xtr_rows 为 50000*3072矩阵
Xval_rows = Xtr_rows[:1000,:]  # take first 1000 for validation 构建 1000 的交叉验证集
Yval = Ytr[:1000]
Xtr_rows = Xtr_rows[1000:,:]  # keep last 49,000 for train 保留 49000 的训练集
Ytr = Ytr[1000:]

# find hyperparameters that work best on the validation set 设置一些 K 值，用于试验
validation_accuracies = []
for k in [1,3,5,10,20,50,100]:
    
    # use a particular value of k and evaluation on validation data 初始化对象
    nn = NearstNeighbor()
    nn.train(Xtr_rows, Ytr)
    # here we assume a modified NearestNeighbor class that can take a k as input
    # 修改以下 predict 函数，接受 k 作为参数
    Yval_predict = nn.predict(Xval_rows, k=k)
    acc = np.mean(Yval_predict == Yval)
    print('accuracy:%f' % acc)
    
    # keep track of what works on the validation set 输出结果
    validation_accuracies.append((k, acc))