In [None]:
import numpy as np


In [9]:
def distance(p, q):
    """Dist√¢ncia entre dois pontos 2D.
     metric: 'euclidean'
    """

    p = np.asarray(p)
    q = np.asarray(q)

    return float(np.linalg.norm(p - q))

In [10]:
def train_test_split (x, y, test_size = 0.2, random_seed = None):
    x = np.array(x)
    y = np.array(y)

    if random_seed:
        np.random.seed(random_seed)

    n_rows = x.shape[0]
    idx = np.random.permutation(n_rows)

    split_idx = int(n_rows * test_size)

    test_idx = idx[:split_idx]
    train_idx = idx[split_idx:]

    x_train = x[train_idx]
    x_test = x[test_idx]
    y_train = y[train_idx]
    y_test = y[test_idx]
    
    return x_train, x_test, y_train, y_test



In [11]:
def standardize_data(x_train, x_test):
    x_train_std = np.array(x_train, dtype=float)
    x_test_std = np.array(x_test, dtype=float)

    mean = np.mean(x_train_std, axis=0)
    stdev = np.std(x_train_std, axis=0)

    for i in range(len(stdev)):
        if stdev[i] == 0:
            stdev[i] = 1

    x_train_std = (x_train_std - mean) / stdev
    x_test_std = (x_test_std - mean) / stdev
    
    return x_train_std, x_test_std

In [12]:
def KNN(x_train, x_test, y_train, k = 5, standardize = False):
    y_pred = []

    if standardize:
        x_train_std, x_test_std = standardize_data(x_train, x_test)
    else:
        x_train_std = x_train
        x_test_std = x_test
    

    for i in x_test_std:
        distance_list = []
        for j in range(len(x_train_std)):
            dist = distance(i, x_train_std[j])
            distance_list.append((dist, y_train[j]))
            
        distance_list.sort()

        k_neighbor = distance_list[:k]

        targets = []
        for v in k_neighbor:
            targets.append(v[1])

        counter = {}
        for vote in targets:
            if vote in counter:
                counter [vote] += 1
            else:
                counter[vote] = 1


        best = max(counter, key = counter.get)
        y_pred.append(best)


    return np.array(y_pred)