# Cross Validation on-K

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

(X_train, Y_train),(X_test, Y_test) = tf.keras.datasets.cifar10.load_data()
print('Training data shape: ', X_train.shape)
print('Training labels shape: ', Y_train.shape)
print('Test data shape: ', X_test.shape)
print('Test labels shape: ', Y_test.shape)


## Sub Sampling

In [None]:
num_training = 5000
mask = list(range(num_training))
X_train = X_train[mask]
Y_train = Y_train[mask]

num_test = 500
mask = list(range(num_test))
X_test = X_test[mask]
Y_test = Y_test[mask]

## Reshape Image Data

In [None]:
X_train = np.reshape(X_train, (X_train.shape[0], -1))
X_test = np.reshape(X_test, (X_test.shape[0], -1))
print(X_train.shape, X_test.shape)

In [None]:
from KNN import KNearestNeighbor

num_folds = 5
k_choices = [1, 3, 5, 8, 10, 12, 15, 20, 50, 100]

X_train_folds = []
y_train_folds = []

X_train_folds = np.array_split(X_train, num_folds)
y_train_folds = np.array_split(Y_train, num_folds)

k_to_accuracies = {}

classifier = KNearestNeighbor()
for k in k_choices:
    accuracies = []
    for fold in range(num_folds):
        temp_X = X_train_folds[:]
        temp_y = y_train_folds[:]
        X_val_fold = temp_X.pop(fold)
        y_val_fold = temp_y.pop(fold)
        temp_X = np.array([y for x in temp_X for y in x])
        temp_y = np.array([y for x in temp_y for y in x])
        classifier.train(temp_X,temp_y)
        y_val_pred = classifier.predict(X_val_fold,k=k)
        num_correct = np.sum(y_val_fold == y_val_pred)
        accuracies.append(num_correct / y_val_fold.shape[0])
    k_to_accuracies[k] = accuracies
    
    
for k in sorted(k_to_accuracies):
    for accuracy in k_to_accuracies[k]:
        print('k = %d, accuracy = %f' % (k, accuracy))
        



## Plotting

In [None]:
for k in k_choices:
    accuracies = k_to_accuracies[k]
    plt.scatter([k] * len(accuracies), accuracies)

accuracies_mean = np.array([np.mean(v) for k,v in sorted(k_to_accuracies.items())])
accuracies_std = np.array([np.std(v) for k,v in sorted(k_to_accuracies.items())])
plt.errorbar(k_choices, accuracies_mean, yerr=accuracies_std)
plt.title('Cross-validation on k')
plt.xlabel('k')
plt.ylabel('Cross-validation accuracy')
plt.show()