In [152]:
import numpy as np
from sklearn.datasets import fetch_mldata
from random import shuffle
from random import seed

In [91]:
class KNN:    
    def __init__(self, train_data, train_target, k, metric_selection):
        
        if (train_data.shape[0] != train_target.shape[0]):
            raise TypeError
        
        self.data = train_data
        self.target = train_target
        self.k = k
        self.clusters = np.sort(np.unique(train_target))
        self.clusters_amount = np.unique(train_target).__len__()
        
        if (metric_selection == "euclidean"):
            self.metric = self.euclidean_distance
        elif (metric_selection == "cosine"):
            self.metric = self.cosine_distance
        else:
            raise TypeError
            
    def euclidean_distance(self, arr1, arr2):
        res_arr = np.empty((arr1.shape[0], arr2.shape[0]))
        for enum, item in enumerate(arr1[:, ]):
            res_arr[enum] = np.sqrt(np.sum((arr2 - item) ** 2, axis=1))

        return res_arr
    
    def cosine_distance(self, arr1, arr2):
        if (arr1.shape[1] != arr2.shape[1]):
            return TypeError

        res_arr = np.empty((arr1.shape[0], arr2.shape[0]))
        for enum, item in enumerate(arr1[:, ]):
            res_arr[enum] = np.sum(arr2 * item, axis=1) / (
                (np.sqrt(np.sum(arr2 * arr2, axis=1)) *
                 np.sqrt((np.sum(item * item)))))

        return 1 - res_arr
        
    def predict(self, test_data):
        
        if (test_data.shape[1] != self.data.shape[1]):
            return TypeError
        
        ranges = self.metric(self.data, test_data)
        
        max_range = np.max(ranges)
        
        k = self.k
        
        if (k > self.data.shape[0]):
            k = self.data.shape[0]
            
        closest = np.empty((k, test_data.shape[0])).astype(int)
            
        while (k > 0):
            
            # тут мы получаем в столбцах номера выходных точек по каждой из точек test_set
            save = np.argmin(ranges, axis=0)
            
            closest[k - 1] = save
            
            for enum, item in enumerate(save):
                ranges[item, enum] += max_range
                
            k -= 1
            
        closest_4_each = closest.T
        test_target = np.empty(test_data.shape[0]).astype(int)
        
        for enum, item in enumerate(closest_4_each):
            
            cluster_nb = np.zeros(self.clusters_amount).astype(int)
            
            for it in item:
                cluster_nb[np.where(self.clusters ==self.target[it])[0]] += 1
                
            test_target[enum] = self.clusters[self.clusters[np.argmax(cluster_nb)]]
            
        
        return test_target
        
        
    

In [162]:
class Cross_validation:
    
    def __init__(self, folds_amount = 5, random_seed = 0, stratified = False):
        self.folds = folds_amount
        self.stratified = stratified
        self.seed = random_seed
        if random_seed != 0:
            seed(random_seed)
        
    def split(self, train_set):
            
        index_list = []
        
        for index in range(train_set.__len__()):
            index_list.append(index)
            
            
        train_list = []
        test_list = []
            
        if (self.stratified):
            shuffle(index_list)
       
        part = 0
        each_len = (int)(train_set.__len__() / self.folds)

        while part < self.folds - 1:
            test_subset = index_list[part * each_len : (part + 1) * each_len]
            train_subset = [x for x in index_list if x not in test_subset]

            test_list.append(test_subset)
            train_list.append(train_subset)
            part += 1

        test_subset = index_list[part * each_len : ]
        train_subset = [x for x in index_list if x not in test_subset]
        test_list.append(test_subset)
        train_list.append(train_subset)

        if (self.seed != 0):
            seed(self.seed)
        
        return train_list, test_list
                

In [92]:
mnist = fetch_mldata('MNIST original')



In [93]:
from sklearn.model_selection import train_test_split

In [94]:
trX, teX, trY, teY = train_test_split(mnist.data / 255.0, mnist.target.astype("int0"), test_size = 1/200)

In [95]:
trX.shape, trY.shape, teX.shape, teY.shape

((69650, 784), (69650,), (350, 784), (350,))

In [96]:
model = KNN(trX, trY, 5, "euclidean")

[0 1 2 3 4 5 6 7 8 9]


In [97]:
res = model.predict(teX)

In [98]:
res

array([7, 6, 8, 4, 9, 1, 4, 8, 0, 1, 5, 1, 8, 6, 0, 2, 0, 1, 2, 7, 2, 5,
       0, 5, 5, 0, 0, 2, 6, 4, 8, 5, 0, 0, 7, 5, 6, 9, 7, 3, 9, 0, 1, 5,
       9, 3, 5, 9, 5, 9, 6, 6, 7, 4, 3, 9, 5, 4, 2, 2, 5, 0, 9, 8, 2, 1,
       1, 1, 9, 1, 2, 7, 8, 2, 3, 4, 6, 7, 4, 0, 2, 9, 8, 1, 2, 1, 7, 4,
       1, 2, 8, 8, 3, 9, 0, 4, 6, 0, 8, 0, 6, 8, 9, 5, 7, 4, 8, 1, 8, 5,
       9, 2, 1, 1, 7, 4, 2, 3, 6, 4, 7, 4, 9, 7, 0, 3, 2, 0, 5, 3, 5, 0,
       0, 5, 8, 2, 3, 1, 0, 9, 2, 2, 3, 7, 4, 3, 5, 4, 9, 3, 1, 5, 3, 6,
       0, 4, 6, 4, 1, 4, 8, 9, 7, 0, 0, 4, 3, 1, 6, 3, 7, 2, 2, 5, 6, 1,
       7, 3, 5, 7, 8, 6, 1, 3, 0, 4, 0, 8, 1, 5, 1, 7, 9, 1, 1, 7, 0, 5,
       5, 4, 8, 8, 3, 4, 6, 2, 5, 9, 9, 4, 2, 4, 6, 2, 7, 7, 7, 3, 4, 2,
       2, 1, 4, 0, 6, 9, 6, 4, 9, 3, 2, 4, 2, 6, 1, 6, 2, 8, 3, 8, 1, 4,
       0, 7, 5, 3, 6, 9, 6, 4, 9, 7, 7, 3, 7, 1, 0, 4, 4, 0, 8, 6, 7, 2,
       3, 4, 7, 7, 4, 5, 7, 3, 5, 3, 6, 5, 0, 8, 1, 1, 0, 9, 3, 4, 7, 4,
       0, 9, 8, 9, 5, 7, 5, 2, 0, 3, 8, 1, 7, 8, 7,

In [99]:
from sklearn.metrics import accuracy_score as ac_s

In [100]:
ac_s(teY, res)

0.9742857142857143

In [102]:
trY.__len__()

69650

In [103]:
test_list = [0, 1, 2, 3, 4]

array([8, 7, 8, 5, 6])

In [106]:
test_list[2:]

[2, 3, 4]

In [112]:
test_list = []

for i in range(110):
    test_list.append(i)

In [163]:
KFolds = Cross_validation(folds_amount = 4, stratified = True, random_seed = 317)

In [164]:
train, test = KFolds.split(test_list)

In [158]:
print(train)
print(test)

[[98, 45, 19, 66, 105, 68, 84, 65, 88, 92, 85, 73, 83, 27, 78, 53, 28, 55, 79, 41, 7, 50, 40, 30, 106, 18, 72, 1, 22, 16, 74, 13, 5, 109, 48, 76, 80, 89, 12, 101, 32, 58, 69, 99, 20, 4, 10, 100, 77, 37, 0, 104, 103, 24, 38, 11, 47, 54, 64, 8, 15, 81, 44, 62, 52, 93, 33, 42, 36, 82, 56, 60, 107, 57, 71, 31, 25, 108, 49, 96, 90, 2, 59], [43, 9, 87, 102, 46, 34, 23, 21, 63, 14, 95, 86, 6, 91, 94, 26, 35, 39, 3, 67, 17, 61, 97, 51, 75, 29, 70, 1, 22, 16, 74, 13, 5, 109, 48, 76, 80, 89, 12, 101, 32, 58, 69, 99, 20, 4, 10, 100, 77, 37, 0, 104, 103, 24, 38, 11, 47, 54, 64, 8, 15, 81, 44, 62, 52, 93, 33, 42, 36, 82, 56, 60, 107, 57, 71, 31, 25, 108, 49, 96, 90, 2, 59], [43, 9, 87, 102, 46, 34, 23, 21, 63, 14, 95, 86, 6, 91, 94, 26, 35, 39, 3, 67, 17, 61, 97, 51, 75, 29, 70, 98, 45, 19, 66, 105, 68, 84, 65, 88, 92, 85, 73, 83, 27, 78, 53, 28, 55, 79, 41, 7, 50, 40, 30, 106, 18, 72, 38, 11, 47, 54, 64, 8, 15, 81, 44, 62, 52, 93, 33, 42, 36, 82, 56, 60, 107, 57, 71, 31, 25, 108, 49, 96, 90, 2, 59